Files
rest-server/cmd/rest-server/main_test.go
2025-02-17 22:25:25 +01:00

285 lines
7.7 KiB
Go

package main
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
restserver "github.com/restic/rest-server"
)
func TestTLSSettings(t *testing.T) {
type expected struct {
TLSKey string
TLSCert string
Error bool
}
type passed struct {
Path string
TLS bool
TLSKey string
TLSCert string
}
var tests = []struct {
passed passed
expected expected
}{
{passed{TLS: false}, expected{"", "", false}},
{passed{TLS: true}, expected{
filepath.Join(os.TempDir(), "restic/private_key"),
filepath.Join(os.TempDir(), "restic/public_key"),
false,
}},
{passed{
Path: os.TempDir(),
TLS: true,
}, expected{
filepath.Join(os.TempDir(), "private_key"),
filepath.Join(os.TempDir(), "public_key"),
false,
}},
{passed{Path: os.TempDir(), TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", false}},
{passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
{passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", true}},
{passed{Path: os.TempDir(), TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
}
for _, test := range tests {
app := newRestServerApp()
t.Run("", func(t *testing.T) {
// defer func() { restserver.Server = defaultConfig }()
if test.passed.Path != "" {
app.Server.Path = test.passed.Path
}
app.Server.TLS = test.passed.TLS
app.Server.TLSKey = test.passed.TLSKey
app.Server.TLSCert = test.passed.TLSCert
gotTLS, gotKey, gotCert, err := app.tlsSettings()
if err != nil && !test.expected.Error {
t.Fatalf("tls_settings returned err (%v)", err)
}
if test.expected.Error {
if err == nil {
t.Fatalf("Error not returned properly (%v)", test)
} else {
return
}
}
if gotTLS != test.passed.TLS {
t.Errorf("TLS enabled, want (%v), got (%v)", test.passed.TLS, gotTLS)
}
wantKey := test.expected.TLSKey
if gotKey != wantKey {
t.Errorf("wrong TLSPrivPath path, want (%v), got (%v)", wantKey, gotKey)
}
wantCert := test.expected.TLSCert
if gotCert != wantCert {
t.Errorf("wrong TLSCertPath path, want (%v), got (%v)", wantCert, gotCert)
}
})
}
}
func TestGetHandler(t *testing.T) {
dir, err := os.MkdirTemp("", "rest-server-test")
if err != nil {
t.Fatal(err)
}
defer func() {
err := os.Remove(dir)
if err != nil {
t.Fatal(err)
}
}()
getHandler := restserver.NewHandler
// With NoAuth = false and no .htpasswd
_, err = getHandler(&restserver.Server{Path: dir})
if err == nil {
t.Errorf("NoAuth=false: expected error, got nil")
}
// With NoAuth = true and no .htpasswd
_, err = getHandler(&restserver.Server{NoAuth: true, Path: dir})
if err != nil {
t.Errorf("NoAuth=true: expected no error, got %v", err)
}
// With NoAuth = false, no .htpasswd and ProxyAuth = X-Remote-User
_, err = getHandler(&restserver.Server{Path: dir, ProxyAuthUsername: "X-Remote-User"})
if err != nil {
t.Errorf("NoAuth=false, ProxyAuthUsername = X-Remote-User: expected no error, got %v", err)
}
// With NoAuth = false and custom .htpasswd
htpFile, err := os.CreateTemp(dir, "custom")
if err != nil {
t.Fatal(err)
}
defer func() {
err := os.Remove(htpFile.Name())
if err != nil {
t.Fatal(err)
}
}()
_, err = getHandler(&restserver.Server{HtpasswdPath: htpFile.Name()})
if err != nil {
t.Errorf("NoAuth=false with custom htpasswd: expected no error, got %v", err)
}
// Create .htpasswd
htpasswd := filepath.Join(dir, ".htpasswd")
err = os.WriteFile(htpasswd, []byte(""), 0644)
if err != nil {
t.Fatal(err)
}
defer func() {
err := os.Remove(htpasswd)
if err != nil {
t.Fatal(err)
}
}()
// With NoAuth = false and with .htpasswd
_, err = getHandler(&restserver.Server{Path: dir})
if err != nil {
t.Errorf("NoAuth=false with .htpasswd: expected no error, got %v", err)
}
}
// helper method to test the app. Starts app with passed arguments,
// then will call the callback function which can make requests against
// the application. If the callback function fails due to errors returned
// by http.Do() (i.e. *url.Error), then it will be retried until successful,
// or the passed timeout passes.
func testServerWithArgs(args []string, timeout time.Duration, cb func(context.Context, *restServerApp) error) error {
// create the app with passed args
app := newRestServerApp()
app.CmdRoot.SetArgs(args)
// create context that will timeout
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// wait group for our client and server tasks
jobs := &sync.WaitGroup{}
jobs.Add(2)
// run the server, saving the error
var serverErr error
go func() {
defer jobs.Done()
defer cancel() // if the server is stopped, no point keep the client alive
serverErr = app.CmdRoot.ExecuteContext(ctx)
}()
// run the client, saving the error
var clientErr error
go func() {
defer jobs.Done()
defer cancel() // once the client is done, stop the server
var urlError *url.Error
// execute in loop, as we will retry for network errors
// (such as the server hasn't started yet)
for {
clientErr = cb(ctx, app)
switch {
case clientErr == nil:
return // success, we're done
case errors.As(clientErr, &urlError):
// if a network error (url.Error), then wait and retry
// as server may not be ready yet
select {
case <-time.After(time.Millisecond * 100):
continue
case <-ctx.Done(): // unless we run out of time first
clientErr = context.Canceled
return
}
default:
return // other error type, we're done
}
}
}()
// wait for both to complete
jobs.Wait()
// report back if either failed
if clientErr != nil || serverErr != nil {
return fmt.Errorf("client or server error, client: %v, server: %v", clientErr, serverErr)
}
return nil
}
func TestHttpListen(t *testing.T) {
td := t.TempDir()
// create some content and parent dirs
if err := os.MkdirAll(filepath.Join(td, "data", "repo1"), 0700); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(td, "data", "repo1", "config"), []byte("foo"), 0700); err != nil {
t.Fatal(err)
}
for _, args := range [][]string{
{"--no-auth", "--path", filepath.Join(td, "data"), "--listen", "127.0.0.1:0"}, // test emphemeral port
{"--no-auth", "--path", filepath.Join(td, "data"), "--listen", "127.0.0.1:9000"}, // test "normal" port
{"--no-auth", "--path", filepath.Join(td, "data"), "--listen", "127.0.0.1:9000"}, // test that server was shutdown cleanly and that we can re-use that port
} {
err := testServerWithArgs(args, time.Second*10, func(ctx context.Context, app *restServerApp) error {
for _, test := range []struct {
Path string
StatusCode int
}{
{"/repo1/", http.StatusMethodNotAllowed},
{"/repo1/config", http.StatusOK},
{"/repo2/config", http.StatusNotFound},
} {
listenAddr := app.ListenerAddress()
if listenAddr == nil {
return &url.Error{} // return this type of err, as we know this will retry
}
port := strings.Split(listenAddr.String(), ":")[1]
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:%s%s", port, test.Path), nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
err = resp.Body.Close()
if err != nil {
return err
}
if resp.StatusCode != test.StatusCode {
return fmt.Errorf("expected %d from server, instead got %d (path %s)", test.StatusCode, resp.StatusCode, test.Path)
}
}
return nil
})
if err != nil {
t.Fatal(err)
}
}
}