diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go index c59d4ad56c..f0c7c5a11a 100644 --- a/internal/jsonrpc2/jsonrpc2.go +++ b/internal/jsonrpc2/jsonrpc2.go @@ -27,7 +27,8 @@ const ( // Conn is a JSON RPC 2 client server connection. // Conn is bidirectional; it does not have a designated server or client end. type Conn struct { - seq int64 // must only be accessed using atomic operations + seq int64 // must only be accessed using atomic operations + writeMu sync.Mutex // protects writes to the stream stream Stream pendingMu sync.Mutex // protects the pending map pending map[ID]chan *Response @@ -65,7 +66,7 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (e }() event.Metric(ctx, tag.Started.Of(1)) - n, err := c.stream.Write(ctx, notify) + n, err := c.write(ctx, notify) event.Metric(ctx, tag.SentBytes.Of(n)) return err } @@ -104,7 +105,7 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface c.pendingMu.Unlock() }() // now we are ready to send - n, err := c.stream.Write(ctx, call) + n, err := c.write(ctx, call) event.Metric(ctx, tag.SentBytes.Of(n)) if err != nil { // sending failed, we will never get a response, so don't leave it pending @@ -144,7 +145,7 @@ func replier(conn *Conn, req Request, spanDone func()) Replier { if err != nil { return err } - n, err := conn.stream.Write(ctx, response) + n, err := conn.write(ctx, response) event.Metric(ctx, tag.SentBytes.Of(n)) if err != nil { // TODO(iancottrell): if a stream write fails, we really need to shut down @@ -155,6 +156,12 @@ func replier(conn *Conn, req Request, spanDone func()) Replier { } } +func (c *Conn) write(ctx context.Context, msg Message) (int64, error) { + c.writeMu.Lock() + defer c.writeMu.Unlock() + return c.stream.Write(ctx, msg) +} + // Run blocks until the connection is terminated, and returns any error that // caused the termination. // It must be called exactly once for each Conn. diff --git a/internal/jsonrpc2/jsonrpc2_test.go b/internal/jsonrpc2/jsonrpc2_test.go index e586009b57..b3d67e8799 100644 --- a/internal/jsonrpc2/jsonrpc2_test.go +++ b/internal/jsonrpc2/jsonrpc2_test.go @@ -119,11 +119,7 @@ func run(ctx context.Context, t *testing.T, withHeaders bool, r io.ReadCloser, w wg.Add(1) go func() { defer func() { - // this will happen when Run returns, which means at least one of the - // streams has already been closed - // we close both streams anyway, this may be redundant but is safe - r.Close() - w.Close() + stream.Close() // and then signal that this connection is done wg.Done() }() diff --git a/internal/jsonrpc2/serve.go b/internal/jsonrpc2/serve.go index 1a8494145d..ff6e48884e 100644 --- a/internal/jsonrpc2/serve.go +++ b/internal/jsonrpc2/serve.go @@ -98,6 +98,7 @@ func Serve(ctx context.Context, ln net.Listener, server StreamServer, idleTimeou stream := NewHeaderStream(netConn, netConn) go func() { closedConns <- server.ServeStream(ctx, stream) + stream.Close() }() case err := <-doneListening: return err diff --git a/internal/jsonrpc2/stream.go b/internal/jsonrpc2/stream.go index c12eb003bd..f56b72bf81 100644 --- a/internal/jsonrpc2/stream.go +++ b/internal/jsonrpc2/stream.go @@ -10,38 +10,43 @@ import ( "encoding/json" "fmt" "io" + "net" "strconv" "strings" - "sync" + + "golang.org/x/tools/internal/fakenet" ) // Stream abstracts the transport mechanics from the JSON RPC protocol. // A Conn reads and writes messages using the stream it was provided on // construction, and assumes that each call to Read or Write fully transfers // a single message, or returns an error. +// A stream is not safe for concurrent use, it is expected it will be used by +// a single Conn in a safe manner. type Stream interface { // Read gets the next message from the stream. - // It is never called concurrently. Read(context.Context) (Message, int64, error) // Write sends a message to the stream. - // It must be safe for concurrent use. Write(context.Context, Message) (int64, error) + // Close closes the connection. + // Any blocked Read or Write operations will be unblocked and return errors. + Close() error } // NewRawStream returns a Stream built on top of an io.Reader and io.Writer. // The messages are sent with no wrapping, and rely on json decode consistency // to determine message boundaries. -func NewRawStream(in io.Reader, out io.Writer) Stream { +func NewRawStream(in io.ReadCloser, out io.WriteCloser) Stream { + conn := fakenet.NewConn("jsonrpc2.NewRawStream", in, out) return &rawStream{ - in: json.NewDecoder(in), - out: out, + conn: conn, + in: json.NewDecoder(conn), } } type rawStream struct { - in *json.Decoder - outMu sync.Mutex - out io.Writer + conn net.Conn + in *json.Decoder } func (s *rawStream) Read(ctx context.Context) (Message, int64, error) { @@ -68,26 +73,28 @@ func (s *rawStream) Write(ctx context.Context, msg Message) (int64, error) { if err != nil { return 0, fmt.Errorf("marshaling message: %v", err) } - s.outMu.Lock() - n, err := s.out.Write(data) - s.outMu.Unlock() + n, err := s.conn.Write(data) return int64(n), err } +func (s *rawStream) Close() error { + return s.conn.Close() +} + // NewHeaderStream returns a Stream built on top of an io.Reader and io.Writer. // The messages are sent with HTTP content length and MIME type headers. // This is the format used by LSP and others. -func NewHeaderStream(in io.Reader, out io.Writer) Stream { +func NewHeaderStream(in io.ReadCloser, out io.WriteCloser) Stream { + conn := fakenet.NewConn("jsonrpc2.NewHeaderStream", in, out) return &headerStream{ - in: bufio.NewReader(in), - out: out, + conn: conn, + in: bufio.NewReader(conn), } } type headerStream struct { - in *bufio.Reader - outMu sync.Mutex - out io.Writer + conn net.Conn + in *bufio.Reader } func (s *headerStream) Read(ctx context.Context) (Message, int64, error) { @@ -148,13 +155,15 @@ func (s *headerStream) Write(ctx context.Context, msg Message) (int64, error) { if err != nil { return 0, fmt.Errorf("marshaling message: %v", err) } - s.outMu.Lock() - defer s.outMu.Unlock() - n, err := fmt.Fprintf(s.out, "Content-Length: %v\r\n\r\n", len(data)) + n, err := fmt.Fprintf(s.conn, "Content-Length: %v\r\n\r\n", len(data)) total := int64(n) if err == nil { - n, err = s.out.Write(data) + n, err = s.conn.Write(data) total += int64(n) } return total, err } + +func (s *headerStream) Close() error { + return s.conn.Close() +} diff --git a/internal/lsp/protocol/log.go b/internal/lsp/protocol/log.go index 597553b47c..2c82c64a12 100644 --- a/internal/lsp/protocol/log.go +++ b/internal/lsp/protocol/log.go @@ -36,6 +36,10 @@ func (s *loggingStream) Write(ctx context.Context, msg jsonrpc2.Message) (int64, return count, err } +func (s *loggingStream) Close() error { + return s.stream.Close() +} + type req struct { method string start time.Time