1
0
mirror of https://github.com/golang/go synced 2024-11-22 03:34:40 -07:00

http: fix chunking bug during content sniffing

R=golang-dev, bradfitz, gri
CC=golang-dev
https://golang.org/cl/4807044
This commit is contained in:
Russ Cox 2011-07-21 14:29:14 -04:00
parent 22853098a9
commit 301d8a6d4a
3 changed files with 87 additions and 16 deletions

View File

@ -9,6 +9,7 @@ package httptest
import ( import (
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"flag"
"fmt" "fmt"
"http" "http"
"net" "net"
@ -49,15 +50,34 @@ func newLocalListener() net.Listener {
return l return l
} }
// When debugging a particular http server-based test,
// this flag lets you run
// gotest -run=BrokenTest -httptest.serve=127.0.0.1:8000
// to start the broken server so you can interact with it manually.
var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks")
// NewServer starts and returns a new Server. // NewServer starts and returns a new Server.
// The caller should call Close when finished, to shut it down. // The caller should call Close when finished, to shut it down.
func NewServer(handler http.Handler) *Server { func NewServer(handler http.Handler) *Server {
ts := new(Server) ts := new(Server)
l := newLocalListener() var l net.Listener
if *serve != "" {
var err os.Error
l, err = net.Listen("tcp", *serve)
if err != nil {
panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err))
}
} else {
l = newLocalListener()
}
ts.Listener = &historyListener{l, make([]net.Conn, 0)} ts.Listener = &historyListener{l, make([]net.Conn, 0)}
ts.URL = "http://" + l.Addr().String() ts.URL = "http://" + l.Addr().String()
server := &http.Server{Handler: handler} server := &http.Server{Handler: handler}
go server.Serve(ts.Listener) go server.Serve(ts.Listener)
if *serve != "" {
fmt.Println(os.Stderr, "httptest: serving on", ts.URL)
select {}
}
return ts return ts
} }

View File

@ -255,9 +255,7 @@ func (w *response) WriteHeader(code int) {
} else { } else {
// If no content type, apply sniffing algorithm to body. // If no content type, apply sniffing algorithm to body.
if w.header.Get("Content-Type") == "" { if w.header.Get("Content-Type") == "" {
// NOTE(dsymonds): the sniffing mechanism in this file is currently broken. w.needSniff = true
//w.needSniff = true
w.header.Set("Content-Type", "text/html; charset=utf-8")
} }
} }
@ -364,10 +362,16 @@ func (w *response) sniff() {
fmt.Fprintf(w.conn.buf, "Content-Type: %s\r\n", DetectContentType(data)) fmt.Fprintf(w.conn.buf, "Content-Type: %s\r\n", DetectContentType(data))
io.WriteString(w.conn.buf, "\r\n") io.WriteString(w.conn.buf, "\r\n")
if w.chunking && len(data) > 0 { if len(data) == 0 {
return
}
if w.chunking {
fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) fmt.Fprintf(w.conn.buf, "%x\r\n", len(data))
} }
w.conn.buf.Write(data) _, err := w.conn.buf.Write(data)
if w.chunking && err == nil {
io.WriteString(w.conn.buf, "\r\n")
}
} }
// bodyAllowed returns true if a Write is allowed for this response type. // bodyAllowed returns true if a Write is allowed for this response type.
@ -401,12 +405,23 @@ func (w *response) Write(data []byte) (n int, err os.Error) {
var m int var m int
if w.needSniff { if w.needSniff {
// We need to sniff the beginning of the output to
// determine the content type. Accumulate the
// initial writes in w.conn.body.
body := w.conn.body body := w.conn.body
m = copy(body[len(body):], data) m = copy(body[len(body):cap(body)], data)
w.conn.body = body[:len(body)+m] w.conn.body = body[:len(body)+m]
if m == len(data) { if m == len(data) {
// Copied everything into the buffer.
// Wait for next write.
return m, nil return m, nil
} }
// Filled the buffer; more data remains.
// Sniff the content (flushes the buffer)
// and then proceed with the remainder
// of the data as a normal Write.
// Calling sniff clears needSniff.
w.sniff() w.sniff()
data = data[m:] data = data[m:]
} }

View File

@ -2,16 +2,22 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package http package http_test
import ( import (
"bytes"
. "http"
"http/httptest"
"io/ioutil"
"log"
"strconv"
"testing" "testing"
) )
var sniffTests = []struct { var sniffTests = []struct {
desc string desc string
data []byte data []byte
exp string contentType string
}{ }{
// Some nonsense. // Some nonsense.
{"Empty", []byte{}, "text/plain; charset=utf-8"}, {"Empty", []byte{}, "text/plain; charset=utf-8"},
@ -30,11 +36,41 @@ var sniffTests = []struct {
{"GIF 89a", []byte(`GIF89a...`), "image/gif"}, {"GIF 89a", []byte(`GIF89a...`), "image/gif"},
} }
func TestSniffing(t *testing.T) { func TestDetectContentType(t *testing.T) {
for _, st := range sniffTests { for _, tt := range sniffTests {
got := DetectContentType(st.data) ct := DetectContentType(tt.data)
if got != st.exp { if ct != tt.contentType {
t.Errorf("%v: sniffed as %v, want %v", st.desc, got, st.exp) t.Errorf("%v: DetectContentType = %q, want %q", tt.desc, ct, tt.contentType)
} }
} }
} }
func TestServerContentType(t *testing.T) {
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
i, _ := strconv.Atoi(r.FormValue("i"))
tt := sniffTests[i]
n, err := w.Write(tt.data)
if n != len(tt.data) || err != nil {
log.Fatalf("%v: Write(%q) = %v, %v want %d, nil", tt.desc, tt.data, n, err, len(tt.data))
}
}))
defer ts.Close()
for i, tt := range sniffTests {
resp, err := Get(ts.URL + "/?i=" + strconv.Itoa(i))
if err != nil {
t.Errorf("%v: %v", tt.desc, err)
continue
}
if ct := resp.Header.Get("Content-Type"); ct != tt.contentType {
t.Errorf("%v: Content-Type = %q, want %q", tt.desc, ct, tt.contentType)
}
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("%v: reading body: %v", tt.desc, err)
} else if !bytes.Equal(data, tt.data) {
t.Errorf("%v: data is %q, want %q", tt.desc, data, tt.data)
}
resp.Body.Close()
}
}