mirror of
https://github.com/restic/rest-server.git
synced 2026-04-06 18:21:50 -07:00
285 lines
7.7 KiB
Go
285 lines
7.7 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
restserver "github.com/restic/rest-server"
|
|
)
|
|
|
|
func TestTLSSettings(t *testing.T) {
|
|
type expected struct {
|
|
TLSKey string
|
|
TLSCert string
|
|
Error bool
|
|
}
|
|
type passed struct {
|
|
Path string
|
|
TLS bool
|
|
TLSKey string
|
|
TLSCert string
|
|
}
|
|
|
|
var tests = []struct {
|
|
passed passed
|
|
expected expected
|
|
}{
|
|
{passed{TLS: false}, expected{"", "", false}},
|
|
{passed{TLS: true}, expected{
|
|
filepath.Join(os.TempDir(), "restic/private_key"),
|
|
filepath.Join(os.TempDir(), "restic/public_key"),
|
|
false,
|
|
}},
|
|
{passed{
|
|
Path: os.TempDir(),
|
|
TLS: true,
|
|
}, expected{
|
|
filepath.Join(os.TempDir(), "private_key"),
|
|
filepath.Join(os.TempDir(), "public_key"),
|
|
false,
|
|
}},
|
|
{passed{Path: os.TempDir(), TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", false}},
|
|
{passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
|
|
{passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", true}},
|
|
{passed{Path: os.TempDir(), TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
app := newRestServerApp()
|
|
t.Run("", func(t *testing.T) {
|
|
// defer func() { restserver.Server = defaultConfig }()
|
|
if test.passed.Path != "" {
|
|
app.Server.Path = test.passed.Path
|
|
}
|
|
app.Server.TLS = test.passed.TLS
|
|
app.Server.TLSKey = test.passed.TLSKey
|
|
app.Server.TLSCert = test.passed.TLSCert
|
|
|
|
gotTLS, gotKey, gotCert, err := app.tlsSettings()
|
|
if err != nil && !test.expected.Error {
|
|
t.Fatalf("tls_settings returned err (%v)", err)
|
|
}
|
|
if test.expected.Error {
|
|
if err == nil {
|
|
t.Fatalf("Error not returned properly (%v)", test)
|
|
} else {
|
|
return
|
|
}
|
|
}
|
|
if gotTLS != test.passed.TLS {
|
|
t.Errorf("TLS enabled, want (%v), got (%v)", test.passed.TLS, gotTLS)
|
|
}
|
|
wantKey := test.expected.TLSKey
|
|
if gotKey != wantKey {
|
|
t.Errorf("wrong TLSPrivPath path, want (%v), got (%v)", wantKey, gotKey)
|
|
}
|
|
|
|
wantCert := test.expected.TLSCert
|
|
if gotCert != wantCert {
|
|
t.Errorf("wrong TLSCertPath path, want (%v), got (%v)", wantCert, gotCert)
|
|
}
|
|
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetHandler(t *testing.T) {
|
|
dir, err := os.MkdirTemp("", "rest-server-test")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer func() {
|
|
err := os.Remove(dir)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}()
|
|
|
|
getHandler := restserver.NewHandler
|
|
|
|
// With NoAuth = false and no .htpasswd
|
|
_, err = getHandler(&restserver.Server{Path: dir})
|
|
if err == nil {
|
|
t.Errorf("NoAuth=false: expected error, got nil")
|
|
}
|
|
|
|
// With NoAuth = true and no .htpasswd
|
|
_, err = getHandler(&restserver.Server{NoAuth: true, Path: dir})
|
|
if err != nil {
|
|
t.Errorf("NoAuth=true: expected no error, got %v", err)
|
|
}
|
|
|
|
// With NoAuth = false, no .htpasswd and ProxyAuth = X-Remote-User
|
|
_, err = getHandler(&restserver.Server{Path: dir, ProxyAuthUsername: "X-Remote-User"})
|
|
if err != nil {
|
|
t.Errorf("NoAuth=false, ProxyAuthUsername = X-Remote-User: expected no error, got %v", err)
|
|
}
|
|
|
|
// With NoAuth = false and custom .htpasswd
|
|
htpFile, err := os.CreateTemp(dir, "custom")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer func() {
|
|
err := os.Remove(htpFile.Name())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}()
|
|
_, err = getHandler(&restserver.Server{HtpasswdPath: htpFile.Name()})
|
|
if err != nil {
|
|
t.Errorf("NoAuth=false with custom htpasswd: expected no error, got %v", err)
|
|
}
|
|
|
|
// Create .htpasswd
|
|
htpasswd := filepath.Join(dir, ".htpasswd")
|
|
err = os.WriteFile(htpasswd, []byte(""), 0644)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer func() {
|
|
err := os.Remove(htpasswd)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}()
|
|
|
|
// With NoAuth = false and with .htpasswd
|
|
_, err = getHandler(&restserver.Server{Path: dir})
|
|
if err != nil {
|
|
t.Errorf("NoAuth=false with .htpasswd: expected no error, got %v", err)
|
|
}
|
|
}
|
|
|
|
// helper method to test the app. Starts app with passed arguments,
|
|
// then will call the callback function which can make requests against
|
|
// the application. If the callback function fails due to errors returned
|
|
// by http.Do() (i.e. *url.Error), then it will be retried until successful,
|
|
// or the passed timeout passes.
|
|
func testServerWithArgs(args []string, timeout time.Duration, cb func(context.Context, *restServerApp) error) error {
|
|
// create the app with passed args
|
|
app := newRestServerApp()
|
|
app.CmdRoot.SetArgs(args)
|
|
|
|
// create context that will timeout
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
|
|
// wait group for our client and server tasks
|
|
jobs := &sync.WaitGroup{}
|
|
jobs.Add(2)
|
|
|
|
// run the server, saving the error
|
|
var serverErr error
|
|
go func() {
|
|
defer jobs.Done()
|
|
defer cancel() // if the server is stopped, no point keep the client alive
|
|
serverErr = app.CmdRoot.ExecuteContext(ctx)
|
|
}()
|
|
|
|
// run the client, saving the error
|
|
var clientErr error
|
|
go func() {
|
|
defer jobs.Done()
|
|
defer cancel() // once the client is done, stop the server
|
|
|
|
var urlError *url.Error
|
|
|
|
// execute in loop, as we will retry for network errors
|
|
// (such as the server hasn't started yet)
|
|
for {
|
|
clientErr = cb(ctx, app)
|
|
switch {
|
|
case clientErr == nil:
|
|
return // success, we're done
|
|
case errors.As(clientErr, &urlError):
|
|
// if a network error (url.Error), then wait and retry
|
|
// as server may not be ready yet
|
|
select {
|
|
case <-time.After(time.Millisecond * 100):
|
|
continue
|
|
case <-ctx.Done(): // unless we run out of time first
|
|
clientErr = context.Canceled
|
|
return
|
|
}
|
|
default:
|
|
return // other error type, we're done
|
|
}
|
|
}
|
|
}()
|
|
|
|
// wait for both to complete
|
|
jobs.Wait()
|
|
|
|
// report back if either failed
|
|
if clientErr != nil || serverErr != nil {
|
|
return fmt.Errorf("client or server error, client: %v, server: %v", clientErr, serverErr)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func TestHttpListen(t *testing.T) {
|
|
td := t.TempDir()
|
|
|
|
// create some content and parent dirs
|
|
if err := os.MkdirAll(filepath.Join(td, "data", "repo1"), 0700); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := os.WriteFile(filepath.Join(td, "data", "repo1", "config"), []byte("foo"), 0700); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
for _, args := range [][]string{
|
|
{"--no-auth", "--path", filepath.Join(td, "data"), "--listen", "127.0.0.1:0"}, // test emphemeral port
|
|
{"--no-auth", "--path", filepath.Join(td, "data"), "--listen", "127.0.0.1:9000"}, // test "normal" port
|
|
{"--no-auth", "--path", filepath.Join(td, "data"), "--listen", "127.0.0.1:9000"}, // test that server was shutdown cleanly and that we can re-use that port
|
|
} {
|
|
err := testServerWithArgs(args, time.Second*10, func(ctx context.Context, app *restServerApp) error {
|
|
for _, test := range []struct {
|
|
Path string
|
|
StatusCode int
|
|
}{
|
|
{"/repo1/", http.StatusMethodNotAllowed},
|
|
{"/repo1/config", http.StatusOK},
|
|
{"/repo2/config", http.StatusNotFound},
|
|
} {
|
|
listenAddr := app.ListenerAddress()
|
|
if listenAddr == nil {
|
|
return &url.Error{} // return this type of err, as we know this will retry
|
|
}
|
|
port := strings.Split(listenAddr.String(), ":")[1]
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:%s%s", port, test.Path), nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = resp.Body.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if resp.StatusCode != test.StatusCode {
|
|
return fmt.Errorf("expected %d from server, instead got %d (path %s)", test.StatusCode, resp.StatusCode, test.Path)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
}
|