diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go index 9be5ebfa4f..9a722e752a 100644 --- a/src/net/http/h2_bundle.go +++ b/src/net/http/h2_bundle.go @@ -1,12 +1,16 @@ -// This file is autogenerated using x/tools/cmd/bundle from -// https://go-review.googlesource.com/#/c/15850/ -// Usage: -// $ bundle golang.org/x/net/http2 net/http http2 > /tmp/x.go; mv /tmp/x.go $GOROOT/src/net/http/h2_bundle.go - -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. +// Code generated by golang.org/x/tools/cmd/bundle command: +// $ bundle golang.org/x/net/http2 net/http http2 +// Package http2 implements the HTTP/2 protocol. +// +// This is a work in progress. This package is low-level and intended +// to be used directly by very few people. Most users will use it +// indirectly through integration with the net/http package. See +// ConfigureServer. That ConfigureServer call will likely be automatic +// or available via an empty import in the future. +// +// See http://http2.github.io/ +// package http import ( @@ -16,7 +20,9 @@ import ( "encoding/binary" "errors" "fmt" + "golang.org/x/net/http2/hpack" "io" + "io/ioutil" "log" "net" "net/url" @@ -26,74 +32,8 @@ import ( "strings" "sync" "time" - - "golang.org/x/net/http2/hpack" ) -// buffer is an io.ReadWriteCloser backed by a fixed size buffer. -// It never allocates, but moves old data as new data is written. -type http2buffer struct { - buf []byte - r, w int - closed bool - err error // err to return to reader -} - -var ( - http2errReadEmpty = errors.New("read from empty buffer") - http2errWriteClosed = errors.New("write on closed buffer") - http2errWriteFull = errors.New("write on full buffer") -) - -// Read copies bytes from the buffer into p. -// It is an error to read when no data is available. -func (b *http2buffer) Read(p []byte) (n int, err error) { - n = copy(p, b.buf[b.r:b.w]) - b.r += n - if b.closed && b.r == b.w { - err = b.err - } else if b.r == b.w && n == 0 { - err = http2errReadEmpty - } - return n, err -} - -// Len returns the number of bytes of the unread portion of the buffer. -func (b *http2buffer) Len() int { - return b.w - b.r -} - -// Write copies bytes from p into the buffer. -// It is an error to write more data than the buffer can hold. -func (b *http2buffer) Write(p []byte) (n int, err error) { - if b.closed { - return 0, http2errWriteClosed - } - - if b.r > 0 && len(p) > len(b.buf)-b.w { - copy(b.buf, b.buf[b.r:b.w]) - b.w -= b.r - b.r = 0 - } - - n = copy(b.buf[b.w:], p) - b.w += n - if n < len(p) { - err = http2errWriteFull - } - return n, err -} - -// Close marks the buffer as closed. Future calls to Write will -// return an error. Future calls to Read, once the buffer is -// empty, will return err. -func (b *http2buffer) Close(err error) { - if !b.closed { - b.closed = true - b.err = err - } -} - // An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec. type http2ErrCode uint32 @@ -166,6 +106,56 @@ type http2goAwayFlowError struct{} func (http2goAwayFlowError) Error() string { return "connection exceeded flow control window size" } +// fixedBuffer is an io.ReadWriter backed by a fixed size buffer. +// It never allocates, but moves old data as new data is written. +type http2fixedBuffer struct { + buf []byte + r, w int +} + +var ( + http2errReadEmpty = errors.New("read from empty fixedBuffer") + http2errWriteFull = errors.New("write on full fixedBuffer") +) + +// Read copies bytes from the buffer into p. +// It is an error to read when no data is available. +func (b *http2fixedBuffer) Read(p []byte) (n int, err error) { + if b.r == b.w { + return 0, http2errReadEmpty + } + n = copy(p, b.buf[b.r:b.w]) + b.r += n + if b.r == b.w { + b.r = 0 + b.w = 0 + } + return n, nil +} + +// Len returns the number of bytes of the unread portion of the buffer. +func (b *http2fixedBuffer) Len() int { + return b.w - b.r +} + +// Write copies bytes from p into the buffer. +// It is an error to write more data than the buffer can hold. +func (b *http2fixedBuffer) Write(p []byte) (n int, err error) { + + if b.r > 0 && len(p) > len(b.buf)-b.w { + copy(b.buf, b.buf[b.r:b.w]) + b.w -= b.r + b.r = 0 + } + + n = copy(b.buf[b.w:], p) + b.w += n + if n < len(p) { + err = http2errWriteFull + } + return n, err +} + // flow is the flow control window's size. type http2flow struct { // n is the number of DATA bytes we're allowed to send. @@ -834,7 +824,7 @@ func http2parseUnknownFrame(fh http2FrameHeader, p []byte) (http2Frame, error) { // See http://http2.github.io/http2-spec/#rfc.section.6.9 type http2WindowUpdateFrame struct { http2FrameHeader - Increment uint32 + Increment uint32 // never read with high bit set } func http2parseWindowUpdateFrame(fh http2FrameHeader, p []byte) (http2Frame, error) { @@ -1707,37 +1697,90 @@ func (w *http2bufferedWriter) Flush() error { return err } +func http2mustUint31(v int32) uint32 { + if v < 0 || v > 2147483647 { + panic("out of range") + } + return uint32(v) +} + +// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like +// io.Pipe except there are no PipeReader/PipeWriter halves, and the +// underlying buffer is an interface. (io.Pipe is always unbuffered) type http2pipe struct { - b http2buffer - c sync.Cond - m sync.Mutex + mu sync.Mutex + c sync.Cond // c.L must point to + b http2pipeBuffer + err error // read error once empty. non-nil means closed. +} + +type http2pipeBuffer interface { + Len() int + io.Writer + io.Reader } // Read waits until data is available and copies bytes // from the buffer into p. -func (r *http2pipe) Read(p []byte) (n int, err error) { - r.c.L.Lock() - defer r.c.L.Unlock() - for r.b.Len() == 0 && !r.b.closed { - r.c.Wait() +func (p *http2pipe) Read(d []byte) (n int, err error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.c.L == nil { + p.c.L = &p.mu + } + for { + if p.b.Len() > 0 { + return p.b.Read(d) + } + if p.err != nil { + return 0, p.err + } + p.c.Wait() } - return r.b.Read(p) } +var http2errClosedPipeWrite = errors.New("write on closed buffer") + // Write copies bytes from p into the buffer and wakes a reader. // It is an error to write more data than the buffer can hold. -func (w *http2pipe) Write(p []byte) (n int, err error) { - w.c.L.Lock() - defer w.c.L.Unlock() - defer w.c.Signal() - return w.b.Write(p) +func (p *http2pipe) Write(d []byte) (n int, err error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.c.L == nil { + p.c.L = &p.mu + } + defer p.c.Signal() + if p.err != nil { + return 0, http2errClosedPipeWrite + } + return p.b.Write(d) } -func (c *http2pipe) Close(err error) { - c.c.L.Lock() - defer c.c.L.Unlock() - defer c.c.Signal() - c.b.Close(err) +// CloseWithError causes Reads to wake up and return the +// provided err after all data has been read. +// +// The error must be non-nil. +func (p *http2pipe) CloseWithError(err error) { + if err == nil { + panic("CloseWithError must be non-nil") + } + p.mu.Lock() + defer p.mu.Unlock() + if p.c.L == nil { + p.c.L = &p.mu + } + defer p.c.Signal() + if p.err == nil { + p.err = err + } +} + +// Err returns the error (if any) first set with CloseWithError. +// This is the error which will be returned after the reader is exhausted. +func (p *http2pipe) Err() error { + p.mu.Lock() + defer p.mu.Unlock() + return p.err } const ( @@ -1750,6 +1793,7 @@ const ( var ( http2errClientDisconnected = errors.New("client disconnected") http2errClosedBody = errors.New("body closed by handler") + http2errHandlerComplete = errors.New("http2: request body closed due to handler exiting") http2errStreamClosed = errors.New("http2: stream closed") ) @@ -2361,37 +2405,43 @@ var http2errChanPool = sync.Pool{ New: func() interface{} { return make(chan error, 1) }, } -// writeDataFromHandler writes the data described in req to stream.id. -// -// The flow control currently happens in the Handler where it waits -// for 1 or more bytes to be available to then write here. So at this -// point we know that we have flow control. But this might have to -// change when priority is implemented, so the serve goroutine knows -// the total amount of bytes waiting to be sent and can can have more -// scheduling decisions available. -func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, writeData *http2writeData) error { +var http2writeDataPool = sync.Pool{ + New: func() interface{} { return new(http2writeData) }, +} + +// writeDataFromHandler writes DATA response frames from a handler on +// the given stream. +func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte, endStream bool) error { ch := http2errChanPool.Get().(chan error) + writeArg := http2writeDataPool.Get().(*http2writeData) + *writeArg = http2writeData{stream.id, data, endStream} err := sc.writeFrameFromHandler(http2frameWriteMsg{ - write: writeData, + write: writeArg, stream: stream, done: ch, }) if err != nil { return err } + var frameWriteDone bool // the frame write is done (successfully or not) select { case err = <-ch: + frameWriteDone = true case <-sc.doneServing: return http2errClientDisconnected case <-stream.cw: select { case err = <-ch: + frameWriteDone = true default: return http2errStreamClosed } } http2errChanPool.Put(ch) + if frameWriteDone { + http2writeDataPool.Put(writeArg) + } return err } @@ -2490,7 +2540,7 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { errCancel := http2StreamError{st.id, http2ErrCodeCancel} sc.resetStream(errCancel) case http2stateHalfClosedRemote: - sc.closeStream(st, nil) + sc.closeStream(st, http2errHandlerComplete) } } @@ -2738,7 +2788,7 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { } delete(sc.streams, st.id) if p := st.body; p != nil { - p.Close(err) + p.CloseWithError(err) } st.cw.Close() sc.writeSched.forgetStream(st.id) @@ -2818,7 +2868,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { data := f.Data() if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { - st.body.Close(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) + st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) return http2StreamError{id, http2ErrCodeStreamClosed} } if len(data) > 0 { @@ -2838,10 +2888,10 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { } if f.StreamEnded() { if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes { - st.body.Close(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes", + st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes", st.declBodyBytes, st.bodyBytes)) } else { - st.body.Close(io.EOF) + st.body.CloseWithError(io.EOF) } st.state = http2stateHalfClosedRemote } @@ -3034,9 +3084,8 @@ func (sc *http2serverConn) newWriterAndRequest() (*http2responseWriter, *Request } if bodyOpen { body.pipe = &http2pipe{ - b: http2buffer{buf: make([]byte, http2initialWindowSize)}, + b: &http2fixedBuffer{buf: make([]byte, http2initialWindowSize)}, } - body.pipe.c.L = &body.pipe.m if vv, ok := rp.header["Content-Length"]; ok { req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) @@ -3126,7 +3175,10 @@ type http2bodyReadMsg struct { // and schedules flow control tokens to be sent. func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int) { sc.serveG.checkNotOn() - sc.bodyReadCh <- http2bodyReadMsg{st, n} + select { + case sc.bodyReadCh <- http2bodyReadMsg{st, n}: + case <-sc.doneServing: + } } func (sc *http2serverConn) noteBodyRead(st *http2stream, n int) { @@ -3192,7 +3244,7 @@ type http2requestBody struct { func (b *http2requestBody) Close() error { if b.pipe != nil { - b.pipe.Close(http2errClosedBody) + b.pipe.CloseWithError(http2errClosedBody) } b.closed = true return nil @@ -3247,7 +3299,6 @@ type http2responseWriterState struct { wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet. sentHeader bool // have we sent the header frame? handlerDone bool // handler has finished - curWrite http2writeData closeNotifierMu sync.Mutex // guards closeNotifierCh closeNotifierCh chan bool // nil until first used @@ -3296,11 +3347,7 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { return 0, nil } - curWrite := &rws.curWrite - curWrite.streamID = rws.stream.id - curWrite.p = p - curWrite.endStream = rws.handlerDone - if err := rws.conn.writeDataFromHandler(rws.stream, curWrite); err != nil { + if err := rws.conn.writeDataFromHandler(rws.stream, p, rws.handlerDone); err != nil { return 0, err } return len(p), nil @@ -3423,34 +3470,67 @@ func (w *http2responseWriter) handlerDone() { http2responseWriterStatePool.Put(rws) } +const ( + // transportDefaultConnFlow is how many connection-level flow control + // tokens we give the server at start-up, past the default 64k. + http2transportDefaultConnFlow = 1 << 30 + + // transportDefaultStreamFlow is how many stream-level flow + // control tokens we announce to the peer, and how many bytes + // we buffer per stream. + http2transportDefaultStreamFlow = 4 << 20 + + // transportDefaultStreamMinRefresh is the minimum number of bytes we'll send + // a stream-level WINDOW_UPDATE for at a time. + http2transportDefaultStreamMinRefresh = 4 << 10 +) + +// Transport is an HTTP/2 Transport. +// +// A Transport internally caches connections to servers. It is safe +// for concurrent use by multiple goroutines. type http2Transport struct { - Fallback RoundTripper + // DialTLS specifies an optional dial function for creating + // TLS connections for requests. + // + // If DialTLS is nil, tls.Dial is used. + // + // If the returned net.Conn has a ConnectionState method like tls.Conn, + // it will be used to set http.Response.TLS. + DialTLS func(network, addr string, cfg *tls.Config) (net.Conn, error) - // TODO: remove this and make more general with a TLS dial hook, like http - InsecureTLSDial bool + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. If nil, the default configuration is used. + TLSClientConfig *tls.Config + // TODO: switch to RWMutex + // TODO: add support for sharing conns based on cert names + // (e.g. share conn for googleapis.com and appspot.com) connMu sync.Mutex conns map[string][]*http2clientConn // key is host:port } +// clientConn is the state of a single HTTP/2 client connection to an +// HTTP/2 server. type http2clientConn struct { t *http2Transport - tconn *tls.Conn + tconn net.Conn tlsState *tls.ConnectionState connKey []string // key(s) this connection is cached in, in t.conns + // readLoop goroutine fields: readerDone chan struct{} // closed on error readerErr error // set before readerDone is closed - hdec *hpack.Decoder - nextRes *Response - mu sync.Mutex + mu sync.Mutex // guards following + cond *sync.Cond // hold mu; broadcast on flow/closed changes + flow http2flow // our conn-level flow control quota (cs.flow is per stream) + inflow http2flow // peer's conn-level flow control closed bool - goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received - streams map[uint32]*http2clientStream + goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received + streams map[uint32]*http2clientStream // client-initiated nextStreamID uint32 bw *bufio.Writer - werr error // first write error that has occurred br *bufio.Reader fr *http2Framer // Settings from peer: @@ -3459,13 +3539,36 @@ type http2clientConn struct { initialWindowSize uint32 hbuf bytes.Buffer // HPACK encoder writes into this henc *hpack.Encoder + freeBuf [][]byte + + wmu sync.Mutex // held while writing; acquire AFTER wmu if holding both + werr error // first write error that has occurred } +// clientStream is the state for a single HTTP/2 stream. One of these +// is created for each Transport.RoundTrip call. type http2clientStream struct { - ID uint32 - resc chan http2resAndError - pw *io.PipeWriter - pr *io.PipeReader + cc *http2clientConn + ID uint32 + resc chan http2resAndError + bufPipe http2pipe // buffered pipe with the flow-controlled response payload + + flow http2flow // guarded by cc.mu + inflow http2flow // guarded by cc.mu + + peerReset chan struct{} // closed on peer reset + resetErr error // populated before peerReset is closed +} + +// checkReset reports any error sent in a RST_STREAM frame by the +// server. +func (cs *http2clientStream) checkReset() error { + select { + case <-cs.peerReset: + return cs.resetErr + default: + return nil + } } type http2stickyErrWriter struct { @@ -3484,10 +3587,7 @@ func (sew http2stickyErrWriter) Write(p []byte) (n int, err error) { func (t *http2Transport) RoundTrip(req *Request) (*Response, error) { if req.URL.Scheme != "https" { - if t.Fallback == nil { - return nil, errors.New("http2: unsupported scheme and no Fallback") - } - return t.Fallback.RoundTrip(req) + return nil, errors.New("http2: unsupported scheme") } host, port, err := net.SplitHostPort(req.URL.Host) @@ -3556,57 +3656,111 @@ func http2filterOutClientConn(in []*http2clientConn, exclude *http2clientConn) [ out = append(out, v) } } + + if len(in) != len(out) { + in[len(in)-1] = nil + } return out } -func (t *http2Transport) getClientConn(host, port string) (*http2clientConn, error) { +// AddIdleConn adds c as an idle conn for Transport. +// It assumes that c has not yet exchanged SETTINGS frames. +// The addr maybe be either "host" or "host:port". +func (t *http2Transport) AddIdleConn(addr string, c *tls.Conn) error { + var key string + _, _, err := net.SplitHostPort(addr) + if err == nil { + key = addr + } else { + key = addr + ":443" + } + cc, err := t.newClientConn(key, c) + if err != nil { + return err + } + + t.addConn(key, cc) + return nil +} + +func (t *http2Transport) addConn(key string, cc *http2clientConn) { t.connMu.Lock() defer t.connMu.Unlock() - - key := net.JoinHostPort(host, port) - - for _, cc := range t.conns[key] { - if cc.canTakeNewRequest() { - return cc, nil - } - } if t.conns == nil { t.conns = make(map[string][]*http2clientConn) } - cc, err := t.newClientConn(host, port, key) + t.conns[key] = append(t.conns[key], cc) +} + +func (t *http2Transport) getClientConn(host, port string) (*http2clientConn, error) { + key := net.JoinHostPort(host, port) + + t.connMu.Lock() + for _, cc := range t.conns[key] { + if cc.canTakeNewRequest() { + t.connMu.Unlock() + return cc, nil + } + } + t.connMu.Unlock() + + cc, err := t.dialClientConn(host, port, key) if err != nil { return nil, err } - t.conns[key] = append(t.conns[key], cc) + t.addConn(key, cc) return cc, nil } -func (t *http2Transport) newClientConn(host, port, key string) (*http2clientConn, error) { - cfg := &tls.Config{ - ServerName: host, - NextProtos: []string{http2NextProtoTLS}, - InsecureSkipVerify: t.InsecureTLSDial, - } - tconn, err := tls.Dial("tcp", net.JoinHostPort(host, port), cfg) +func (t *http2Transport) dialClientConn(host, port, key string) (*http2clientConn, error) { + tconn, err := t.dialTLS()("tcp", net.JoinHostPort(host, port), t.newTLSConfig(host)) if err != nil { return nil, err } - if err := tconn.Handshake(); err != nil { + return t.newClientConn(key, tconn) +} + +func (t *http2Transport) newTLSConfig(host string) *tls.Config { + cfg := new(tls.Config) + if t.TLSClientConfig != nil { + *cfg = *t.TLSClientConfig + } + cfg.NextProtos = []string{http2NextProtoTLS} + cfg.ServerName = host + return cfg +} + +func (t *http2Transport) dialTLS() func(string, string, *tls.Config) (net.Conn, error) { + if t.DialTLS != nil { + return t.DialTLS + } + return t.dialTLSDefault +} + +func (t *http2Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.Conn, error) { + cn, err := tls.Dial(network, addr, cfg) + if err != nil { return nil, err } - if !t.InsecureTLSDial { - if err := tconn.VerifyHostname(cfg.ServerName); err != nil { + if err := cn.Handshake(); err != nil { + return nil, err + } + if !cfg.InsecureSkipVerify { + if err := cn.VerifyHostname(cfg.ServerName); err != nil { return nil, err } } - state := tconn.ConnectionState() + state := cn.ConnectionState() if p := state.NegotiatedProtocol; p != http2NextProtoTLS { - - return nil, fmt.Errorf("bad protocol: %v", p) + return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2NextProtoTLS) } if !state.NegotiatedProtocolIsMutual { - return nil, errors.New("could not negotiate protocol mutually") + return nil, errors.New("http2: could not negotiate protocol mutually") } + return cn, nil +} + +func (t *http2Transport) newClientConn(key string, tconn net.Conn) (*http2clientConn, error) { if _, err := tconn.Write(http2clientPreface); err != nil { return nil, err } @@ -3615,7 +3769,6 @@ func (t *http2Transport) newClientConn(host, port, key string) (*http2clientConn t: t, tconn: tconn, connKey: []string{key}, - tlsState: &state, readerDone: make(chan struct{}), nextStreamID: 1, maxFrameSize: 16 << 10, @@ -3623,14 +3776,28 @@ func (t *http2Transport) newClientConn(host, port, key string) (*http2clientConn maxConcurrentStreams: 1000, streams: make(map[uint32]*http2clientStream), } + cc.cond = sync.NewCond(&cc.mu) + cc.flow.add(int32(http2initialWindowSize)) + cc.bw = bufio.NewWriter(http2stickyErrWriter{tconn, &cc.werr}) cc.br = bufio.NewReader(tconn) cc.fr = http2NewFramer(cc.bw, cc.br) cc.henc = hpack.NewEncoder(&cc.hbuf) - cc.fr.WriteSettings() + type connectionStater interface { + ConnectionState() tls.ConnectionState + } + if cs, ok := tconn.(connectionStater); ok { + state := cs.ConnectionState() + cc.tlsState = &state + } - cc.fr.WriteWindowUpdate(0, 1<<30) + cc.fr.WriteSettings( + http2Setting{ID: http2SettingEnablePush, Val: 0}, + http2Setting{ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow}, + ) + cc.fr.WriteWindowUpdate(0, http2transportDefaultConnFlow) + cc.inflow.add(http2transportDefaultConnFlow + http2initialWindowSize) cc.bw.Flush() if cc.werr != nil { return nil, cc.werr @@ -3662,8 +3829,6 @@ func (t *http2Transport) newClientConn(host, port, key string) (*http2clientConn return nil }) - cc.hdec = hpack.NewDecoder(http2initialHeaderTableSize, cc.onNewHeaderField) - go cc.readLoop() return cc, nil } @@ -3695,6 +3860,46 @@ func (cc *http2clientConn) closeIfIdle() { cc.tconn.Close() } +const http2maxAllocFrameSize = 512 << 10 + +// frameBuffer returns a scratch buffer suitable for writing DATA frames. +// They're capped at the min of the peer's max frame size or 512KB +// (kinda arbitrarily), but definitely capped so we don't allocate 4GB +// bufers. +func (cc *http2clientConn) frameScratchBuffer() []byte { + cc.mu.Lock() + size := cc.maxFrameSize + if size > http2maxAllocFrameSize { + size = http2maxAllocFrameSize + } + for i, buf := range cc.freeBuf { + if len(buf) >= int(size) { + cc.freeBuf[i] = nil + cc.mu.Unlock() + return buf[:size] + } + } + cc.mu.Unlock() + return make([]byte, size) +} + +func (cc *http2clientConn) putFrameScratchBuffer(buf []byte) { + cc.mu.Lock() + defer cc.mu.Unlock() + const maxBufs = 4 // arbitrary; 4 concurrent requests per conn? investigate. + if len(cc.freeBuf) < maxBufs { + cc.freeBuf = append(cc.freeBuf, buf) + return + } + for i, old := range cc.freeBuf { + if old == nil { + cc.freeBuf[i] = buf + return + } + } + +} + func (cc *http2clientConn) roundTrip(req *Request) (*Response, error) { cc.mu.Lock() @@ -3704,14 +3909,17 @@ func (cc *http2clientConn) roundTrip(req *Request) (*Response, error) { } cs := cc.newStream() - hasBody := false + hasBody := req.Body != nil hdrs := cc.encodeHeaders(req) first := true - for len(hdrs) > 0 { + + cc.wmu.Lock() + frameSize := int(cc.maxFrameSize) + for len(hdrs) > 0 && cc.werr == nil { chunk := hdrs - if len(chunk) > int(cc.maxFrameSize) { - chunk = chunk[:cc.maxFrameSize] + if len(chunk) > frameSize { + chunk = chunk[:frameSize] } hdrs = hdrs[len(chunk):] endHeaders := len(hdrs) == 0 @@ -3729,24 +3937,133 @@ func (cc *http2clientConn) roundTrip(req *Request) (*Response, error) { } cc.bw.Flush() werr := cc.werr + cc.wmu.Unlock() cc.mu.Unlock() - if hasBody { - - } - if werr != nil { return nil, werr } - re := <-cs.resc - if re.err != nil { - return nil, re.err + var bodyCopyErrc chan error + var gotResHeaders chan struct{} // closed on resheaders + if hasBody { + bodyCopyErrc = make(chan error, 1) + gotResHeaders = make(chan struct{}) + go func() { + bodyCopyErrc <- cs.writeRequestBody(req.Body, gotResHeaders) + }() + } + + for { + select { + case re := <-cs.resc: + if gotResHeaders != nil { + close(gotResHeaders) + } + if re.err != nil { + return nil, re.err + } + res := re.res + res.Request = req + res.TLS = cc.tlsState + return res, nil + case err := <-bodyCopyErrc: + if err != nil { + return nil, err + } + } + } +} + +var http2errServerResponseBeforeRequestBody = errors.New("http2: server sent response while still writing request body") + +func (cs *http2clientStream) writeRequestBody(body io.Reader, gotResHeaders <-chan struct{}) error { + cc := cs.cc + sentEnd := false + buf := cc.frameScratchBuffer() + defer cc.putFrameScratchBuffer(buf) + + for !sentEnd { + var sawEOF bool + n, err := io.ReadFull(body, buf) + if err == io.ErrUnexpectedEOF { + sawEOF = true + err = nil + } else if err == io.EOF { + break + } else if err != nil { + return err + } + + toWrite := buf[:n] + for len(toWrite) > 0 && err == nil { + var allowed int32 + allowed, err = cs.awaitFlowControl(int32(len(toWrite))) + if err != nil { + return err + } + + cc.wmu.Lock() + select { + case <-gotResHeaders: + err = http2errServerResponseBeforeRequestBody + case <-cs.peerReset: + err = cs.resetErr + default: + data := toWrite[:allowed] + toWrite = toWrite[allowed:] + sentEnd = sawEOF && len(toWrite) == 0 + err = cc.fr.WriteData(cs.ID, sentEnd, data) + } + cc.wmu.Unlock() + } + if err != nil { + return err + } + } + + var err error + + cc.wmu.Lock() + if !sentEnd { + err = cc.fr.WriteData(cs.ID, true, nil) + } + if ferr := cc.bw.Flush(); ferr != nil && err == nil { + err = ferr + } + cc.wmu.Unlock() + + return err +} + +// awaitFlowControl waits for [1, min(maxBytes, cc.cs.maxFrameSize)] flow +// control tokens from the server. +// It returns either the non-zero number of tokens taken or an error +// if the stream is dead. +func (cs *http2clientStream) awaitFlowControl(maxBytes int32) (taken int32, err error) { + cc := cs.cc + cc.mu.Lock() + defer cc.mu.Unlock() + for { + if cc.closed { + return 0, http2errClientConnClosed + } + if err := cs.checkReset(); err != nil { + return 0, err + } + if a := cs.flow.available(); a > 0 { + take := a + if take > maxBytes { + take = maxBytes + } + if take > int32(cc.maxFrameSize) { + take = int32(cc.maxFrameSize) + } + cs.flow.take(take) + return take, nil + } + cc.cond.Wait() } - res := re.res - res.Request = req - res.TLS = cc.tlsState - return res, nil } // requires cc.mu be held. @@ -3758,14 +4075,9 @@ func (cc *http2clientConn) encodeHeaders(req *Request) []byte { host = req.URL.Host } - path := req.URL.Path - if path == "" { - path = "/" - } - cc.writeHeader(":authority", host) cc.writeHeader(":method", req.Method) - cc.writeHeader(":path", path) + cc.writeHeader(":path", req.URL.RequestURI()) cc.writeHeader(":scheme", "https") for k, vv := range req.Header { @@ -3792,9 +4104,15 @@ type http2resAndError struct { // requires cc.mu be held. func (cc *http2clientConn) newStream() *http2clientStream { cs := &http2clientStream{ - ID: cc.nextStreamID, - resc: make(chan http2resAndError, 1), + cc: cc, + ID: cc.nextStreamID, + resc: make(chan http2resAndError, 1), + peerReset: make(chan struct{}), } + cs.flow.add(int32(cc.initialWindowSize)) + cs.flow.setConnFlow(&cc.flow) + cs.inflow.add(http2transportDefaultStreamFlow) + cs.inflow.setConnFlow(&cc.inflow) cc.nextStreamID += 2 cc.streams[cs.ID] = cs return cs @@ -3810,32 +4128,75 @@ func (cc *http2clientConn) streamByID(id uint32, andRemove bool) *http2clientStr return cs } -// runs in its own goroutine. -func (cc *http2clientConn) readLoop() { - defer cc.t.removeClientConn(cc) - defer close(cc.readerDone) - - activeRes := map[uint32]*http2clientStream{} - - defer func() { - err := cc.readerErr - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - for _, cs := range activeRes { - cs.pw.CloseWithError(err) - } - }() +// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. +type http2clientConnReadLoop struct { + cc *http2clientConn + activeRes map[uint32]*http2clientStream // keyed by streamID // continueStreamID is the stream ID we're waiting for // continuation frames for. - var continueStreamID uint32 + continueStreamID uint32 + hdec *hpack.Decoder + + // Fields reset on each HEADERS: + nextRes *Response + sawRegHeader bool // saw non-pseudo header + reqMalformed error // non-nil once known to be malformed +} + +// readLoop runs in its own goroutine and reads and dispatches frames. +func (cc *http2clientConn) readLoop() { + rl := &http2clientConnReadLoop{ + cc: cc, + activeRes: make(map[uint32]*http2clientStream), + } + + rl.hdec = hpack.NewDecoder(http2initialHeaderTableSize, rl.onNewHeaderField) + + defer rl.cleanup() + cc.readerErr = rl.run() + if ce, ok := cc.readerErr.(http2ConnectionError); ok { + cc.wmu.Lock() + cc.fr.WriteGoAway(0, http2ErrCode(ce), nil) + cc.wmu.Unlock() + } +} + +func (rl *http2clientConnReadLoop) cleanup() { + cc := rl.cc + defer cc.tconn.Close() + defer cc.t.removeClientConn(cc) + defer close(cc.readerDone) + + err := cc.readerErr + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + cc.mu.Lock() + for _, cs := range rl.activeRes { + cs.bufPipe.CloseWithError(err) + } + for _, cs := range cc.streams { + select { + case cs.resc <- http2resAndError{err: err}: + default: + } + } + cc.closed = true + cc.cond.Broadcast() + cc.mu.Unlock() +} + +func (rl *http2clientConnReadLoop) run() error { + cc := rl.cc for { f, err := cc.fr.ReadFrame() - if err != nil { - cc.readerErr = err - return + if se, ok := err.(http2StreamError); ok { + + return se + } else if err != nil { + return err } cc.vlogf("Transport received %v: %#v", f.Header(), f) @@ -3843,105 +4204,292 @@ func (cc *http2clientConn) readLoop() { _, isContinue := f.(*http2ContinuationFrame) if isContinue { - if streamID != continueStreamID { - cc.logf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, continueStreamID) - cc.readerErr = http2ConnectionError(http2ErrCodeProtocol) - return + if streamID != rl.continueStreamID { + cc.logf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, rl.continueStreamID) + return http2ConnectionError(http2ErrCodeProtocol) } - } else if continueStreamID != 0 { + } else if rl.continueStreamID != 0 { - cc.logf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, continueStreamID) - cc.readerErr = http2ConnectionError(http2ErrCodeProtocol) - return - } - - if streamID%2 == 0 { - - continue - } - streamEnded := false - if ff, ok := f.(http2streamEnder); ok { - streamEnded = ff.StreamEnded() - } - - cs := cc.streamByID(streamID, streamEnded) - if cs == nil { - cc.logf("Received frame for untracked stream ID %d", streamID) - continue + cc.logf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, rl.continueStreamID) + return http2ConnectionError(http2ErrCodeProtocol) } switch f := f.(type) { case *http2HeadersFrame: - cc.nextRes = &Response{ - Proto: "HTTP/2.0", - ProtoMajor: 2, - Header: make(Header), - } - cs.pr, cs.pw = io.Pipe() - cc.hdec.Write(f.HeaderBlockFragment()) + err = rl.processHeaders(f) case *http2ContinuationFrame: - cc.hdec.Write(f.HeaderBlockFragment()) + err = rl.processContinuation(f) case *http2DataFrame: - if http2VerboseLogs { - cc.logf("DATA: %q", f.Data()) - } - cs.pw.Write(f.Data()) + err = rl.processData(f) case *http2GoAwayFrame: - cc.t.removeClientConn(cc) - if f.ErrCode != 0 { - - cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode) - } - cc.setGoAway(f) + err = rl.processGoAway(f) + case *http2RSTStreamFrame: + err = rl.processResetStream(f) + case *http2SettingsFrame: + err = rl.processSettings(f) + case *http2PushPromiseFrame: + err = rl.processPushPromise(f) + case *http2WindowUpdateFrame: + err = rl.processWindowUpdate(f) default: cc.logf("Transport: unhandled response frame type %T", f) } - headersEnded := false - if he, ok := f.(http2headersEnder); ok { - headersEnded = he.HeadersEnded() - if headersEnded { - continueStreamID = 0 - } else { - continueStreamID = streamID - } - } - - if streamEnded { - cs.pw.Close() - delete(activeRes, streamID) - } - if headersEnded { - if cs == nil { - panic("couldn't find stream") - } - - cc.nextRes.Body = cs.pr - res := cc.nextRes - activeRes[streamID] = cs - cs.resc <- http2resAndError{res: res} + if err != nil { + return err } } } -func (cc *http2clientConn) onNewHeaderField(f hpack.HeaderField) { +func (rl *http2clientConnReadLoop) processHeaders(f *http2HeadersFrame) error { + rl.sawRegHeader = false + rl.reqMalformed = nil + rl.nextRes = &Response{ + Proto: "HTTP/2.0", + ProtoMajor: 2, + Header: make(Header), + } + return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded(), f.StreamEnded()) +} +func (rl *http2clientConnReadLoop) processContinuation(f *http2ContinuationFrame) error { + return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded(), f.StreamEnded()) +} + +func (rl *http2clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID uint32, headersEnded, streamEnded bool) error { + cc := rl.cc + cs := cc.streamByID(streamID, streamEnded) + if cs == nil { + + return nil + } + _, err := rl.hdec.Write(frag) + if err != nil { + return err + } + if !headersEnded { + rl.continueStreamID = cs.ID + return nil + } + + rl.continueStreamID = 0 + + if rl.reqMalformed != nil { + cs.resc <- http2resAndError{err: rl.reqMalformed} + rl.cc.writeStreamReset(cs.ID, http2ErrCodeProtocol, rl.reqMalformed) + return nil + } + + res := rl.nextRes + if streamEnded { + res.Body = http2noBody + } else { + buf := new(bytes.Buffer) + cs.bufPipe = http2pipe{b: buf} + res.Body = http2transportResponseBody{cs} + } + rl.activeRes[cs.ID] = cs + cs.resc <- http2resAndError{res: res} + rl.nextRes = nil + return nil +} + +// transportResponseBody is the concrete type of Transport.RoundTrip's +// Response.Body. It is an io.ReadCloser. On Read, it reads from cs.body. +// On Close it sends RST_STREAM if EOF wasn't already seen. +type http2transportResponseBody struct { + cs *http2clientStream +} + +func (b http2transportResponseBody) Read(p []byte) (n int, err error) { + n, err = b.cs.bufPipe.Read(p) + if n == 0 { + return + } + + cs := b.cs + cc := cs.cc + cc.mu.Lock() + defer cc.mu.Unlock() + + var connAdd, streamAdd int32 + + if v := cc.inflow.available(); v < http2transportDefaultConnFlow/2 { + connAdd = http2transportDefaultConnFlow - v + cc.inflow.add(connAdd) + } + if err == nil { + if v := cs.inflow.available(); v < http2transportDefaultStreamFlow-http2transportDefaultStreamMinRefresh { + streamAdd = http2transportDefaultStreamFlow - v + cs.inflow.add(streamAdd) + } + } + if connAdd != 0 || streamAdd != 0 { + cc.wmu.Lock() + defer cc.wmu.Unlock() + if connAdd != 0 { + cc.fr.WriteWindowUpdate(0, http2mustUint31(connAdd)) + } + if streamAdd != 0 { + cc.fr.WriteWindowUpdate(cs.ID, http2mustUint31(streamAdd)) + } + cc.bw.Flush() + } + return +} + +func (b http2transportResponseBody) Close() error { + if b.cs.bufPipe.Err() != io.EOF { + + b.cs.cc.writeStreamReset(b.cs.ID, http2ErrCodeCancel, nil) + } + return nil +} + +func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { + cc := rl.cc + cs := cc.streamByID(f.StreamID, f.StreamEnded()) + if cs == nil { + return nil + } + data := f.Data() + if http2VerboseLogs { + rl.cc.logf("DATA: %q", data) + } + + cc.mu.Lock() + if cs.inflow.available() >= int32(len(data)) { + cs.inflow.take(int32(len(data))) + } else { + cc.mu.Unlock() + return http2ConnectionError(http2ErrCodeFlowControl) + } + cc.mu.Unlock() + + if _, err := cs.bufPipe.Write(data); err != nil { + return err + } + + if f.StreamEnded() { + cs.bufPipe.CloseWithError(io.EOF) + delete(rl.activeRes, cs.ID) + } + return nil +} + +func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { + cc := rl.cc + cc.t.removeClientConn(cc) + if f.ErrCode != 0 { + + cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode) + } + cc.setGoAway(f) + return nil +} + +func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error { + cc := rl.cc + cc.mu.Lock() + defer cc.mu.Unlock() + return f.ForeachSetting(func(s http2Setting) error { + switch s.ID { + case http2SettingMaxFrameSize: + cc.maxFrameSize = s.Val + case http2SettingMaxConcurrentStreams: + cc.maxConcurrentStreams = s.Val + case http2SettingInitialWindowSize: + + cc.initialWindowSize = s.Val + default: + + cc.vlogf("Unhandled Setting: %v", s) + } + return nil + }) +} + +func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame) error { + cc := rl.cc + cs := cc.streamByID(f.StreamID, false) + if f.StreamID != 0 && cs == nil { + return nil + } + + cc.mu.Lock() + defer cc.mu.Unlock() + + fl := &cc.flow + if cs != nil { + fl = &cs.flow + } + if !fl.add(int32(f.Increment)) { + return http2ConnectionError(http2ErrCodeFlowControl) + } + cc.cond.Broadcast() + return nil +} + +func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) error { + cs := rl.cc.streamByID(f.StreamID, true) + if cs == nil { + + return nil + } + select { + case <-cs.peerReset: + + default: + err := http2StreamError{cs.ID, f.ErrCode} + cs.resetErr = err + close(cs.peerReset) + cs.bufPipe.CloseWithError(err) + } + delete(rl.activeRes, cs.ID) + return nil +} + +func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame) error { + + return http2ConnectionError(http2ErrCodeProtocol) +} + +func (cc *http2clientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) { + + cc.wmu.Lock() + cc.fr.WriteRSTStream(streamID, code) + cc.wmu.Unlock() +} + +// onNewHeaderField runs on the readLoop goroutine whenever a new +// hpack header field is decoded. +func (rl *http2clientConnReadLoop) onNewHeaderField(f hpack.HeaderField) { + cc := rl.cc if http2VerboseLogs { cc.logf("Header field: %+v", f) } - if f.Name == ":status" { - code, err := strconv.Atoi(f.Value) - if err != nil { - panic("TODO: be graceful") + isPseudo := strings.HasPrefix(f.Name, ":") + if isPseudo { + if rl.sawRegHeader { + rl.reqMalformed = errors.New("http2: invalid pseudo header after regular header") + return } - cc.nextRes.Status = f.Value + " " + StatusText(code) - cc.nextRes.StatusCode = code - return - } - if strings.HasPrefix(f.Name, ":") { + switch f.Name { + case ":status": + code, err := strconv.Atoi(f.Value) + if err != nil { + rl.reqMalformed = errors.New("http2: invalid :status") + return + } + rl.nextRes.Status = f.Value + " " + StatusText(code) + rl.nextRes.StatusCode = code + default: - return + rl.reqMalformed = fmt.Errorf("http2: unknown response pseudo header %q", f.Name) + } + } else { + rl.sawRegHeader = true + rl.nextRes.Header.Add(CanonicalHeaderKey(f.Name), f.Value) } - cc.nextRes.Header.Add(CanonicalHeaderKey(f.Name), f.Value) } func (cc *http2clientConn) logf(format string, args ...interface{}) { @@ -3962,6 +4510,8 @@ func (t *http2Transport) logf(format string, args ...interface{}) { log.Printf(format, args...) } +var http2noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) + // writeFramer is implemented by any type that is used to write frames. type http2writeFramer interface { writeFrame(http2writeContext) error