mirror of
https://github.com/restic/rest-server.git
synced 2025-12-07 09:36:13 -08:00
@@ -1,15 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
restserver "github.com/restic/rest-server"
|
||||
"github.com/spf13/cobra"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
|
||||
restserver "github.com/restic/rest-server"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// cmdRoot is the base command when no other command has been specified.
|
||||
@@ -29,12 +29,35 @@ func init() {
|
||||
flags.StringVar(&restserver.Config.Log, "log", restserver.Config.Log, "log HTTP requests in the combined log format")
|
||||
flags.StringVar(&restserver.Config.Path, "path", restserver.Config.Path, "data directory")
|
||||
flags.BoolVar(&restserver.Config.TLS, "tls", restserver.Config.TLS, "turn on TLS support")
|
||||
flags.StringVar(&restserver.Config.TLSCert, "tls-cert", restserver.Config.TLSCert, "TLS certificate path")
|
||||
flags.StringVar(&restserver.Config.TLSKey, "tls-key", restserver.Config.TLSKey, "TLS key path")
|
||||
flags.BoolVar(&restserver.Config.AppendOnly, "append-only", restserver.Config.AppendOnly, "enable append only mode")
|
||||
flags.BoolVar(&restserver.Config.Prometheus, "prometheus", restserver.Config.Prometheus, "enable Prometheus metrics")
|
||||
}
|
||||
|
||||
var version = "manually"
|
||||
|
||||
func tlsSettings() (bool, string, string, error) {
|
||||
var key, cert string
|
||||
enabledTLS := restserver.Config.TLS
|
||||
if !enabledTLS && (restserver.Config.TLSKey != "" || restserver.Config.TLSCert != "") {
|
||||
return false, "", "", errors.New("requires enabled TLS")
|
||||
} else if !enabledTLS {
|
||||
return false, "", "", nil
|
||||
}
|
||||
if restserver.Config.TLSKey != "" {
|
||||
key = restserver.Config.TLSKey
|
||||
} else {
|
||||
key = filepath.Join(restserver.Config.Path, "private_key")
|
||||
}
|
||||
if restserver.Config.TLSCert != "" {
|
||||
cert = restserver.Config.TLSCert
|
||||
} else {
|
||||
cert = filepath.Join(restserver.Config.Path, "public_key")
|
||||
}
|
||||
return enabledTLS, key, cert, nil
|
||||
}
|
||||
|
||||
func runRoot(cmd *cobra.Command, args []string) error {
|
||||
log.SetFlags(0)
|
||||
|
||||
@@ -65,22 +88,24 @@ func runRoot(cmd *cobra.Command, args []string) error {
|
||||
log.Println("Authentication enabled")
|
||||
}
|
||||
|
||||
if !restserver.Config.TLS {
|
||||
enabledTLS, privateKey, publicKey, err := tlsSettings()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !enabledTLS {
|
||||
log.Printf("Starting server on %s\n", restserver.Config.Listen)
|
||||
err = http.ListenAndServe(restserver.Config.Listen, handler)
|
||||
} else {
|
||||
privateKey := filepath.Join(restserver.Config.Path, "private_key")
|
||||
publicKey := filepath.Join(restserver.Config.Path, "public_key")
|
||||
|
||||
log.Println("TLS enabled")
|
||||
log.Printf("Private key: %s", privateKey)
|
||||
log.Printf("Public key: %s", publicKey)
|
||||
log.Printf("Public key(certificate): %s", publicKey)
|
||||
log.Printf("Starting server on %s\n", restserver.Config.Listen)
|
||||
err = http.ListenAndServeTLS(restserver.Config.Listen, publicKey, privateKey, handler)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func main() {
|
||||
if err := cmdRoot.Execute(); err != nil {
|
||||
log.Fatalf("error: %v", err)
|
||||
|
||||
72
cmd/rest-server/main_test.go
Normal file
72
cmd/rest-server/main_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
restserver "github.com/restic/rest-server"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTLSSettings(t *testing.T) {
|
||||
type expected struct {
|
||||
TLSKey string
|
||||
TLSCert string
|
||||
Error bool
|
||||
}
|
||||
type passed struct {
|
||||
Path string
|
||||
TLS bool
|
||||
TLSKey string
|
||||
TLSCert string
|
||||
}
|
||||
|
||||
var tests = []struct {
|
||||
passed passed
|
||||
expected expected
|
||||
}{
|
||||
{passed{TLS: false}, expected{"", "", false}},
|
||||
{passed{TLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", false}},
|
||||
{passed{Path: "/tmp", TLS: true}, expected{"/tmp/private_key", "/tmp/public_key", false}},
|
||||
{passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", false}},
|
||||
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
|
||||
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", true}},
|
||||
{passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
|
||||
}
|
||||
|
||||
defaultConfig := restserver.Config
|
||||
for _, test := range tests {
|
||||
|
||||
t.Run("", func(t *testing.T) {
|
||||
defer func() { restserver.Config = defaultConfig }()
|
||||
if test.passed.Path != "" {
|
||||
restserver.Config.Path = test.passed.Path
|
||||
}
|
||||
restserver.Config.TLS = test.passed.TLS
|
||||
restserver.Config.TLSKey = test.passed.TLSKey
|
||||
restserver.Config.TLSCert = test.passed.TLSCert
|
||||
|
||||
gotTLS, gotKey, gotCert, err := tlsSettings()
|
||||
if err != nil && !test.expected.Error {
|
||||
t.Fatalf("tls_settings returned err (%v)", err)
|
||||
}
|
||||
if test.expected.Error {
|
||||
if err == nil {
|
||||
t.Fatalf("Error not returned properly (%v)", test)
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
if gotTLS != test.passed.TLS {
|
||||
t.Errorf("TLS enabled, want (%v), got (%v)", test.passed.TLS, gotTLS)
|
||||
}
|
||||
wantKey := test.expected.TLSKey
|
||||
if gotKey != wantKey {
|
||||
t.Errorf("wrong TLSPrivPath path, want (%v), got (%v)", wantKey, gotKey)
|
||||
}
|
||||
|
||||
wantCert := test.expected.TLSCert
|
||||
if gotCert != wantCert {
|
||||
t.Errorf("wrong TLSCertPath path, want (%v), got (%v)", wantCert, gotCert)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user