diff --git a/src/pkg/http/serve_test.go b/src/pkg/http/serve_test.go index 2bb423b15f7..c172c6c42d9 100644 --- a/src/pkg/http/serve_test.go +++ b/src/pkg/http/serve_test.go @@ -14,6 +14,7 @@ import ( "io/ioutil" "os" "net" + "strings" "testing" "time" ) @@ -349,3 +350,85 @@ func TestServerTimeouts(t *testing.T) { l.Close() } + +// TestIdentityResponse verifies that a handler can unset +func TestIdentityResponse(t *testing.T) { + l, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("failed to listen on a port: %v", err) + } + defer l.Close() + urlBase := "http://" + l.Addr().String() + "/" + + handler := HandlerFunc(func(rw ResponseWriter, req *Request) { + rw.SetHeader("Content-Length", "3") + rw.SetHeader("Transfer-Encoding", req.FormValue("te")) + switch { + case req.FormValue("overwrite") == "1": + _, err := rw.Write([]byte("foo TOO LONG")) + if err != ErrContentLength { + t.Errorf("expected ErrContentLength; got %v", err) + } + case req.FormValue("underwrite") == "1": + rw.SetHeader("Content-Length", "500") + rw.Write([]byte("too short")) + default: + rw.Write([]byte("foo")) + } + }) + + server := &Server{Handler: handler} + go server.Serve(l) + + // Note: this relies on the assumption (which is true) that + // Get sends HTTP/1.1 or greater requests. Otherwise the + // server wouldn't have the choice to send back chunked + // responses. + for _, te := range []string{"", "identity"} { + url := urlBase + "?te=" + te + res, _, err := Get(url) + if err != nil { + t.Fatalf("error with Get of %s: %v", url, err) + } + if cl, expected := res.ContentLength, int64(3); cl != expected { + t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl) + } + if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected { + t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl) + } + if tl, expected := len(res.TransferEncoding), 0; tl != expected { + t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)", + url, expected, tl, res.TransferEncoding) + } + } + + // Verify that ErrContentLength is returned + url := urlBase + "?overwrite=1" + _, _, err = Get(url) + if err != nil { + t.Fatalf("error with Get of %s: %v", url, err) + } + + // Verify that the connection is closed when the declared Content-Length + // is larger than what the handler wrote. + conn, err := net.Dial("tcp", "", l.Addr().String()) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n")) + if err != nil { + t.Fatalf("error writing: %v", err) + } + // The next ReadAll will hang for a failing test, so use a Timer instead + // to fail more traditionally + timer := time.AfterFunc(2e9, func() { + t.Fatalf("Timeout expired in ReadAll.") + }) + defer timer.Stop() + got, _ := ioutil.ReadAll(conn) + expectedSuffix := "\r\n\r\ntoo short" + if !strings.HasSuffix(string(got), expectedSuffix) { + t.Fatalf("Expected output to end with %q; got response body %q", + expectedSuffix, string(got)) + } +} diff --git a/src/pkg/http/server.go b/src/pkg/http/server.go index d16cadb3b5d..977c8c2297a 100644 --- a/src/pkg/http/server.go +++ b/src/pkg/http/server.go @@ -31,6 +31,7 @@ var ( ErrWriteAfterFlush = os.NewError("Conn.Write called after Flush") ErrBodyNotAllowed = os.NewError("http: response status code does not allow body") ErrHijacked = os.NewError("Conn has been hijacked") + ErrContentLength = os.NewError("Conn.Write wrote more than the declared Content-Length") ) // Objects implementing the Handler interface can be @@ -60,10 +61,10 @@ type ResponseWriter interface { // // Content-Type: text/html; charset=utf-8 // - // being sent. UTF-8 encoded HTML is the default setting for + // being sent. UTF-8 encoded HTML is the default setting for // Content-Type in this library, so users need not make that - // particular call. Calls to SetHeader after WriteHeader (or Write) - // are ignored. + // particular call. Calls to SetHeader after WriteHeader (or Write) + // are ignored. An empty value removes the header if previously set. SetHeader(string, string) // Write writes the data to the connection as part of an HTTP reply. @@ -108,6 +109,7 @@ type response struct { wroteContinue bool // 100 Continue response was written header map[string]string // reply header parameters written int64 // number of bytes written in body + contentLength int64 // explicitly-declared Content-Length; or -1 status int // status code passed to WriteHeader // close connection after this reply. set on request and @@ -170,33 +172,13 @@ func (c *conn) readRequest() (w *response, err os.Error) { w.conn = c w.req = req w.header = make(map[string]string) + w.contentLength = -1 // Expect 100 Continue support if req.expectsContinue() && req.ProtoAtLeast(1, 1) { // Wrap the Body reader with one that replies on the connection req.Body = &expectContinueReader{readCloser: req.Body, resp: w} } - - // Default output is HTML encoded in UTF-8. - w.SetHeader("Content-Type", "text/html; charset=utf-8") - w.SetHeader("Date", time.UTC().Format(TimeFormat)) - - if req.Method == "HEAD" { - // do nothing - } else if req.ProtoAtLeast(1, 1) { - // HTTP/1.1 or greater: use chunked transfer encoding - // to avoid closing the connection at EOF. - w.chunking = true - w.SetHeader("Transfer-Encoding", "chunked") - } else { - // HTTP version < 1.1: cannot do chunked transfer - // encoding, so signal EOF by closing connection. - // Will be overridden if the HTTP handler ends up - // writing a Content-Length and the client requested - // "Connection: keep-alive" - w.closeAfterReply = true - } - return w, nil } @@ -209,7 +191,10 @@ func (w *response) UsingTLS() bool { func (w *response) RemoteAddr() string { return w.conn.remoteAddr } // SetHeader implements the ResponseWriter.SetHeader method -func (w *response) SetHeader(hdr, val string) { w.header[CanonicalHeaderKey(hdr)] = val } +// An empty value removes the header from the map. +func (w *response) SetHeader(hdr, val string) { + w.header[CanonicalHeaderKey(hdr)] = val, val != "" +} // WriteHeader implements the ResponseWriter.WriteHeader method func (w *response) WriteHeader(code int) { @@ -225,13 +210,83 @@ func (w *response) WriteHeader(code int) { w.status = code if code == StatusNotModified { // Must not have body. - w.header["Content-Type"] = "", false - w.header["Transfer-Encoding"] = "", false - w.chunking = false + for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { + if w.header[header] != "" { + // TODO: return an error if WriteHeader gets a return parameter + // or set a flag on w to make future Writes() write an error page? + // for now just log and drop the header. + log.Printf("http: StatusNotModified response with header %q defined", header) + w.header[header] = "", false + } + } + } else { + // Default output is HTML encoded in UTF-8. + if w.header["Content-Type"] == "" { + w.SetHeader("Content-Type", "text/html; charset=utf-8") + } } + + if w.header["Date"] == "" { + w.SetHeader("Date", time.UTC().Format(TimeFormat)) + } + + // Check for a explicit (and valid) Content-Length header. + var hasCL bool + var contentLength int64 + if clenStr, ok := w.header["Content-Length"]; ok { + var err os.Error + contentLength, err = strconv.Atoi64(clenStr) + if err == nil { + hasCL = true + } else { + log.Printf("http: invalid Content-Length of %q sent", clenStr) + w.SetHeader("Content-Length", "") + } + } + + te, hasTE := w.header["Transfer-Encoding"] + if hasCL && hasTE && te != "identity" { + // TODO: return an error if WriteHeader gets a return parameter + // For now just ignore the Content-Length. + log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", + te, contentLength) + w.SetHeader("Content-Length", "") + hasCL = false + } + + if w.req.Method == "HEAD" { + // do nothing + } else if hasCL { + w.chunking = false + w.contentLength = contentLength + w.SetHeader("Transfer-Encoding", "") + } else if w.req.ProtoAtLeast(1, 1) { + // HTTP/1.1 or greater: use chunked transfer encoding + // to avoid closing the connection at EOF. + // TODO: this blows away any custom or stacked Transfer-Encoding they + // might have set. Deal with that as need arises once we have a valid + // use case. + w.chunking = true + w.SetHeader("Transfer-Encoding", "chunked") + } else { + // HTTP version < 1.1: cannot do chunked transfer + // encoding and we don't know the Content-Length so + // signal EOF by closing connection. + w.closeAfterReply = true + w.chunking = false // redundant + w.SetHeader("Transfer-Encoding", "") // in case already set + } + + if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) { + _, connectionHeaderSet := w.header["Connection"] + if !connectionHeaderSet { + w.SetHeader("Connection", "keep-alive") + } + } + // Cannot use Content-Length with non-identity Transfer-Encoding. if w.chunking { - w.header["Content-Length"] = "", false + w.SetHeader("Content-Length", "") } if !w.req.ProtoAtLeast(1, 0) { return @@ -259,15 +314,6 @@ func (w *response) Write(data []byte) (n int, err os.Error) { return 0, ErrHijacked } if !w.wroteHeader { - if w.req.wantsHttp10KeepAlive() { - _, hasLength := w.header["Content-Length"] - if hasLength { - _, connectionHeaderSet := w.header["Connection"] - if !connectionHeaderSet { - w.header["Connection"] = "keep-alive" - } - } - } w.WriteHeader(StatusOK) } if len(data) == 0 { @@ -280,6 +326,9 @@ func (w *response) Write(data []byte) (n int, err os.Error) { } w.written += int64(len(data)) // ignoring errors, for errorKludge + if w.contentLength != -1 && w.written > w.contentLength { + return 0, ErrContentLength + } // TODO(rsc): if chunking happened after the buffering, // then there would be fewer chunk headers. @@ -369,6 +418,11 @@ func (w *response) finishRequest() { } w.conn.buf.Flush() w.req.Body.Close() + + if w.contentLength != -1 && w.contentLength != w.written { + // Did not write enough. Avoid getting out of sync. + w.closeAfterReply = true + } } // Flush implements the ResponseWriter.Flush method.