1
0
mirror of https://github.com/golang/go synced 2024-11-20 04:44:40 -07:00

net/http: populate ContentLength in HEAD responses

Also fixes a necessary TODO in the process.

Fixes #4126

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/6869053
This commit is contained in:
Brad Fitzpatrick 2012-12-05 22:36:23 -08:00
parent 755e13877f
commit 53d091c5ff
7 changed files with 73 additions and 17 deletions

View File

@ -527,3 +527,38 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) {
t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err) t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
} }
} }
// Verify Response.ContentLength is populated. http://golang.org/issue/4126
func TestClientHeadContentLength(t *testing.T) {
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if v := r.FormValue("cl"); v != "" {
w.Header().Set("Content-Length", v)
}
}))
defer ts.Close()
tests := []struct {
suffix string
want int64
}{
{"/?cl=1234", 1234},
{"/?cl=0", 0},
{"", -1},
}
for _, tt := range tests {
req, _ := NewRequest("HEAD", ts.URL+tt.suffix, nil)
res, err := DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
if res.ContentLength != tt.want {
t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want)
}
bs, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if len(bs) != 0 {
t.Errorf("Unexpected content: %q", bs)
}
}
}

View File

@ -49,7 +49,7 @@ type Response struct {
Body io.ReadCloser Body io.ReadCloser
// ContentLength records the length of the associated content. The // ContentLength records the length of the associated content. The
// value -1 indicates that the length is unknown. Unless RequestMethod // value -1 indicates that the length is unknown. Unless Request.Method
// is "HEAD", values >= 0 indicate that the given number of bytes may // is "HEAD", values >= 0 indicate that the given number of bytes may
// be read from Body. // be read from Body.
ContentLength int64 ContentLength int64
@ -178,7 +178,7 @@ func (r *Response) ProtoAtLeast(major, minor int) bool {
// StatusCode // StatusCode
// ProtoMajor // ProtoMajor
// ProtoMinor // ProtoMinor
// RequestMethod // Request.Method
// TransferEncoding // TransferEncoding
// Trailer // Trailer
// Body // Body

View File

@ -193,7 +193,7 @@ var respTests = []respTest{
Request: dummyReq("HEAD"), Request: dummyReq("HEAD"),
Header: Header{}, Header: Header{},
Close: true, Close: true,
ContentLength: 0, ContentLength: -1,
}, },
"", "",

View File

@ -614,7 +614,7 @@ func (w *response) finishRequest() {
// HTTP/1.0 clients keep their "keep-alive" connections alive, and for // HTTP/1.0 clients keep their "keep-alive" connections alive, and for
// HTTP/1.1 clients is just as good as the alternative: sending a // HTTP/1.1 clients is just as good as the alternative: sending a
// chunked response and immediately sending the zero-length EOF chunk. // chunked response and immediately sending the zero-length EOF chunk.
if w.written == 0 && w.header.get("Content-Length") == "" { if w.written == 0 && w.header.get("Content-Length") == "" && w.req.Method != "HEAD" {
w.header.Set("Content-Length", "0") w.header.Set("Content-Length", "0")
} }
// If this was an HTTP/1.0 request with keep-alive and we sent a // If this was an HTTP/1.0 request with keep-alive and we sent a

View File

@ -294,10 +294,19 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
return err return err
} }
t.ContentLength, err = fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding) realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding)
if err != nil { if err != nil {
return err return err
} }
if isResponse && t.RequestMethod == "HEAD" {
if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil {
return err
} else {
t.ContentLength = n
}
} else {
t.ContentLength = realLength
}
// Trailer // Trailer
t.Trailer, err = fixTrailer(t.Header, t.TransferEncoding) t.Trailer, err = fixTrailer(t.Header, t.TransferEncoding)
@ -310,7 +319,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
// See RFC2616, section 4.4. // See RFC2616, section 4.4.
switch msg.(type) { switch msg.(type) {
case *Response: case *Response:
if t.ContentLength == -1 && if realLength == -1 &&
!chunked(t.TransferEncoding) && !chunked(t.TransferEncoding) &&
bodyAllowedForStatus(t.StatusCode) { bodyAllowedForStatus(t.StatusCode) {
// Unbounded body. // Unbounded body.
@ -323,11 +332,11 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
switch { switch {
case chunked(t.TransferEncoding): case chunked(t.TransferEncoding):
t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
case t.ContentLength >= 0: case realLength >= 0:
// TODO: limit the Content-Length. This is an easy DoS vector. // TODO: limit the Content-Length. This is an easy DoS vector.
t.Body = &body{Reader: io.LimitReader(r, t.ContentLength), closing: t.Close} t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close}
default: default:
// t.ContentLength < 0, i.e. "Content-Length" not mentioned in header // realLength < 0, i.e. "Content-Length" not mentioned in header
if t.Close { if t.Close {
// Close semantics (i.e. HTTP/1.0) // Close semantics (i.e. HTTP/1.0)
t.Body = &body{Reader: r, closing: t.Close} t.Body = &body{Reader: r, closing: t.Close}
@ -434,9 +443,9 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header,
// Logic based on Content-Length // Logic based on Content-Length
cl := strings.TrimSpace(header.get("Content-Length")) cl := strings.TrimSpace(header.get("Content-Length"))
if cl != "" { if cl != "" {
n, err := strconv.ParseInt(cl, 10, 64) n, err := parseContentLength(cl)
if err != nil || n < 0 { if err != nil {
return -1, &badStringError{"bad Content-Length", cl} return -1, err
} }
return n, nil return n, nil
} else { } else {
@ -641,3 +650,18 @@ func (b *body) Close() error {
} }
return nil return nil
} }
// parseContentLength trims whitespace from s and returns -1 if no value
// is set, or the value if it's >= 0.
func parseContentLength(cl string) (int64, error) {
cl = strings.TrimSpace(cl)
if cl == "" {
return -1, nil
}
n, err := strconv.ParseInt(cl, 10, 64)
if err != nil || n < 0 {
return 0, &badStringError{"bad Content-Length", cl}
}
return n, nil
}

View File

@ -604,10 +604,7 @@ func (pc *persistConn) readLoop() {
alive = false alive = false
} }
// TODO(bradfitz): this hasBody conflicts with the defition hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0
// above which excludes HEAD requests. Is this one
// incomplete?
hasBody := resp != nil && resp.ContentLength != 0
var waitForBodyRead chan bool var waitForBodyRead chan bool
if hasBody { if hasBody {
lastbody = resp.Body lastbody = resp.Body

View File

@ -464,7 +464,7 @@ func TestTransportHeadResponses(t *testing.T) {
if e, g := "123", res.Header.Get("Content-Length"); e != g { if e, g := "123", res.Header.Get("Content-Length"); e != g {
t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
} }
if e, g := int64(0), res.ContentLength; e != g { if e, g := int64(123), res.ContentLength; e != g {
t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
} }
} }