diff --git a/src/pkg/http/httptest/server.go b/src/pkg/http/httptest/server.go index 879f04f33c2..2ec36d04cf6 100644 --- a/src/pkg/http/httptest/server.go +++ b/src/pkg/http/httptest/server.go @@ -9,6 +9,7 @@ package httptest import ( "crypto/rand" "crypto/tls" + "flag" "fmt" "http" "net" @@ -49,15 +50,34 @@ func newLocalListener() net.Listener { return l } +// When debugging a particular http server-based test, +// this flag lets you run +// gotest -run=BrokenTest -httptest.serve=127.0.0.1:8000 +// to start the broken server so you can interact with it manually. +var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks") + // NewServer starts and returns a new Server. // The caller should call Close when finished, to shut it down. func NewServer(handler http.Handler) *Server { ts := new(Server) - l := newLocalListener() + var l net.Listener + if *serve != "" { + var err os.Error + l, err = net.Listen("tcp", *serve) + if err != nil { + panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err)) + } + } else { + l = newLocalListener() + } ts.Listener = &historyListener{l, make([]net.Conn, 0)} ts.URL = "http://" + l.Addr().String() server := &http.Server{Handler: handler} go server.Serve(ts.Listener) + if *serve != "" { + fmt.Println(os.Stderr, "httptest: serving on", ts.URL) + select {} + } return ts } diff --git a/src/pkg/http/server.go b/src/pkg/http/server.go index b3fb8e101c3..f14ef8c04b9 100644 --- a/src/pkg/http/server.go +++ b/src/pkg/http/server.go @@ -255,9 +255,7 @@ func (w *response) WriteHeader(code int) { } else { // If no content type, apply sniffing algorithm to body. if w.header.Get("Content-Type") == "" { - // NOTE(dsymonds): the sniffing mechanism in this file is currently broken. - //w.needSniff = true - w.header.Set("Content-Type", "text/html; charset=utf-8") + w.needSniff = true } } @@ -364,10 +362,16 @@ func (w *response) sniff() { fmt.Fprintf(w.conn.buf, "Content-Type: %s\r\n", DetectContentType(data)) io.WriteString(w.conn.buf, "\r\n") - if w.chunking && len(data) > 0 { + if len(data) == 0 { + return + } + if w.chunking { fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) } - w.conn.buf.Write(data) + _, err := w.conn.buf.Write(data) + if w.chunking && err == nil { + io.WriteString(w.conn.buf, "\r\n") + } } // bodyAllowed returns true if a Write is allowed for this response type. @@ -401,12 +405,23 @@ func (w *response) Write(data []byte) (n int, err os.Error) { var m int if w.needSniff { + // We need to sniff the beginning of the output to + // determine the content type. Accumulate the + // initial writes in w.conn.body. body := w.conn.body - m = copy(body[len(body):], data) + m = copy(body[len(body):cap(body)], data) w.conn.body = body[:len(body)+m] if m == len(data) { + // Copied everything into the buffer. + // Wait for next write. return m, nil } + + // Filled the buffer; more data remains. + // Sniff the content (flushes the buffer) + // and then proceed with the remainder + // of the data as a normal Write. + // Calling sniff clears needSniff. w.sniff() data = data[m:] } diff --git a/src/pkg/http/sniff_test.go b/src/pkg/http/sniff_test.go index 770496f4051..2d01807f695 100644 --- a/src/pkg/http/sniff_test.go +++ b/src/pkg/http/sniff_test.go @@ -2,16 +2,22 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http +package http_test import ( + "bytes" + . "http" + "http/httptest" + "io/ioutil" + "log" + "strconv" "testing" ) var sniffTests = []struct { - desc string - data []byte - exp string + desc string + data []byte + contentType string }{ // Some nonsense. {"Empty", []byte{}, "text/plain; charset=utf-8"}, @@ -30,11 +36,41 @@ var sniffTests = []struct { {"GIF 89a", []byte(`GIF89a...`), "image/gif"}, } -func TestSniffing(t *testing.T) { - for _, st := range sniffTests { - got := DetectContentType(st.data) - if got != st.exp { - t.Errorf("%v: sniffed as %v, want %v", st.desc, got, st.exp) +func TestDetectContentType(t *testing.T) { + for _, tt := range sniffTests { + ct := DetectContentType(tt.data) + if ct != tt.contentType { + t.Errorf("%v: DetectContentType = %q, want %q", tt.desc, ct, tt.contentType) } } } + +func TestServerContentType(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + i, _ := strconv.Atoi(r.FormValue("i")) + tt := sniffTests[i] + n, err := w.Write(tt.data) + if n != len(tt.data) || err != nil { + log.Fatalf("%v: Write(%q) = %v, %v want %d, nil", tt.desc, tt.data, n, err, len(tt.data)) + } + })) + defer ts.Close() + + for i, tt := range sniffTests { + resp, err := Get(ts.URL + "/?i=" + strconv.Itoa(i)) + if err != nil { + t.Errorf("%v: %v", tt.desc, err) + continue + } + if ct := resp.Header.Get("Content-Type"); ct != tt.contentType { + t.Errorf("%v: Content-Type = %q, want %q", tt.desc, ct, tt.contentType) + } + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("%v: reading body: %v", tt.desc, err) + } else if !bytes.Equal(data, tt.data) { + t.Errorf("%v: data is %q, want %q", tt.desc, data, tt.data) + } + resp.Body.Close() + } +}