diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index 00d1c58cf0..bb94e5ffea 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -469,10 +469,7 @@ func testCancelRequestMidBody(t *testing.T, h2 bool) { // Tests that clients can send trailers to a server and that the server can read them. func TestTrailersClientToServer_h1(t *testing.T) { testTrailersClientToServer(t, h1Mode) } -func TestTrailersClientToServer_h2(t *testing.T) { - t.Skip("skipping in http2 mode; golang.org/issue/13557") - testTrailersClientToServer(t, h2Mode) -} +func TestTrailersClientToServer_h2(t *testing.T) { testTrailersClientToServer(t, h2Mode) } func testTrailersClientToServer(t *testing.T, h2 bool) { defer afterTest(t) diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go index b793f18416..5e4b9c0141 100644 --- a/src/net/http/h2_bundle.go +++ b/src/net/http/h2_bundle.go @@ -4425,9 +4425,32 @@ func (cc *http2ClientConn) putFrameScratchBuffer(buf []byte) { // exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests. var http2errRequestCanceled = errors.New("net/http: request canceled") -func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { - cc.mu.Lock() +func http2commaSeparatedTrailers(req *Request) (string, error) { + keys := make([]string, 0, len(req.Trailer)) + for k := range req.Trailer { + k = CanonicalHeaderKey(k) + switch k { + case "Transfer-Encoding", "Trailer", "Content-Length": + return "", &http2badStringError{"invalid Trailer key", k} + } + keys = append(keys, k) + } + if len(keys) > 0 { + sort.Strings(keys) + return strings.Join(keys, ","), nil + } + return "", nil +} + +func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { + trailers, err := http2commaSeparatedTrailers(req) + if err != nil { + return nil, err + } + hasTrailers := trailers != "" + + cc.mu.Lock() if cc.closed || !cc.canTakeNewRequestLocked() { cc.mu.Unlock() return nil, http2errClientConnUnusable @@ -4445,33 +4468,10 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { cs.requestedGzip = true } - hdrs := cc.encodeHeaders(req, cs.requestedGzip) - first := true - + hdrs := cc.encodeHeaders(req, cs.requestedGzip, trailers) cc.wmu.Lock() - frameSize := int(cc.maxFrameSize) - for len(hdrs) > 0 && cc.werr == nil { - chunk := hdrs - if len(chunk) > frameSize { - chunk = chunk[:frameSize] - } - hdrs = hdrs[len(chunk):] - endHeaders := len(hdrs) == 0 - if first { - cc.fr.WriteHeaders(http2HeadersFrameParam{ - StreamID: cs.ID, - BlockFragment: chunk, - EndStream: !hasBody, - EndHeaders: endHeaders, - }) - first = false - } else { - cc.fr.WriteContinuation(cs.ID, endHeaders, chunk) - } - } - - cc.bw.Flush() - werr := cc.werr + endStream := !hasBody && !hasTrailers + werr := cc.writeHeaders(cs.ID, endStream, hdrs) cc.wmu.Unlock() cc.mu.Unlock() @@ -4514,6 +4514,34 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { } } +// requires cc.wmu be held +func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []byte) error { + first := true + frameSize := int(cc.maxFrameSize) + for len(hdrs) > 0 && cc.werr == nil { + chunk := hdrs + if len(chunk) > frameSize { + chunk = chunk[:frameSize] + } + hdrs = hdrs[len(chunk):] + endHeaders := len(hdrs) == 0 + if first { + cc.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: streamID, + BlockFragment: chunk, + EndStream: endStream, + EndHeaders: endHeaders, + }) + first = false + } else { + cc.fr.WriteContinuation(streamID, endHeaders, chunk) + } + } + + cc.bw.Flush() + return cc.werr +} + // errAbortReqBodyWrite is an internal error value. // It doesn't escape to callers. var http2errAbortReqBodyWrite = errors.New("http2: aborting request body write") @@ -4532,6 +4560,9 @@ func (cs *http2clientStream) writeRequestBody(body io.ReadCloser) (err error) { } }() + req := cs.req + hasTrailers := req.Trailer != nil + var sawEOF bool for !sawEOF { n, err := body.Read(buf) @@ -4552,7 +4583,7 @@ func (cs *http2clientStream) writeRequestBody(body io.ReadCloser) (err error) { cc.wmu.Lock() data := remain[:allowed] remain = remain[allowed:] - sentEnd = sawEOF && len(remain) == 0 + sentEnd = sawEOF && len(remain) == 0 && !hasTrailers err = cc.fr.WriteData(cs.ID, sentEnd, data) if err == nil { @@ -4567,7 +4598,18 @@ func (cs *http2clientStream) writeRequestBody(body io.ReadCloser) (err error) { cc.wmu.Lock() if !sentEnd { - err = cc.fr.WriteData(cs.ID, true, nil) + var trls []byte + if hasTrailers { + cc.mu.Lock() + trls = cc.encodeTrailers(req) + cc.mu.Unlock() + } + + if len(trls) > 0 { + err = cc.writeHeaders(cs.ID, true, trls) + } else { + err = cc.fr.WriteData(cs.ID, true, nil) + } } if ferr := cc.bw.Flush(); ferr != nil && err == nil { err = ferr @@ -4611,8 +4653,15 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er } } +type http2badStringError struct { + what string + str string +} + +func (e *http2badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) } + // requires cc.mu be held. -func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool) []byte { +func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string) []byte { cc.hbuf.Reset() host := req.Host @@ -4624,6 +4673,9 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool) []byt cc.writeHeader(":method", req.Method) cc.writeHeader(":path", req.URL.RequestURI()) cc.writeHeader(":scheme", "https") + if trailers != "" { + cc.writeHeader("trailer", trailers) + } for k, vv := range req.Header { lowKey := strings.ToLower(k) @@ -4640,6 +4692,19 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool) []byt return cc.hbuf.Bytes() } +// requires cc.mu be held. +func (cc *http2ClientConn) encodeTrailers(req *Request) []byte { + cc.hbuf.Reset() + for k, vv := range req.Trailer { + + lowKey := strings.ToLower(k) + for _, v := range vv { + cc.writeHeader(lowKey, v) + } + } + return cc.hbuf.Bytes() +} + func (cc *http2ClientConn) writeHeader(name, value string) { cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) }