fix: shutdown gracefully on TERM or INT signals

This allows server listen resources to be cleaned up appropriately.
This commit is contained in:
Adam Eijdenberg
2024-02-05 10:24:54 +11:00
parent 3ce6aaf2b6
commit 8becd574cb
3 changed files with 260 additions and 88 deletions

View File

@@ -0,0 +1,7 @@
Change: Server is now shutdown cleanly on TERM or INT signals
Server now listens for TERM and INT signals and cleanly closes down the http.Server and listener.
This is particularly useful when listening on a unix socket, as the server will remove the socket file from it shuts down.
https://github.com/restic/rest-server/pull/273

View File

@@ -1,159 +1,196 @@
package main
import (
"context"
"errors"
"fmt"
"log"
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"runtime"
"runtime/pprof"
"sync"
"syscall"
restserver "github.com/restic/rest-server"
"github.com/spf13/cobra"
)
type restServerApp struct {
CmdRoot *cobra.Command
Server restserver.Server
CpuProfile string
listenerAddressMu sync.Mutex
listenerAddress net.Addr // set after startup
}
// cmdRoot is the base command when no other command has been specified.
var cmdRoot = &cobra.Command{
Use: "rest-server",
Short: "Run a REST server for use with restic",
SilenceErrors: true,
SilenceUsage: true,
RunE: runRoot,
Args: func(cmd *cobra.Command, args []string) error {
if len(args) != 0 {
return fmt.Errorf("rest-server expects no arguments - unknown argument: %s", args[0])
}
return nil
},
Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH),
}
func newRestServerApp() *restServerApp {
rv := &restServerApp{
CmdRoot: &cobra.Command{
Use: "rest-server",
Short: "Run a REST server for use with restic",
SilenceErrors: true,
SilenceUsage: true,
Args: func(cmd *cobra.Command, args []string) error {
if len(args) != 0 {
return fmt.Errorf("rest-server expects no arguments - unknown argument: %s", args[0])
}
return nil
},
Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH),
},
Server: restserver.Server{
Path: filepath.Join(os.TempDir(), "restic"),
Listen: ":8000",
},
}
rv.CmdRoot.RunE = rv.runRoot
flags := rv.CmdRoot.Flags()
var server = restserver.Server{
Path: filepath.Join(os.TempDir(), "restic"),
Listen: ":8000",
}
var (
cpuProfile string
)
func init() {
flags := cmdRoot.Flags()
flags.StringVar(&cpuProfile, "cpu-profile", cpuProfile, "write CPU profile to file")
flags.BoolVar(&server.Debug, "debug", server.Debug, "output debug messages")
flags.StringVar(&server.Listen, "listen", server.Listen, "listen address")
flags.StringVar(&server.Log, "log", server.Log, "write HTTP requests in the combined log format to the specified `filename` (use \"-\" for logging to stdout)")
flags.Int64Var(&server.MaxRepoSize, "max-size", server.MaxRepoSize, "the maximum size of the repository in bytes")
flags.StringVar(&server.Path, "path", server.Path, "data directory")
flags.BoolVar(&server.TLS, "tls", server.TLS, "turn on TLS support")
flags.StringVar(&server.TLSCert, "tls-cert", server.TLSCert, "TLS certificate path")
flags.StringVar(&server.TLSKey, "tls-key", server.TLSKey, "TLS key path")
flags.BoolVar(&server.NoAuth, "no-auth", server.NoAuth, "disable .htpasswd authentication")
flags.StringVar(&server.HtpasswdPath, "htpasswd-file", server.HtpasswdPath, "location of .htpasswd file (default: \"<data directory>/.htpasswd)\"")
flags.BoolVar(&server.NoVerifyUpload, "no-verify-upload", server.NoVerifyUpload,
flags.StringVar(&rv.CpuProfile, "cpu-profile", rv.CpuProfile, "write CPU profile to file")
flags.BoolVar(&rv.Server.Debug, "debug", rv.Server.Debug, "output debug messages")
flags.StringVar(&rv.Server.Listen, "listen", rv.Server.Listen, "listen address")
flags.StringVar(&rv.Server.Log, "log", rv.Server.Log, "write HTTP requests in the combined log format to the specified `filename` (use \"-\" for logging to stdout)")
flags.Int64Var(&rv.Server.MaxRepoSize, "max-size", rv.Server.MaxRepoSize, "the maximum size of the repository in bytes")
flags.StringVar(&rv.Server.Path, "path", rv.Server.Path, "data directory")
flags.BoolVar(&rv.Server.TLS, "tls", rv.Server.TLS, "turn on TLS support")
flags.StringVar(&rv.Server.TLSCert, "tls-cert", rv.Server.TLSCert, "TLS certificate path")
flags.StringVar(&rv.Server.TLSKey, "tls-key", rv.Server.TLSKey, "TLS key path")
flags.BoolVar(&rv.Server.NoAuth, "no-auth", rv.Server.NoAuth, "disable .htpasswd authentication")
flags.StringVar(&rv.Server.HtpasswdPath, "htpasswd-file", rv.Server.HtpasswdPath, "location of .htpasswd file (default: \"<data directory>/.htpasswd)\"")
flags.BoolVar(&rv.Server.NoVerifyUpload, "no-verify-upload", rv.Server.NoVerifyUpload,
"do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device")
flags.BoolVar(&server.AppendOnly, "append-only", server.AppendOnly, "enable append only mode")
flags.BoolVar(&server.PrivateRepos, "private-repos", server.PrivateRepos, "users can only access their private repo")
flags.BoolVar(&server.Prometheus, "prometheus", server.Prometheus, "enable Prometheus metrics")
flags.BoolVar(&server.PrometheusNoAuth, "prometheus-no-auth", server.PrometheusNoAuth, "disable auth for Prometheus /metrics endpoint")
flags.BoolVar(&rv.Server.AppendOnly, "append-only", rv.Server.AppendOnly, "enable append only mode")
flags.BoolVar(&rv.Server.PrivateRepos, "private-repos", rv.Server.PrivateRepos, "users can only access their private repo")
flags.BoolVar(&rv.Server.Prometheus, "prometheus", rv.Server.Prometheus, "enable Prometheus metrics")
flags.BoolVar(&rv.Server.PrometheusNoAuth, "prometheus-no-auth", rv.Server.PrometheusNoAuth, "disable auth for Prometheus /metrics endpoint")
return rv
}
var version = "0.12.1-dev"
func tlsSettings() (bool, string, string, error) {
func (app *restServerApp) tlsSettings() (bool, string, string, error) {
var key, cert string
if !server.TLS && (server.TLSKey != "" || server.TLSCert != "") {
if !app.Server.TLS && (app.Server.TLSKey != "" || app.Server.TLSCert != "") {
return false, "", "", errors.New("requires enabled TLS")
} else if !server.TLS {
} else if !app.Server.TLS {
return false, "", "", nil
}
if server.TLSKey != "" {
key = server.TLSKey
if app.Server.TLSKey != "" {
key = app.Server.TLSKey
} else {
key = filepath.Join(server.Path, "private_key")
key = filepath.Join(app.Server.Path, "private_key")
}
if server.TLSCert != "" {
cert = server.TLSCert
if app.Server.TLSCert != "" {
cert = app.Server.TLSCert
} else {
cert = filepath.Join(server.Path, "public_key")
cert = filepath.Join(app.Server.Path, "public_key")
}
return server.TLS, key, cert, nil
return app.Server.TLS, key, cert, nil
}
func runRoot(cmd *cobra.Command, args []string) error {
// returns the address that the app is listening on.
// returns nil if the application hasn't finished starting yet
func (app *restServerApp) ListenerAddress() net.Addr {
app.listenerAddressMu.Lock()
defer app.listenerAddressMu.Unlock()
return app.listenerAddress
}
func (app *restServerApp) runRoot(cmd *cobra.Command, args []string) error {
log.SetFlags(0)
log.Printf("Data directory: %s", server.Path)
log.Printf("Data directory: %s", app.Server.Path)
if cpuProfile != "" {
f, err := os.Create(cpuProfile)
if app.CpuProfile != "" {
f, err := os.Create(app.CpuProfile)
if err != nil {
return err
}
defer f.Close()
if err := pprof.StartCPUProfile(f); err != nil {
return err
}
log.Println("CPU profiling enabled")
defer pprof.StopCPUProfile()
// clean profiling shutdown on sigint
sigintCh := make(chan os.Signal, 1)
go func() {
for range sigintCh {
pprof.StopCPUProfile()
log.Println("Stopped CPU profiling")
err := f.Close()
if err != nil {
log.Printf("error closing CPU profile file: %v", err)
}
os.Exit(130)
}
}()
signal.Notify(sigintCh, syscall.SIGINT)
log.Println("CPU profiling enabled")
defer log.Println("Stopped CPU profiling")
}
if server.NoAuth {
if app.Server.NoAuth {
log.Println("Authentication disabled")
} else {
log.Println("Authentication enabled")
}
handler, err := restserver.NewHandler(&server)
handler, err := restserver.NewHandler(&app.Server)
if err != nil {
log.Fatalf("error: %v", err)
}
if server.PrivateRepos {
if app.Server.PrivateRepos {
log.Println("Private repositories enabled")
} else {
log.Println("Private repositories disabled")
}
enabledTLS, privateKey, publicKey, err := tlsSettings()
enabledTLS, privateKey, publicKey, err := app.tlsSettings()
if err != nil {
return err
}
listener, err := findListener(server.Listen)
listener, err := findListener(app.Server.Listen)
if err != nil {
return fmt.Errorf("unable to listen: %w", err)
}
if !enabledTLS {
err = http.Serve(listener, handler)
} else {
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
err = http.ServeTLS(listener, handler, publicKey, privateKey)
// set listener address, this is useful for tests
app.listenerAddressMu.Lock()
app.listenerAddress = listener.Addr()
app.listenerAddressMu.Unlock()
srv := &http.Server{
Handler: handler,
}
return err
// run server in background
go func() {
if !enabledTLS {
err = srv.Serve(listener)
} else {
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
err = srv.ServeTLS(listener, publicKey, privateKey)
}
if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("listen and serve returned err: %v", err)
}
}()
// wait until done
<-app.CmdRoot.Context().Done()
// gracefully shutdown server
if err := srv.Shutdown(context.Background()); err != nil {
return fmt.Errorf("server shutdown returned an err: %w", err)
}
log.Println("shutdown cleanly")
return nil
}
func main() {
if err := cmdRoot.Execute(); err != nil {
// create context to be notified on interrupt or term signal so that we can shutdown cleanly
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
if err := newRestServerApp().CmdRoot.ExecuteContext(ctx); err != nil {
log.Fatalf("error: %v", err)
}
}

View File

@@ -1,10 +1,18 @@
package main
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
restserver "github.com/restic/rest-server"
)
@@ -47,17 +55,17 @@ func TestTLSSettings(t *testing.T) {
}
for _, test := range tests {
app := newRestServerApp()
t.Run("", func(t *testing.T) {
// defer func() { restserver.Server = defaultConfig }()
if test.passed.Path != "" {
server.Path = test.passed.Path
app.Server.Path = test.passed.Path
}
server.TLS = test.passed.TLS
server.TLSKey = test.passed.TLSKey
server.TLSCert = test.passed.TLSCert
app.Server.TLS = test.passed.TLS
app.Server.TLSKey = test.passed.TLSKey
app.Server.TLSCert = test.passed.TLSCert
gotTLS, gotKey, gotCert, err := tlsSettings()
gotTLS, gotKey, gotCert, err := app.tlsSettings()
if err != nil && !test.expected.Error {
t.Fatalf("tls_settings returned err (%v)", err)
}
@@ -146,3 +154,123 @@ func TestGetHandler(t *testing.T) {
t.Errorf("NoAuth=false with .htpasswd: expected no error, got %v", err)
}
}
// helper method to test the app. Starts app with passed arguments,
// then will call the callback function which can make requests against
// the application. If the callback function fails due to errors returned
// by http.Do() (i.e. *url.Error), then it will be retried until successful,
// or the passed timeout passes.
func testServerWithArgs(args []string, timeout time.Duration, cb func(context.Context, *restServerApp) error) error {
// create the app with passed args
app := newRestServerApp()
app.CmdRoot.SetArgs(args)
// create context that will timeout
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// wait group for our client and server tasks
jobs := &sync.WaitGroup{}
jobs.Add(2)
// run the server, saving the error
var serverErr error
go func() {
defer jobs.Done()
defer cancel() // if the server is stopped, no point keep the client alive
serverErr = app.CmdRoot.ExecuteContext(ctx)
}()
// run the client, saving the error
var clientErr error
go func() {
defer jobs.Done()
defer cancel() // once the client is done, stop the server
var urlError *url.Error
// execute in loop, as we will retry for network errors
// (such as the server hasn't started yet)
for {
clientErr = cb(ctx, app)
switch {
case clientErr == nil:
return // success, we're done
case errors.As(clientErr, &urlError):
// if a network error (url.Error), then wait and retry
// as server may not be ready yet
select {
case <-time.After(time.Millisecond * 100):
continue
case <-ctx.Done(): // unless we run out of time first
clientErr = context.Canceled
return
}
default:
return // other error type, we're done
}
}
}()
// wait for both to complete
jobs.Wait()
// report back if either failed
if clientErr != nil || serverErr != nil {
return fmt.Errorf("client or server error, client: %v, server: %v", clientErr, serverErr)
}
return nil
}
func TestHttpListen(t *testing.T) {
td := t.TempDir()
// create some content and parent dirs
if err := os.MkdirAll(filepath.Join(td, "data", "repo1"), 0700); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(td, "data", "repo1", "config"), []byte("foo"), 0700); err != nil {
t.Fatal(err)
}
for _, args := range [][]string{
{"--no-auth", "--path", filepath.Join(td, "data"), "--listen", ":0"}, // test emphemeral port
{"--no-auth", "--path", filepath.Join(td, "data"), "--listen", ":9000"}, // test "normal" port
{"--no-auth", "--path", filepath.Join(td, "data"), "--listen", ":9000"}, // test that server was shutdown cleanly and that we can re-use that port
} {
err := testServerWithArgs(args, time.Second, func(ctx context.Context, app *restServerApp) error {
for _, test := range []struct {
Path string
StatusCode int
}{
{"/repo1/", http.StatusMethodNotAllowed},
{"/repo1/config", http.StatusOK},
{"/repo2/config", http.StatusNotFound},
} {
listenAddr := app.ListenerAddress()
if listenAddr == nil {
return &url.Error{} // return this type of err, as we know this will retry
}
port := strings.Split(listenAddr.String(), ":")[1]
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:%s%s", port, test.Path), nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
resp.Body.Close()
if resp.StatusCode != test.StatusCode {
return fmt.Errorf("expected %d from server, instead got %d (path %s)", test.StatusCode, resp.StatusCode, test.Path)
}
}
return nil
})
if err != nil {
t.Fatal(err)
}
}
}