Merge pull request #138 from MichaelEischer/cache-basic-auth

Cache basic auth credentials
This commit is contained in:
MichaelEischer
2022-07-02 21:19:46 +02:00
committed by GitHub
3 changed files with 143 additions and 10 deletions

View File

@@ -0,0 +1,9 @@
Enhancement: Cache basic auth credentials
To speed up the verification of basic auth credentials, rest-server now caches
passwords for a minute in memory. That way the expensive verification of basic
auth credentials can be skipped for most requests issued by a single restic
run. The password is kept in memory in a hashed form and not as plaintext.
https://github.com/restic/rest-server/issues/133
https://github.com/restic/rest-server/pull/138

View File

@@ -26,6 +26,8 @@ THE SOFTWARE.
import (
"crypto/sha1"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/csv"
"log"
@@ -42,15 +44,26 @@ import (
// CheckInterval represents how often we check for changes in htpasswd file.
const CheckInterval = 30 * time.Second
// PasswordCacheDuration represents how long authentication credentials are
// cached in memory after they were successfully verified. This allows avoiding
// repeatedly verifying the same authentication credentials.
const PasswordCacheDuration = time.Minute
// Lookup passwords in a htpasswd file. The entries must have been created with -s for SHA encryption.
type cacheEntry struct {
expiry time.Time
verifier []byte
}
// HtpasswdFile is a map for usernames to passwords.
type HtpasswdFile struct {
mutex sync.Mutex
path string
stat os.FileInfo
throttle chan struct{}
Users map[string]string
users map[string]string
cache map[string]cacheEntry
}
// NewHtpasswdFromFile reads the users and passwords from a htpasswd file and returns them. If an error is encountered,
@@ -68,6 +81,7 @@ func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) {
path: path,
stat: stat,
throttle: make(chan struct{}),
cache: make(map[string]cacheEntry),
}
if err := h.Reload(); err != nil {
@@ -76,6 +90,7 @@ func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) {
// Start a goroutine that limits reload checks to once per CheckInterval
go h.throttleTimer()
go h.expiryTimer()
go func() {
for range c {
@@ -100,6 +115,23 @@ func (h *HtpasswdFile) throttleTimer() {
}
}
func (h *HtpasswdFile) expiryTimer() {
for {
time.Sleep(5 * time.Second)
now := time.Now()
h.mutex.Lock()
var zeros [sha256.Size]byte
// try to wipe expired cache entries
for user, entry := range h.cache {
if entry.expiry.After(now) {
copy(entry.verifier, zeros[:])
delete(h.cache, user)
}
}
h.mutex.Unlock()
}
}
var validUsernameRegexp = regexp.MustCompile(`^[\p{L}\d@._-]+$`)
// Reload reloads the htpasswd file. If the reload fails, the Users map is not changed and the error is returned.
@@ -130,7 +162,14 @@ func (h *HtpasswdFile) Reload() error {
// Replace the Users map
h.mutex.Lock()
h.Users = users
var zeros [sha256.Size]byte
// try to wipe the old cache entries
for _, entry := range h.cache {
copy(entry.verifier, zeros[:])
}
h.cache = make(map[string]cacheEntry)
h.users = users
h.mutex.Unlock()
_ = r.Close()
@@ -172,35 +211,74 @@ func (h *HtpasswdFile) ReloadCheck() error {
return nil
}
var shaRe = regexp.MustCompile(`^{SHA}`)
var bcrRe = regexp.MustCompile(`^\$2b\$|^\$2a\$|^\$2y\$`)
// Validate returns true if password matches the stored password for user. If no password for user is stored, or the
// password is wrong, false is returned.
func (h *HtpasswdFile) Validate(user string, password string) bool {
_ = h.ReloadCheck()
hash := sha256.New()
// hash.Write can never fail
_, _ = hash.Write([]byte(user))
_, _ = hash.Write([]byte(":"))
_, _ = hash.Write([]byte(password))
h.mutex.Lock()
realPassword, exists := h.Users[user]
// avoid race conditions with cache replacements
cache := h.cache
hashedPassword, exists := h.users[user]
entry, cacheExists := h.cache[user]
h.mutex.Unlock()
if !exists {
return false
}
var shaRe = regexp.MustCompile(`^{SHA}`)
var bcrRe = regexp.MustCompile(`^\$2b\$|^\$2a\$|^\$2y\$`)
if cacheExists && subtle.ConstantTimeCompare(entry.verifier, hash.Sum(nil)) == 1 {
h.mutex.Lock()
// repurpose mutex to prevent concurrent cache updates
// extend cache entry
cache[user] = cacheEntry{
verifier: entry.verifier,
expiry: time.Now().Add(PasswordCacheDuration),
}
h.mutex.Unlock()
return true
}
isValid := isMatchingHashAndPassword(hashedPassword, password)
if !isValid {
log.Printf("Invalid htpasswd entry for %s.", user)
return false
}
h.mutex.Lock()
// repurpose mutex to prevent concurrent cache updates
cache[user] = cacheEntry{
verifier: hash.Sum(nil),
expiry: time.Now().Add(PasswordCacheDuration),
}
h.mutex.Unlock()
return true
}
func isMatchingHashAndPassword(hashedPassword string, password string) bool {
switch {
case shaRe.MatchString(realPassword):
case shaRe.MatchString(hashedPassword):
d := sha1.New()
_, _ = d.Write([]byte(password))
if realPassword[5:] == base64.StdEncoding.EncodeToString(d.Sum(nil)) {
if subtle.ConstantTimeCompare([]byte(hashedPassword[5:]), []byte(base64.StdEncoding.EncodeToString(d.Sum(nil)))) == 1 {
return true
}
case bcrRe.MatchString(realPassword):
err := bcrypt.CompareHashAndPassword([]byte(realPassword), []byte(password))
case bcrRe.MatchString(hashedPassword):
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
if err == nil {
return true
}
}
log.Printf("Invalid htpasswd entry for %s.", user)
return false
}

46
htpasswd_test.go Normal file
View File

@@ -0,0 +1,46 @@
package restserver
import (
"io/ioutil"
"os"
"testing"
)
func TestValidate(t *testing.T) {
user := "restic"
pwd := "$2y$05$z/OEmNQamd6m6LSegUErh.r/Owk9Xwmc5lxDheIuHY2Z7XiS6FtJm"
rawPwd := "test"
wrongPwd := "wrong"
tmpfile, err := ioutil.TempFile("", "rest-validate-")
if err != nil {
t.Fatal(err)
}
if _, err = tmpfile.Write([]byte(user + ":" + pwd + "\n")); err != nil {
t.Fatal(err)
}
if err = tmpfile.Close(); err != nil {
t.Fatal(err)
}
htpass, err := NewHtpasswdFromFile(tmpfile.Name())
if err != nil {
t.Fatal(err)
}
for i := 0; i < 10; i++ {
isValid := htpass.Validate(user, rawPwd)
if !isValid {
t.Fatal("correct password not accepted")
}
isValid = htpass.Validate(user, wrongPwd)
if isValid {
t.Fatal("wrong password accepted")
}
}
if err = os.Remove(tmpfile.Name()); err != nil {
t.Fatal(err)
}
}