From 1db2e397555e244d064a4ef2c0cd15588063d37d Mon Sep 17 00:00:00 2001 From: Aaron Bieber Date: Mon, 17 Jun 2024 15:42:22 -0600 Subject: [PATCH] split out code into different files - add tests --- handler_test.go | 280 ++++++++++++++++++++++++++++++++++++++++++++++++ handlers.go | 158 +++++++++++++++++++++++++++ main.go | 236 ++-------------------------------------- progress.go | 72 +++++++++++++ store_test.go | 33 ++++++ user.go | 47 ++++++++ user_test.go | 16 +++ 7 files changed, 612 insertions(+), 230 deletions(-) create mode 100644 handler_test.go create mode 100644 handlers.go create mode 100644 progress.go create mode 100644 store_test.go create mode 100644 user.go create mode 100644 user_test.go diff --git a/handler_test.go b/handler_test.go new file mode 100644 index 0000000..f5099a9 --- /dev/null +++ b/handler_test.go @@ -0,0 +1,280 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "testing" +) + +var ( + testDB *Store + err error + token string + user = "arst" + password = "arstarst" + document = "arstarstarstarst" + usrJson = []byte( + fmt.Sprintf(`{"username": "%s", "password": "%s"}`, user, password), + ) + progressJson = []byte( + fmt.Sprintf(`{"device": "snake", "progress": "30", "document": "%s", "percentage": 0.1, "device_id": "1234", "timestamp": 1711992660}`, document), + ) +) + +func TestMain(t *testing.M) { + os.RemoveAll("./test_db") + os.MkdirAll("./test_db", 0755) + + t.Run() + + os.RemoveAll("./test_db") +} + +func TestCreate(t *testing.T) { + req, err := http.NewRequest("POST", "/users/create", bytes.NewBuffer(usrJson)) + if err != nil { + t.Fatal(err) + } + + db, err := NewStore("./test_db") + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + allow := true + handler := http.HandlerFunc(makeCreate(&allow, db)) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusCreated { + t.Errorf("expected %v, got %v\n", http.StatusCreated, status) + } + + u := &User{Username: user, Password: password} + token, err = db.Get(u.Key()) + if err != nil { + t.Fatal(err) + } + if token == "" { + t.Fatalf("token is empty, %q", token) + } +} + +func TestCreateDup(t *testing.T) { + req, err := http.NewRequest("POST", "/users/create", bytes.NewBuffer(usrJson)) + if err != nil { + t.Fatal(err) + } + + db, err := NewStore("./test_db") + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + allow := true + handler := http.HandlerFunc(makeCreate(&allow, db)) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusPaymentRequired { + t.Errorf("expected %v, got %v\n", http.StatusPaymentRequired, status) + } +} + +func TestAuth(t *testing.T) { + req, err := http.NewRequest("GET", "/users/auth", nil) + if err != nil { + t.Fatal(err) + } + + req.Header.Set("x-auth-user", user) + req.Header.Set("x-auth-key", token) + + db, err := NewStore("./test_db") + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(makeAuth(db)) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected %v, got %v\n", http.StatusOK, status) + } +} + +func TestAuthDenied(t *testing.T) { + req, err := http.NewRequest("GET", "/users/auth", nil) + if err != nil { + t.Fatal(err) + } + + req.Header.Set("x-auth-user", user) + req.Header.Set("x-auth-key", fmt.Sprintf("bad_%s", token)) + + db, err := NewStore("./test_db") + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(makeAuth(db)) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusUnauthorized { + t.Errorf("expected %v, got %v\n", http.StatusUnauthorized, status) + } +} + +func TestProgress(t *testing.T) { + req, err := http.NewRequest("PUT", "/syncs/progress", bytes.NewBuffer(progressJson)) + if err != nil { + t.Fatal(err) + } + + req.Header.Set("x-auth-user", user) + req.Header.Set("x-auth-key", token) + + db, err := NewStore("./test_db") + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(makeProgress(db)) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected %v, got %v\n", http.StatusOK, status) + } +} + +func TestProgressDenied(t *testing.T) { + req, err := http.NewRequest("PUT", "/syncs/progress", bytes.NewBuffer(progressJson)) + if err != nil { + t.Fatal(err) + } + + req.Header.Set("x-auth-user", user) + req.Header.Set("x-auth-key", fmt.Sprintf("%s_bad", token)) + + db, err := NewStore("./test_db") + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(makeProgress(db)) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusUnauthorized { + t.Errorf("expected %v, got %v\n", http.StatusUnauthorized, status) + } +} + +func TestGetProgress(t *testing.T) { + docURL := fmt.Sprintf("/sync/progress/%s", document) + req, err := http.NewRequest("GET", docURL, nil) + if err != nil { + t.Fatal(err) + } + + req.Header.Set("x-auth-user", user) + req.Header.Set("x-auth-key", token) + + db, err := NewStore("./test_db") + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(makeDocSync(db)) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected %v, got %v\n", http.StatusOK, status) + } + + prog := &Progress{} + dec := json.NewDecoder(rr.Body) + err = dec.Decode(prog) + if err != nil { + t.Fatal(err) + } +} + +func TestGetProgressDenied(t *testing.T) { + req, err := http.NewRequest("GET", fmt.Sprintf("/syncs/progress/%s", document), nil) + if err != nil { + t.Fatal(err) + } + + req.Header.Set("x-auth-user", user) + req.Header.Set("x-auth-key", fmt.Sprintf("%s_bad", token)) + + db, err := NewStore("./test_db") + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(makeDocSync(db)) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusUnauthorized { + t.Errorf("expected %v, got %v\n", http.StatusUnauthorized, status) + } +} + +func TestGetInvalidDoc(t *testing.T) { + req, err := http.NewRequest("GET", fmt.Sprintf("/syncs/progress/%s_fake", document), nil) + if err != nil { + t.Fatal(err) + } + + req.Header.Set("x-auth-user", user) + req.Header.Set("x-auth-key", token) + + db, err := NewStore("./test_db") + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(makeDocSync(db)) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusInternalServerError { + t.Errorf("expected %v, got %v\n", http.StatusInternalServerError, status) + } +} + +func TestHealthCheck(t *testing.T) { + req, err := http.NewRequest("GET", "/healthcheck", nil) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(healthHandler) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected %v, got %v\n", http.StatusOK, status) + } +} diff --git a/handlers.go b/handlers.go new file mode 100644 index 0000000..80dd251 --- /dev/null +++ b/handlers.go @@ -0,0 +1,158 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "strings" + "time" +) + +func httpLog(r *http.Request) { + n := time.Now() + fmt.Printf("%s (%s) [%s] \"%s %s\" %03d\n", + r.RemoteAddr, + n.Format(time.RFC822Z), + r.Method, + r.URL.Path, + r.Proto, + r.ContentLength, + ) +} + +func makeAuth(d *Store) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + httpLog(r) + _, err := authUserFromHeader(d, r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"message": "Unauthorized"}`)) + return + } + w.Header().Add("Content-type", "application/json") + w.WriteHeader(200) + w.Write([]byte(`{"authorized": "OK"}`)) + } +} + +func makeProgress(d *Store) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + httpLog(r) + u, err := authUserFromHeader(d, r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"message": "Unauthorized"}`)) + return + } + prog := Progress{} + dec := json.NewDecoder(r.Body) + err = dec.Decode(&prog) + if err != nil { + log.Println(err) + http.Error(w, "invalid document", http.StatusNotFound) + return + } + + prog.User = *u + prog.Save(d) + + w.Header().Add("Content-type", "application/json") + w.WriteHeader(200) + w.Write([]byte(fmt.Sprintf(`{"document": "%s", "timestamp": "%d"}`, prog.Document, prog.Timestamp))) + } +} + +func makeCreate(reg *bool, d *Store) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + httpLog(r) + if !*reg { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"message": "Registration disabled"}`)) + return + } + u := User{} + + dec := json.NewDecoder(r.Body) + err := dec.Decode(&u) + if err != nil { + log.Println(err) + http.Error(w, "Internal Error", http.StatusInternalServerError) + return + } + + _, err = d.Get(u.Key()) + if err != nil { + d.Set(u.Key(), u.Password) + } else { + log.Println(err) + http.Error(w, "Username is already registered", http.StatusPaymentRequired) + return + } + + w.Header().Add("Content-type", "application/json") + w.WriteHeader(201) + w.Write(u.Created()) + } +} + +func makeDocSync(d *Store) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + httpLog(r) + + // TODO: I have no idea why this PathValue returns "".. dirty hack + // to grab it from the URL anyway :( + doc := r.PathValue("document") + if doc == "" { + parts := strings.Split(r.URL.String(), "/") + doc = parts[len(parts)-1] + if doc == "" { + http.Error(w, "Invalid Request", http.StatusBadRequest) + return + } + } + + u, err := authUserFromHeader(d, r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"message": "Unauthorized"}`)) + return + } + prog := Progress{ + Document: doc, + User: *u, + } + + err = prog.Get(d) + if err != nil { + log.Println(err) + http.Error(w, "Internal Error", http.StatusInternalServerError) + return + } + + b, err := json.Marshal(prog) + if err != nil { + log.Println(err) + http.Error(w, "Internal Error", http.StatusInternalServerError) + return + } + + w.Header().Add("Content-type", "application/json") + w.WriteHeader(200) + w.Write(b) + } +} + +func healthHandler(w http.ResponseWriter, r *http.Request) { + httpLog(r) + w.Header().Add("Content-type", "application/json") + w.WriteHeader(200) + w.Write([]byte(`{"state": "OK"}`)) +} + +func slashHandler(w http.ResponseWriter, r *http.Request) { + httpLog(r) + w.Header().Add("Content-type", "text/plain") + w.WriteHeader(200) + w.Write([]byte(`kogs: koreader sync server`)) +} diff --git a/main.go b/main.go index 22f97a3..04b180c 100644 --- a/main.go +++ b/main.go @@ -1,133 +1,13 @@ package main import ( - "encoding/json" "flag" - "fmt" "log" "net" "net/http" "os" - "strconv" - "time" ) -type User struct { - Username string `json:"username"` - Password string - AuthKey string -} - -func (u *User) Key() string { - return fmt.Sprintf("user:%s:key", u.Username) -} - -func (u *User) Auth(authKey string) bool { - return u.AuthKey == authKey -} - -func (u *User) Created() []byte { - j, _ := json.Marshal(u) - return j -} - -func authUserFromHeader(d *Store, r *http.Request) (*User, error) { - un := r.Header.Get("x-auth-user") - uk := r.Header.Get("x-auth-key") - - u := &User{ - Username: un, - } - storedKey, err := d.Get(u.Key()) - if err != nil { - // No user - return nil, err - } - - u.AuthKey = string(storedKey) - if u.Auth(uk) { - return u, nil - } - - return nil, fmt.Errorf("Unauthorized") -} - -type Progress struct { - Device string `json:"device"` - Progress string `json:"progress"` - Document string `json:"document"` - Percentage float64 `json:"percentage"` - DeviceID string `json:"device_id"` - Timestamp int64 `json:"timestamp"` - User User -} - -func (p *Progress) DocKey() string { - return fmt.Sprintf("user:%s:document:%s", p.User.Username, p.Document) -} - -func (p *Progress) Save(d *Store) { - d.Set(p.DocKey()+"_percent", fmt.Sprintf("%f", p.Percentage)) - d.Set(p.DocKey()+"_progress", p.Progress) - d.Set(p.DocKey()+"_device", p.Device) - d.Set(p.DocKey()+"_device_id", p.DeviceID) - d.Set(p.DocKey()+"_timestamp", fmt.Sprintf("%d", (time.Now().Unix()))) -} - -func (p *Progress) Get(d *Store) error { - if p.Document == "" { - return fmt.Errorf("invalid document") - } - pct, err := d.Get(p.DocKey() + "_percent") - if err != nil { - return err - } - p.Percentage, _ = strconv.ParseFloat(string(pct), 64) - - prog, err := d.Get(p.DocKey() + "_progress") - if err != nil { - return err - } - p.Progress = string(prog) - - dev, err := d.Get(p.DocKey() + "_device") - if err != nil { - return err - } - p.Device = string(dev) - - devID, err := d.Get(p.DocKey() + "_device_id") - if err != nil { - return err - } - p.DeviceID = string(devID) - - ts, err := d.Get(p.DocKey() + "_timestamp") - if err != nil { - return err - } - stamp, err := strconv.ParseInt(string(ts), 10, 64) - if err != nil { - return err - } - - p.Timestamp = stamp - - return nil -} - -func httpLog(r *http.Request) { - n := time.Now() - fmt.Printf("%s (%s) [%s] \"%s %s\" %03d\n", - r.RemoteAddr, - n.Format(time.RFC822Z), - r.Method, - r.URL.Path, - r.Proto, - r.ContentLength, - ) -} - func main() { reg := flag.Bool("reg", true, "enable user registration") listen := flag.String("listen", ":8383", "interface and port to listen on") @@ -151,116 +31,12 @@ func main() { } mux := http.NewServeMux() - mux.HandleFunc("POST /users/create", func(w http.ResponseWriter, r *http.Request) { - httpLog(r) - if !*reg { - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte(`{"message": "Registration disabled"}`)) - return - } - u := User{} - - dec := json.NewDecoder(r.Body) - err := dec.Decode(&u) - if err != nil { - log.Println(err) - http.Error(w, "Internal Error", http.StatusInternalServerError) - return - } - - _, err = d.Get(u.Key()) - if err != nil { - d.Set(u.Key(), u.Password) - } else { - log.Println("user exists", err) - http.Error(w, "Username is already registered", http.StatusPaymentRequired) - return - } - - w.Header().Add("Content-type", "application/json") - w.WriteHeader(201) - w.Write(u.Created()) - }) - mux.HandleFunc("GET /users/auth", func(w http.ResponseWriter, r *http.Request) { - httpLog(r) - _, err := authUserFromHeader(d, r) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte(`{"message": "Unauthorized"}`)) - return - } - w.Header().Add("Content-type", "application/json") - w.WriteHeader(200) - w.Write([]byte(`{"authorized": "OK"}`)) - }) - - mux.HandleFunc("PUT /syncs/progress", func(w http.ResponseWriter, r *http.Request) { - httpLog(r) - u, err := authUserFromHeader(d, r) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte(`{"message": "Unauthorized"}`)) - return - } - prog := Progress{} - dec := json.NewDecoder(r.Body) - err = dec.Decode(&prog) - if err != nil { - log.Println(err) - http.Error(w, "Internal Error", http.StatusInternalServerError) - return - } - prog.User = *u - prog.Save(d) - - w.Header().Add("Content-type", "application/json") - w.WriteHeader(200) - w.Write([]byte(fmt.Sprintf(`{"document": "%s", "timestamp": "%d"}`, prog.Document, prog.Timestamp))) - }) - mux.HandleFunc("GET /syncs/progress/{document}", func(w http.ResponseWriter, r *http.Request) { - httpLog(r) - u, err := authUserFromHeader(d, r) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte(`{"message": "Unauthorized"}`)) - return - } - prog := Progress{ - Document: r.PathValue("document"), - User: *u, - } - err = prog.Get(d) - if err != nil { - log.Println(err) - http.Error(w, "Internal Error", http.StatusInternalServerError) - return - } - - b, err := json.Marshal(prog) - if err != nil { - log.Println(err) - http.Error(w, "Internal Error", http.StatusInternalServerError) - return - } - - w.Header().Add("Content-type", "application/json") - w.WriteHeader(200) - w.Write(b) - }) - - mux.HandleFunc("GET /healthcheck", func(w http.ResponseWriter, r *http.Request) { - httpLog(r) - w.Header().Add("Content-type", "application/json") - w.WriteHeader(200) - w.Write([]byte(`{"state": "OK"}`)) - }) - - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - httpLog(r) - w.Header().Add("Content-type", "text/plain") - w.WriteHeader(200) - w.Write([]byte(`kogs: koreader sync server`)) - }) + mux.HandleFunc("POST /users/create", makeCreate(reg, d)) + mux.HandleFunc("GET /users/auth", makeAuth(d)) + mux.HandleFunc("GET /syncs/progress/{document}", makeDocSync(d)) + mux.HandleFunc("PUT /syncs/progress", makeProgress(d)) + mux.HandleFunc("GET /healthcheck", healthHandler) + mux.HandleFunc("/", slashHandler) s := http.Server{ Handler: mux, diff --git a/progress.go b/progress.go new file mode 100644 index 0000000..b88be6b --- /dev/null +++ b/progress.go @@ -0,0 +1,72 @@ +package main + +import ( + "fmt" + "strconv" + "time" +) + +type Progress struct { + Device string `json:"device"` + Progress string `json:"progress"` + Document string `json:"document"` + Percentage float64 `json:"percentage"` + DeviceID string `json:"device_id"` + Timestamp int64 `json:"timestamp"` + User User +} + +func (p *Progress) DocKey() string { + return fmt.Sprintf("user:%s:document:%s", p.User.Username, p.Document) +} + +func (p *Progress) Save(d *Store) { + d.Set(p.DocKey()+"_percent", fmt.Sprintf("%f", p.Percentage)) + d.Set(p.DocKey()+"_progress", p.Progress) + d.Set(p.DocKey()+"_device", p.Device) + d.Set(p.DocKey()+"_device_id", p.DeviceID) + d.Set(p.DocKey()+"_timestamp", fmt.Sprintf("%d", (time.Now().Unix()))) +} + +func (p *Progress) Get(d *Store) error { + if p.Document == "" { + return fmt.Errorf("invalid document") + } + + pct, err := d.Get(p.DocKey() + "_percent") + if err != nil { + return err + } + p.Percentage, _ = strconv.ParseFloat(string(pct), 64) + + prog, err := d.Get(p.DocKey() + "_progress") + if err != nil { + return err + } + p.Progress = string(prog) + + dev, err := d.Get(p.DocKey() + "_device") + if err != nil { + return err + } + p.Device = string(dev) + + devID, err := d.Get(p.DocKey() + "_device_id") + if err != nil { + return err + } + p.DeviceID = string(devID) + + ts, err := d.Get(p.DocKey() + "_timestamp") + if err != nil { + return err + } + stamp, err := strconv.ParseInt(string(ts), 10, 64) + if err != nil { + return err + } + + p.Timestamp = stamp + + return nil +} diff --git a/store_test.go b/store_test.go new file mode 100644 index 0000000..2d0fc6c --- /dev/null +++ b/store_test.go @@ -0,0 +1,33 @@ +package main + +import "testing" + +var ( + db *Store + dir string + value = "a value" +) + +func TestStore(t *testing.T) { + dir = t.TempDir() + db, err = NewStore(dir) + if err != nil { + t.Fatal(err) + } + + db.Set("somekey", value) + + val, err := db.Get("somekey") + if err != nil { + t.Fatal(err) + } + + if val != value { + t.Errorf("expected %q, got %q\n", value, val) + } + + val, err = db.Get("fakekey") + if err == nil { + t.Errorf("expected %q, got %q\n", "error", val) + } +} diff --git a/user.go b/user.go new file mode 100644 index 0000000..93f3206 --- /dev/null +++ b/user.go @@ -0,0 +1,47 @@ +package main + +import ( + "encoding/json" + "fmt" + "net/http" +) + +type User struct { + Username string `json:"username"` + Password string + AuthKey string +} + +func (u *User) Key() string { + return fmt.Sprintf("user:%s:key", u.Username) +} + +func (u *User) Auth(authKey string) bool { + return u.AuthKey == authKey +} + +func (u *User) Created() []byte { + j, _ := json.Marshal(u) + return j +} + +func authUserFromHeader(d *Store, r *http.Request) (*User, error) { + un := r.Header.Get("x-auth-user") + uk := r.Header.Get("x-auth-key") + + u := &User{ + Username: un, + } + storedKey, err := d.Get(u.Key()) + if err != nil { + // No user + return nil, err + } + + u.AuthKey = string(storedKey) + if u.Auth(uk) { + return u, nil + } + + return nil, fmt.Errorf("Unauthorized") +} diff --git a/user_test.go b/user_test.go new file mode 100644 index 0000000..c5bb6f2 --- /dev/null +++ b/user_test.go @@ -0,0 +1,16 @@ +package main + +import "testing" + +func TestUserKey(t *testing.T) { + u := &User{ + Username: "arst", + } + + key := u.Key() + exp := "user:arst:key" + + if key != exp { + t.Errorf("expected %q, got %q\n", exp, key) + } +}