From 545e4f38e0c177484ffa409c2fa1265423a5855f Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 13 Sep 2023 12:02:38 -0400 Subject: [PATCH] net/http: ServeMux handles extended patterns Modify ServeMux to handle patterns with methods and wildcards. Remove the map and list of patterns. Instead patterns are registered and matched using a routing tree. We also reorganize the code around "trailing-slash redirection," the feature whereby a trailing slash is added to a path if it doesn't match an existing one. The existing code checked the map of paths twice, but searching the tree twice would be needlessly expensive. The rewrite searches the tree once, and then again only if a trailing-slash redirection is possible. There are a few omitted features in this CL, indicated with TODOs. Change-Id: Ifaef59f6c8c7b7131dc4a5d0f101cc22887bdc74 Reviewed-on: https://go-review.googlesource.com/c/go/+/528039 Run-TryBot: Jonathan Amsterdam TryBot-Result: Gopher Robot Reviewed-by: Damien Neil --- src/net/http/server.go | 331 ++++++++++++++++++++---------------- src/net/http/server_test.go | 86 +++++++++- 2 files changed, 267 insertions(+), 150 deletions(-) diff --git a/src/net/http/server.go b/src/net/http/server.go index 0d75b87765..26df238495 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -23,7 +23,6 @@ import ( urlpkg "net/url" "path" "runtime" - "sort" "strconv" "strings" "sync" @@ -2281,6 +2280,9 @@ func RedirectHandler(url string, code int) Handler { return &redirectHandler{url, code} } +// TODO(jba): rewrite the following doc for enhanced patterns (proposal +// https://go.dev/issue/61410). + // ServeMux is an HTTP request multiplexer. // It matches the URL of each incoming request against a list of registered // patterns and calls the handler for the pattern that @@ -2317,19 +2319,15 @@ func RedirectHandler(url string, code int) Handler { // header, stripping the port number and redirecting any request containing . or // .. elements or repeated slashes to an equivalent, cleaner URL. type ServeMux struct { - mu sync.RWMutex - m map[string]muxEntry - es []muxEntry // slice of entries sorted from longest to shortest. - hosts bool // whether any patterns contain hostnames -} - -type muxEntry struct { - h Handler - pattern string + mu sync.RWMutex + tree routingNode + patterns []*pattern } // NewServeMux allocates and returns a new ServeMux. -func NewServeMux() *ServeMux { return new(ServeMux) } +func NewServeMux() *ServeMux { + return &ServeMux{} +} // DefaultServeMux is the default ServeMux used by Serve. var DefaultServeMux = &defaultServeMux @@ -2371,66 +2369,6 @@ func stripHostPort(h string) string { return host } -// Find a handler on a handler map given a path string. -// Most-specific (longest) pattern wins. -func (mux *ServeMux) match(path string) (h Handler, pattern string) { - // Check for exact match first. - v, ok := mux.m[path] - if ok { - return v.h, v.pattern - } - - // Check for longest valid match. mux.es contains all patterns - // that end in / sorted from longest to shortest. - for _, e := range mux.es { - if strings.HasPrefix(path, e.pattern) { - return e.h, e.pattern - } - } - return nil, "" -} - -// redirectToPathSlash determines if the given path needs appending "/" to it. -// This occurs when a handler for path + "/" was already registered, but -// not for path itself. If the path needs appending to, it creates a new -// URL, setting the path to u.Path + "/" and returning true to indicate so. -func (mux *ServeMux) redirectToPathSlash(host, path string, u *url.URL) (*url.URL, bool) { - mux.mu.RLock() - shouldRedirect := mux.shouldRedirectRLocked(host, path) - mux.mu.RUnlock() - if !shouldRedirect { - return u, false - } - path = path + "/" - u = &url.URL{Path: path, RawQuery: u.RawQuery} - return u, true -} - -// shouldRedirectRLocked reports whether the given path and host should be redirected to -// path+"/". This should happen if a handler is registered for path+"/" but -// not path -- see comments at ServeMux. -func (mux *ServeMux) shouldRedirectRLocked(host, path string) bool { - p := []string{path, host + path} - - for _, c := range p { - if _, exist := mux.m[c]; exist { - return false - } - } - - n := len(path) - if n == 0 { - return false - } - for _, c := range p { - if _, exist := mux.m[c+"/"]; exist { - return path[n-1] != '/' - } - } - - return false -} - // Handler returns the handler to use for the given request, // consulting r.Method, r.Host, and r.URL.Path. It always returns // a non-nil handler. If the path is not in its canonical form, the @@ -2442,61 +2380,144 @@ func (mux *ServeMux) shouldRedirectRLocked(host, path string) bool { // // Handler also returns the registered pattern that matches the // request or, in the case of internally-generated redirects, -// the pattern that will match after following the redirect. +// the path that will match after following the redirect. // // If there is no registered handler that applies to the request, // Handler returns a “page not found” handler and an empty pattern. func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { + return mux.findHandler(r) +} + +// findHandler finds a handler for a request. +// If there is a matching handler, it returns it and the pattern that matched. +// Otherwise it returns a Redirect or NotFound handler with the path that would match +// after the redirect. +func (mux *ServeMux) findHandler(r *Request) (h Handler, patStr string) { + var n *routingNode + // TODO(jba): use escaped path. This is an independent change that is also part + // of proposal https://go.dev/issue/61410. + path := r.URL.Path // CONNECT requests are not canonicalized. if r.Method == "CONNECT" { // If r.URL.Path is /tree and its handler is not registered, // the /tree -> /tree/ redirect applies to CONNECT requests // but the path canonicalization does not. - if u, ok := mux.redirectToPathSlash(r.URL.Host, r.URL.Path, r.URL); ok { + _, _, u := mux.handler(r.URL.Host, r.Method, path, r.URL) + if u != nil { return RedirectHandler(u.String(), StatusMovedPermanently), u.Path } + // Redo the match, this time with r.Host instead of r.URL.Host. + // Pass a nil URL to skip the trailing-slash redirect logic. + n, _, _ = mux.handler(r.Host, r.Method, path, nil) + } else { + // All other requests have any port stripped and path cleaned + // before passing to mux.handler. + host := stripHostPort(r.Host) + path = cleanPath(path) - return mux.handler(r.Host, r.URL.Path) + // If the given path is /tree and its handler is not registered, + // redirect for /tree/. + var u *url.URL + n, _, u = mux.handler(host, r.Method, path, r.URL) + if u != nil { + return RedirectHandler(u.String(), StatusMovedPermanently), u.Path + } + if path != r.URL.Path { + // Redirect to cleaned path. + patStr := "" + if n != nil { + patStr = n.pattern.String() + } + u := &url.URL{Path: path, RawQuery: r.URL.RawQuery} + return RedirectHandler(u.String(), StatusMovedPermanently), patStr + } } - - // All other requests have any port stripped and path cleaned - // before passing to mux.handler. - host := stripHostPort(r.Host) - path := cleanPath(r.URL.Path) - - // If the given path is /tree and its handler is not registered, - // redirect for /tree/. - if u, ok := mux.redirectToPathSlash(host, path, r.URL); ok { - return RedirectHandler(u.String(), StatusMovedPermanently), u.Path + if n == nil { + // TODO(jba): support 405 (MethodNotAllowed) by checking for patterns with different methods. + return NotFoundHandler(), "" } - - if path != r.URL.Path { - _, pattern = mux.handler(host, path) - u := &url.URL{Path: path, RawQuery: r.URL.RawQuery} - return RedirectHandler(u.String(), StatusMovedPermanently), pattern - } - - return mux.handler(host, r.URL.Path) + return n.handler, n.pattern.String() } -// handler is the main implementation of Handler. +// handler looks up a node in the tree that matches the host, method and path. // The path is known to be in canonical form, except for CONNECT methods. -func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) { + +// If the url argument is non-nil, handler also deals with trailing-slash +// redirection: when a path doesn't match exactly, the match is tried again + +// after appending "/" to the path. If that second match succeeds, the last +// return value is the URL to redirect to. +// +// TODO(jba): give this a better name. For now we're keeping the name of the closest +// corresponding function in the original code. +func (mux *ServeMux) handler(host, method, path string, u *url.URL) (_ *routingNode, matches []string, redirectTo *url.URL) { mux.mu.RLock() defer mux.mu.RUnlock() - // Host-specific pattern takes precedence over generic ones - if mux.hosts { - h, pattern = mux.match(host + path) + n, matches := mux.tree.match(host, method, path) + // If we have an exact match, or we were asked not to try trailing-slash redirection, + // then we're done. + if !exactMatch(n, path) && u != nil { + // If there is an exact match with a trailing slash, then redirect. + path += "/" + n2, _ := mux.tree.match(host, method, path) + if exactMatch(n2, path) { + return nil, nil, &url.URL{Path: path, RawQuery: u.RawQuery} + } } - if h == nil { - h, pattern = mux.match(path) + return n, matches, nil +} + +// exactMatch reports whether the node's pattern exactly matches the path. +// As a special case, if the node is nil, exactMatch return false. +// +// Before wildcards were introduced, it was clear that an exact match meant +// that the pattern and path were the same string. The only other possibility +// was that a trailing-slash pattern, like "/", matched a path longer than +// it, like "/a". +// +// With wildcards, we define an inexact match as any one where a multi wildcard +// matches a non-empty string. All other matches are exact. +// For example, these are all exact matches: +// +// pattern path +// /a /a +// /{x} /a +// /a/{$} /a/ +// /a/ /a/ +// +// The last case has a multi wildcard (implicitly), but the match is exact because +// the wildcard matches the empty string. +// +// Examples of matches that are not exact: +// +// pattern path +// / /a +// /a/{x...} /a/b +func exactMatch(n *routingNode, path string) bool { + if n == nil { + return false } - if h == nil { - h, pattern = NotFoundHandler(), "" + // We can't directly implement the definition (empty match for multi + // wildcard) because we don't record a match for anonymous multis. + + // If there is no multi, the match is exact. + if !n.pattern.lastSegment().multi { + return true } - return + + // If the path doesn't end in a trailing slash, then the multi match + // is non-empty. + if len(path) > 0 && path[len(path)-1] != '/' { + return false + } + // Only patterns ending in {$} or a multi wildcard can + // match a path with a trailing slash. + // For the match to be exact, the number of pattern + // segments should be the same as the number of slashes in the path. + // E.g. "/a/b/{$}" and "/a/b/{...}" exactly match "/a/b/", but "/a/" does not. + return len(n.pattern.segments) == strings.Count(path, "/") } // ServeHTTP dispatches the request to the handler whose @@ -2509,71 +2530,83 @@ func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { w.WriteHeader(StatusBadRequest) return } - h, _ := mux.Handler(r) + h, _ := mux.findHandler(r) + // TODO(jba); save matches in Request. h.ServeHTTP(w, r) } +// The four functions below all call register so that callerLocation +// always refers to user code. + // Handle registers the handler for the given pattern. // If a handler already exists for pattern, Handle panics. func (mux *ServeMux) Handle(pattern string, handler Handler) { - mux.mu.Lock() - defer mux.mu.Unlock() - - if pattern == "" { - panic("http: invalid pattern") - } - if handler == nil { - panic("http: nil handler") - } - if _, exist := mux.m[pattern]; exist { - panic("http: multiple registrations for " + pattern) - } - - if mux.m == nil { - mux.m = make(map[string]muxEntry) - } - e := muxEntry{h: handler, pattern: pattern} - mux.m[pattern] = e - if pattern[len(pattern)-1] == '/' { - mux.es = appendSorted(mux.es, e) - } - - if pattern[0] != '/' { - mux.hosts = true - } -} - -func appendSorted(es []muxEntry, e muxEntry) []muxEntry { - n := len(es) - i := sort.Search(n, func(i int) bool { - return len(es[i].pattern) < len(e.pattern) - }) - if i == n { - return append(es, e) - } - // we now know that i points at where we want to insert - es = append(es, muxEntry{}) // try to grow the slice in place, any entry works. - copy(es[i+1:], es[i:]) // Move shorter entries down - es[i] = e - return es + mux.register(pattern, handler) } // HandleFunc registers the handler function for the given pattern. func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { - if handler == nil { - panic("http: nil handler") - } - mux.Handle(pattern, HandlerFunc(handler)) + mux.register(pattern, HandlerFunc(handler)) } // Handle registers the handler for the given pattern in [DefaultServeMux]. // The documentation for [ServeMux] explains how patterns are matched. -func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) } +func Handle(pattern string, handler Handler) { + DefaultServeMux.register(pattern, handler) +} // HandleFunc registers the handler function for the given pattern in [DefaultServeMux]. // The documentation for [ServeMux] explains how patterns are matched. func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { - DefaultServeMux.HandleFunc(pattern, handler) + DefaultServeMux.register(pattern, HandlerFunc(handler)) +} + +func (mux *ServeMux) register(pattern string, handler Handler) { + if err := mux.registerErr(pattern, handler); err != nil { + panic(err) + } +} + +func (mux *ServeMux) registerErr(pattern string, handler Handler) error { + if pattern == "" { + return errors.New("http: invalid pattern") + } + if handler == nil { + return errors.New("http: nil handler") + } + if f, ok := handler.(HandlerFunc); ok && f == nil { + return errors.New("http: nil handler") + } + + pat, err := parsePattern(pattern) + if err != nil { + return err + } + + // Get the caller's location, for better conflict error messages. + // Skip register and whatever calls it. + _, file, line, ok := runtime.Caller(3) + if !ok { + pat.loc = "unknown location" + } else { + pat.loc = fmt.Sprintf("%s:%d", file, line) + } + + mux.mu.Lock() + defer mux.mu.Unlock() + // Check for conflict. + // This makes a quadratic number of calls to conflictsWith: we check + // each pattern against every other pattern. + // TODO(jba): add indexing to speed this up. + for _, pat2 := range mux.patterns { + if pat.conflictsWith(pat2) { + return fmt.Errorf("pattern %q (registered at %s) conflicts with pattern %q (registered at %s)", + pat, pat.loc, pat2, pat2.loc) + } + } + mux.tree.addPattern(pat, handler) + mux.patterns = append(mux.patterns, pat) + return nil } // Serve accepts incoming HTTP connections on the listener l, diff --git a/src/net/http/server_test.go b/src/net/http/server_test.go index d17c5c1e7e..0e7bdb2f37 100644 --- a/src/net/http/server_test.go +++ b/src/net/http/server_test.go @@ -8,6 +8,7 @@ package http import ( "fmt" + "net/url" "testing" "time" ) @@ -64,6 +65,85 @@ func TestServerTLSHandshakeTimeout(t *testing.T) { } } +type handler struct{ i int } + +func (handler) ServeHTTP(ResponseWriter, *Request) {} + +func TestFindHandler(t *testing.T) { + mux := NewServeMux() + for _, ph := range []struct { + pat string + h Handler + }{ + {"/", &handler{1}}, + {"/foo/", &handler{2}}, + {"/foo", &handler{3}}, + {"/bar/", &handler{4}}, + {"//foo", &handler{5}}, + } { + mux.Handle(ph.pat, ph.h) + } + + for _, test := range []struct { + method string + path string + wantHandler string + }{ + {"GET", "/", "&http.handler{i:1}"}, + {"GET", "//", `&http.redirectHandler{url:"/", code:301}`}, + {"GET", "/foo/../bar/./..//baz", `&http.redirectHandler{url:"/baz", code:301}`}, + {"GET", "/foo", "&http.handler{i:3}"}, + {"GET", "/foo/x", "&http.handler{i:2}"}, + {"GET", "/bar/x", "&http.handler{i:4}"}, + {"GET", "/bar", `&http.redirectHandler{url:"/bar/", code:301}`}, + {"CONNECT", "/", "&http.handler{i:1}"}, + {"CONNECT", "//", "&http.handler{i:1}"}, + {"CONNECT", "//foo", "&http.handler{i:5}"}, + {"CONNECT", "/foo/../bar/./..//baz", "&http.handler{i:2}"}, + {"CONNECT", "/foo", "&http.handler{i:3}"}, + {"CONNECT", "/foo/x", "&http.handler{i:2}"}, + {"CONNECT", "/bar/x", "&http.handler{i:4}"}, + {"CONNECT", "/bar", `&http.redirectHandler{url:"/bar/", code:301}`}, + } { + var r Request + r.Method = test.method + r.Host = "example.com" + r.URL = &url.URL{Path: test.path} + gotH, _ := mux.findHandler(&r) + got := fmt.Sprintf("%#v", gotH) + if got != test.wantHandler { + t.Errorf("%s %q: got %q, want %q", test.method, test.path, got, test.wantHandler) + } + } +} + +func TestExactMatch(t *testing.T) { + for _, test := range []struct { + pattern string + path string + want bool + }{ + {"", "/a", false}, + {"/", "/a", false}, + {"/a", "/a", true}, + {"/a/{x...}", "/a/b", false}, + {"/a/{x}", "/a/b", true}, + {"/a/b/", "/a/b/", true}, + {"/a/b/{$}", "/a/b/", true}, + {"/a/", "/a/b/", false}, + } { + var n *routingNode + if test.pattern != "" { + pat := mustParsePattern(t, test.pattern) + n = &routingNode{pattern: pat} + } + got := exactMatch(n, test.path) + if got != test.want { + t.Errorf("%q, %s: got %t, want %t", test.pattern, test.path, got, test.want) + } + } +} + func BenchmarkServerMatch(b *testing.B) { fn := func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "OK") @@ -90,7 +170,11 @@ func BenchmarkServerMatch(b *testing.B) { "/products/", "/products/3/image.jpg"} b.StartTimer() for i := 0; i < b.N; i++ { - if h, p := mux.match(paths[i%len(paths)]); h != nil && p == "" { + r, err := NewRequest("GET", "http://example.com/"+paths[i%len(paths)], nil) + if err != nil { + b.Fatal(err) + } + if h, p := mux.findHandler(r); h != nil && p == "" { b.Error("impossible") } }