Store our handlers in a slice and lock them when they are being used.

Hopefully this helps with #3 (at least the HTTP 423 errors).
This commit is contained in:
Aaron Bieber 2021-12-09 05:16:00 -07:00
parent 0205c58868
commit 4fdac0116f

62
main.go
View File

@ -15,6 +15,7 @@ import (
"path/filepath" "path/filepath"
"regexp" "regexp"
"strings" "strings"
"sync"
"text/template" "text/template"
"time" "time"
@ -52,9 +53,22 @@ var (
templ *template.Template templ *template.Template
) )
type userHandlers struct { type userHandler struct {
dav *webdav.Handler mu sync.Mutex
fs http.Handler dav *webdav.Handler
fs http.Handler
name string
}
type userHandlers []userHandler
func (u userHandlers) find(name string) *userHandler {
for _, usr := range u {
if usr.name == name {
return &usr
}
}
return nil
} }
var ( var (
@ -62,7 +76,7 @@ var (
davDir string davDir string
fullListen string fullListen string
genHtpass bool genHtpass bool
handlers map[string]userHandlers handlers userHandlers
listen string listen string
passPath string passPath string
tlsCert string tlsCert string
@ -74,7 +88,6 @@ var pledges = "stdio wpath rpath cpath tty inet dns unveil"
func init() { func init() {
users = make(map[string]string) users = make(map[string]string)
handlers = make(map[string]userHandlers)
dir, err := filepath.Abs(filepath.Dir(os.Args[0])) dir, err := filepath.Abs(filepath.Dir(os.Args[0]))
if err != nil { if err != nil {
log.Fatalln(err) log.Fatalln(err)
@ -153,7 +166,10 @@ func prompt(prompt string, secure bool) (string, error) {
} }
input = string(b) input = string(b)
} else { } else {
fmt.Scanln(&input) _, err := fmt.Scanln(&input)
if err != nil {
return "", err
}
} }
return input, nil return input, nil
} }
@ -182,12 +198,15 @@ func main() {
log.Fatalln(err) log.Fatalln(err)
} }
defer f.Close()
if _, err := f.WriteString(fmt.Sprintf("%s:%s\n", user, hash)); err != nil { if _, err := f.WriteString(fmt.Sprintf("%s:%s\n", user, hash)); err != nil {
log.Fatalln(err) log.Fatalln(err)
} }
err = f.Close()
if err != nil {
log.Fatalln(err)
}
fmt.Printf("Added %q to %q\n", user, passPath) fmt.Printf("Added %q to %q\n", user, passPath)
os.Exit(0) os.Exit(0)
@ -209,13 +228,17 @@ func main() {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer p.Close()
ht := csv.NewReader(p) ht := csv.NewReader(p)
ht.Comma = ':' ht.Comma = ':'
ht.Comment = '#' ht.Comment = '#'
ht.TrimLeadingSpace = true ht.TrimLeadingSpace = true
err = p.Close()
if err != nil {
log.Fatal(err)
}
entries, err := ht.ReadAll() entries, err := ht.ReadAll()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@ -229,22 +252,24 @@ func main() {
if auth { if auth {
for u := range users { for u := range users {
uPath := path.Join(davDir, u) uPath := path.Join(davDir, u)
handlers[u] = userHandlers{ handlers = append(handlers, userHandler{
name: u,
dav: &webdav.Handler{ dav: &webdav.Handler{
LockSystem: webdav.NewMemLS(), LockSystem: webdav.NewMemLS(),
FileSystem: webdav.Dir(uPath), FileSystem: webdav.Dir(uPath),
}, },
fs: http.FileServer(http.Dir(uPath)), fs: http.FileServer(http.Dir(uPath)),
} })
} }
} else { } else {
handlers[""] = userHandlers{ handlers = append(handlers, userHandler{
name: "",
dav: &webdav.Handler{ dav: &webdav.Handler{
LockSystem: webdav.NewMemLS(), LockSystem: webdav.NewMemLS(),
FileSystem: webdav.Dir(davDir), FileSystem: webdav.Dir(davDir),
}, },
fs: http.FileServer(http.Dir(davDir)), fs: http.FileServer(http.Dir(davDir)),
} })
} }
mux := http.NewServeMux() mux := http.NewServeMux()
@ -272,7 +297,16 @@ func main() {
} }
} }
handler := handlers[user] handler := handlers.find(user)
if handler == nil {
http.NotFound(w, r)
return
}
handler.mu.Lock()
defer handler.mu.Unlock()
userPath := path.Join(davDir, user) userPath := path.Join(davDir, user)
fullPath := path.Join(davDir, user, r.URL.Path) fullPath := path.Join(davDir, user, r.URL.Path)