From 3ddc711cc19939ba7176e37bb2488cfe61e5dc1b Mon Sep 17 00:00:00 2001 From: Ysbrand van Eijck Date: Wed, 8 Dec 2021 13:42:48 +0100 Subject: [PATCH] Adds header authentication --- main.go | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/main.go b/main.go index 8727c7e..6317b3f 100644 --- a/main.go +++ b/main.go @@ -58,7 +58,7 @@ type userHandlers struct { } var ( - auth bool + auth string davDir string fullListen string genHtpass bool @@ -85,7 +85,7 @@ func init() { flag.StringVar(&tlsCert, "tlscert", "", "TLS certificate.") flag.StringVar(&tlsKey, "tlskey", "", "TLS key.") flag.StringVar(&passPath, "htpass", fmt.Sprintf("%s/.htpasswd", dir), "Path to .htpasswd file..") - flag.BoolVar(&auth, "auth", true, "Enable HTTP Basic Authentication.") + flag.StringVar(&auth, "auth", "basic", "Enable HTTP Basic Authentication.") flag.BoolVar(&genHtpass, "gen", false, "Generate a .htpasswd file or add a new entry to an existing file.") flag.Parse() @@ -100,7 +100,6 @@ func init() { if err != nil { log.Fatalln(err) } - } func authenticate(user string, pass string) bool { @@ -200,7 +199,7 @@ func main() { _, fErr := os.Stat(passPath) if os.IsNotExist(fErr) { - if auth { + if auth == "basic" || auth == "header" { fmt.Println("No .htpasswd file found!") os.Exit(1) } @@ -226,7 +225,7 @@ func main() { } } - if auth { + if auth == "basic" || auth == "header" { for u := range users { uPath := path.Join(davDir, u) handlers[u] = userHandlers{ @@ -263,8 +262,23 @@ func main() { return } - if auth { + if auth == "basic" { user, pass, ok = r.BasicAuth() + if !(ok && authenticate(user, pass)) { + w.Header().Set("WWW-Authenticate", `Basic realm="widdler"`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + } else if auth == "header" { + var prefix = "Auth" + for name, values := range r.Header { + if strings.HasPrefix(name, prefix) { + user = strings.TrimLeft(name, prefix) + pass = values[0] + ok = true + } + } + if !(ok && authenticate(user, pass)) { w.Header().Set("WWW-Authenticate", `Basic realm="widdler"`) http.Error(w, "Unauthorized", http.StatusUnauthorized) @@ -353,5 +367,4 @@ func main() { log.Printf("Listening for HTTP on 'http://%s'", listen) log.Fatalln(s.Serve(lis)) } - }