1
0
mirror of https://github.com/golang/go synced 2024-11-18 16:34:51 -07:00

internal/jsonrpc2: extract logic to handler hooks

Change-Id: Ief531e4b68fcb0dbc71e263c185fb285a9042479
Reviewed-on: https://go-review.googlesource.com/c/tools/+/185983
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
This commit is contained in:
Ian Cottrell 2019-07-12 00:43:12 -04:00
parent b32ec66a23
commit 8b927904ee
6 changed files with 255 additions and 82 deletions

View File

@ -6,8 +6,6 @@ package jsonrpc2
import (
"context"
"encoding/json"
"time"
)
// Handler is the interface used to hook into the mesage handling of an rpc
@ -38,7 +36,26 @@ type Handler interface {
// method is the method name specified in the message
// payload is the parameters for a call or notification, and the result for a
// response
Log(direction Direction, id *ID, elapsed time.Duration, method string, payload *json.RawMessage, err *Error)
// Request is called near the start of processing any request.
Request(ctx context.Context, direction Direction, r *WireRequest) context.Context
// Response is called near the start of processing any response.
Response(ctx context.Context, direction Direction, r *WireResponse) context.Context
// Done is called when any request is fully processed.
// For calls, this means the response has also been processed, for notifies
// this is as soon as the message has been written to the stream.
// If err is set, it implies the request failed.
Done(ctx context.Context, err error)
// Read is called with a count each time some data is read from the stream.
// The read calls are delayed until after the data has been interpreted so
// that it can be attributed to a request/response.
Read(ctx context.Context, bytes int64) context.Context
// Wrote is called each time some data is written to the stream.
Wrote(ctx context.Context, bytes int64) context.Context
// Error is called with errors that cannot be delivered through the normal
// mechanisms, for instance a failure to process a notify cannot be delivered
// back to the other party.
Error(ctx context.Context, err error)
}
// Direction is used to indicate to a logger whether the logged message was being
@ -73,9 +90,27 @@ func (EmptyHandler) Cancel(ctx context.Context, conn *Conn, id ID, cancelled boo
return false
}
func (EmptyHandler) Log(direction Direction, id *ID, elapsed time.Duration, method string, payload *json.RawMessage, err *Error) {
func (EmptyHandler) Request(ctx context.Context, direction Direction, r *WireRequest) context.Context {
return ctx
}
func (EmptyHandler) Response(ctx context.Context, direction Direction, r *WireResponse) context.Context {
return ctx
}
func (EmptyHandler) Done(ctx context.Context, err error) {
}
func (EmptyHandler) Read(ctx context.Context, bytes int64) context.Context {
return ctx
}
func (EmptyHandler) Wrote(ctx context.Context, bytes int64) context.Context {
return ctx
}
func (EmptyHandler) Error(ctx context.Context, err error) {}
type defaultHandler struct{ EmptyHandler }
func (defaultHandler) Deliver(ctx context.Context, r *Request, delivered bool) bool {

View File

@ -11,6 +11,7 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"sync"
"sync/atomic"
"time"
@ -28,7 +29,7 @@ type Conn struct {
stream Stream
err error
pendingMu sync.Mutex // protects the pending map
pending map[ID]chan *wireResponse
pending map[ID]chan *WireResponse
handlingMu sync.Mutex // protects the handling map
handling map[ID]*Request
}
@ -47,18 +48,11 @@ const (
type Request struct {
conn *Conn
cancel context.CancelFunc
start time.Time
state requestState
nextRequest chan struct{}
// Method is a string containing the method name to invoke.
Method string
// Params is either a struct or an array with the parameters of the method.
Params *json.RawMessage
// The id of this request, used to tie the response back to the request.
// Will be either a string or a number. If not set, the request is a notify,
// and no response is possible.
ID *ID
// The Wire values of the request.
WireRequest
}
type rpcStats struct {
@ -115,9 +109,9 @@ func NewErrorf(code int64, format string, args ...interface{}) *Error {
// You must call Run for the connection to be active.
func NewConn(s Stream) *Conn {
conn := &Conn{
handlers: []Handler{defaultHandler{}},
handlers: []Handler{defaultHandler{}, &tracer{}},
stream: s,
pending: make(map[ID]chan *wireResponse),
pending: make(map[ID]chan *WireResponse),
handling: make(map[ID]*Request),
}
return conn
@ -150,14 +144,11 @@ func (c *Conn) Cancel(id ID) {
// It will return as soon as the notification has been sent, as no response is
// possible.
func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (err error) {
ctx, rpcStats := start(ctx, false, method, nil)
defer rpcStats.end(ctx, &err)
jsonParams, err := marshalToRaw(params)
if err != nil {
return fmt.Errorf("marshalling notify parameters: %v", err)
}
request := &wireRequest{
request := &WireRequest{
Method: method,
Params: jsonParams,
}
@ -166,10 +157,17 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (e
return fmt.Errorf("marshalling notify request: %v", err)
}
for _, h := range c.handlers {
h.Log(Send, nil, -1, request.Method, request.Params, nil)
ctx = h.Request(ctx, Send, request)
}
defer func() {
for _, h := range c.handlers {
h.Done(ctx, err)
}
}()
n, err := c.stream.Write(ctx, data)
telemetry.SentBytes.Record(ctx, n)
for _, h := range c.handlers {
ctx = h.Wrote(ctx, n)
}
return err
}
@ -179,13 +177,11 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (e
func (c *Conn) Call(ctx context.Context, method string, params, result interface{}) (err error) {
// generate a new request identifier
id := ID{Number: atomic.AddInt64(&c.seq, 1)}
ctx, rpcStats := start(ctx, false, method, &id)
defer rpcStats.end(ctx, &err)
jsonParams, err := marshalToRaw(params)
if err != nil {
return fmt.Errorf("marshalling call parameters: %v", err)
}
request := &wireRequest{
request := &WireRequest{
ID: &id,
Method: method,
Params: jsonParams,
@ -195,9 +191,12 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
if err != nil {
return fmt.Errorf("marshalling call request: %v", err)
}
for _, h := range c.handlers {
ctx = h.Request(ctx, Send, request)
}
// we have to add ourselves to the pending map before we send, otherwise we
// are racing the response
rchan := make(chan *wireResponse)
rchan := make(chan *WireResponse)
c.pendingMu.Lock()
c.pending[id] = rchan
c.pendingMu.Unlock()
@ -206,14 +205,15 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
c.pendingMu.Lock()
delete(c.pending, id)
c.pendingMu.Unlock()
for _, h := range c.handlers {
h.Done(ctx, err)
}
}()
// now we are ready to send
before := time.Now()
for _, h := range c.handlers {
h.Log(Send, request.ID, -1, request.Method, request.Params, nil)
}
n, err := c.stream.Write(ctx, data)
telemetry.SentBytes.Record(ctx, n)
for _, h := range c.handlers {
ctx = h.Wrote(ctx, n)
}
if err != nil {
// sending failed, we will never get a response, so don't leave it pending
return err
@ -221,9 +221,8 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
// now wait for the response
select {
case response := <-rchan:
elapsed := time.Since(before)
for _, h := range c.handlers {
h.Log(Receive, response.ID, elapsed, request.Method, response.Result, response.Error)
ctx = h.Response(ctx, Receive, response)
}
// is it an error response?
if response.Error != nil {
@ -283,9 +282,6 @@ func (r *Request) Reply(ctx context.Context, result interface{}, err error) erro
if r.IsNotify() {
return fmt.Errorf("reply not invoked with a valid call")
}
ctx, close := trace.StartSpan(ctx, r.Method+":reply")
defer close()
// reply ends the handling phase of a call, so if we are not yet
// parallel we should be now. The go routine is allowed to continue
// to do work after replying, which is why it is important to unlock
@ -293,12 +289,11 @@ func (r *Request) Reply(ctx context.Context, result interface{}, err error) erro
r.Parallel()
r.state = requestReplied
elapsed := time.Since(r.start)
var raw *json.RawMessage
if err == nil {
raw, err = marshalToRaw(result)
}
response := &wireResponse{
response := &WireResponse{
Result: raw,
ID: r.ID,
}
@ -314,10 +309,12 @@ func (r *Request) Reply(ctx context.Context, result interface{}, err error) erro
return err
}
for _, h := range r.conn.handlers {
h.Log(Send, response.ID, elapsed, r.Method, response.Result, response.Error)
ctx = h.Response(ctx, Send, response)
}
n, err := r.conn.stream.Write(ctx, data)
telemetry.SentBytes.Record(ctx, n)
for _, h := range r.conn.handlers {
ctx = h.Wrote(ctx, n)
}
if err != nil {
// TODO(iancottrell): if a stream write fails, we really need to shut down
@ -374,7 +371,7 @@ func (c *Conn) Run(ctx context.Context) error {
// a badly formed message arrived, log it and continue
// we trust the stream to have isolated the error to just this message
for _, h := range c.handlers {
h.Log(Receive, nil, -1, "", nil, NewErrorf(0, "unmarshal failed: %v", err))
h.Error(ctx, fmt.Errorf("unmarshal failed: %v", err))
}
continue
}
@ -382,19 +379,23 @@ func (c *Conn) Run(ctx context.Context) error {
switch {
case msg.Method != "":
// if method is set it must be a request
reqCtx, cancelReq := context.WithCancel(ctx)
reqCtx, rpcStats := start(reqCtx, true, msg.Method, msg.ID)
telemetry.ReceivedBytes.Record(ctx, n)
ctx, cancelReq := context.WithCancel(ctx)
thisRequest := nextRequest
nextRequest = make(chan struct{})
req := &Request{
conn: c,
cancel: cancelReq,
nextRequest: nextRequest,
start: time.Now(),
Method: msg.Method,
Params: msg.Params,
ID: msg.ID,
WireRequest: WireRequest{
VersionTag: msg.VersionTag,
Method: msg.Method,
Params: msg.Params,
ID: msg.ID,
},
}
for _, h := range c.handlers {
ctx = h.Request(ctx, Receive, &req.WireRequest)
ctx = h.Read(ctx, n)
}
c.setHandling(req, true)
go func() {
@ -403,16 +404,17 @@ func (c *Conn) Run(ctx context.Context) error {
defer func() {
c.setHandling(req, false)
if !req.IsNotify() && req.state < requestReplied {
req.Reply(reqCtx, nil, NewErrorf(CodeInternalError, "method %q did not reply", req.Method))
req.Reply(ctx, nil, NewErrorf(CodeInternalError, "method %q did not reply", req.Method))
}
req.Parallel()
rpcStats.end(reqCtx, nil)
for _, h := range c.handlers {
h.Done(ctx, err)
}
cancelReq()
}()
delivered := false
for _, h := range c.handlers {
h.Log(Receive, req.ID, -1, req.Method, req.Params, nil)
if h.Deliver(reqCtx, req, delivered) {
if h.Deliver(ctx, req, delivered) {
delivered = true
}
}
@ -426,7 +428,7 @@ func (c *Conn) Run(ctx context.Context) error {
}
c.pendingMu.Unlock()
// and send the reply to the channel
response := &wireResponse{
response := &WireResponse{
Result: msg.Result,
Error: msg.Error,
ID: msg.ID,
@ -435,7 +437,7 @@ func (c *Conn) Run(ctx context.Context) error {
close(rchan)
default:
for _, h := range c.handlers {
h.Log(Receive, nil, -1, "", nil, NewErrorf(0, "message not a call, notify or response, ignoring"))
h.Error(ctx, fmt.Errorf("message not a call, notify or response, ignoring"))
}
}
}
@ -449,3 +451,49 @@ func marshalToRaw(obj interface{}) (*json.RawMessage, error) {
raw := json.RawMessage(data)
return &raw, nil
}
type statsKeyType int
const statsKey = statsKeyType(0)
type tracer struct {
}
func (h *tracer) Deliver(ctx context.Context, r *Request, delivered bool) bool {
return false
}
func (h *tracer) Cancel(ctx context.Context, conn *Conn, id ID, cancelled bool) bool {
return false
}
func (h *tracer) Request(ctx context.Context, direction Direction, r *WireRequest) context.Context {
ctx, stats := start(ctx, direction == Receive, r.Method, r.ID)
ctx = context.WithValue(ctx, statsKey, stats)
return ctx
}
func (h *tracer) Response(ctx context.Context, direction Direction, r *WireResponse) context.Context {
return ctx
}
func (h *tracer) Done(ctx context.Context, err error) {
stats, ok := ctx.Value(statsKey).(*rpcStats)
if ok && stats != nil {
stats.end(ctx, &err)
}
}
func (h *tracer) Read(ctx context.Context, bytes int64) context.Context {
telemetry.SentBytes.Record(ctx, bytes)
return ctx
}
func (h *tracer) Wrote(ctx context.Context, bytes int64) context.Context {
telemetry.ReceivedBytes.Record(ctx, bytes)
return ctx
}
func (h *tracer) Error(ctx context.Context, err error) {
log.Printf("%v", err)
}

View File

@ -108,7 +108,7 @@ func run(ctx context.Context, t *testing.T, withHeaders bool, r io.ReadCloser, w
stream = jsonrpc2.NewStream(r, w)
}
conn := jsonrpc2.NewConn(stream)
conn.AddHandler(handle{})
conn.AddHandler(&handle{log: *logRPC})
go func() {
defer func() {
r.Close()
@ -121,9 +121,11 @@ func run(ctx context.Context, t *testing.T, withHeaders bool, r io.ReadCloser, w
return conn
}
type handle struct{ jsonrpc2.EmptyHandler }
type handle struct {
log bool
}
func (handle) Deliver(ctx context.Context, r *jsonrpc2.Request, delivered bool) bool {
func (h *handle) Deliver(ctx context.Context, r *jsonrpc2.Request, delivered bool) bool {
switch r.Method {
case "no_args":
if r.Params != nil {
@ -158,18 +160,43 @@ func (handle) Deliver(ctx context.Context, r *jsonrpc2.Request, delivered bool)
return true
}
func (handle) Log(direction jsonrpc2.Direction, id *jsonrpc2.ID, elapsed time.Duration, method string, payload *json.RawMessage, err *jsonrpc2.Error) {
if !*logRPC {
return
}
switch {
case err != nil:
log.Printf("%v failure [%v] %s %v", direction, id, method, err)
case id == nil:
log.Printf("%v notification %s %s", direction, method, *payload)
case elapsed >= 0:
log.Printf("%v response in %v [%v] %s %s", direction, elapsed, id, method, *payload)
default:
log.Printf("%v call [%v] %s %s", direction, id, method, *payload)
}
func (h *handle) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.ID, cancelled bool) bool {
return false
}
func (h *handle) Request(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context {
if h.log {
if r.ID != nil {
log.Printf("%v call [%v] %s %s", direction, r.ID, r.Method, r.Params)
} else {
log.Printf("%v notification %s %s", direction, r.Method, r.Params)
}
ctx = context.WithValue(ctx, "method", r.Method)
ctx = context.WithValue(ctx, "start", time.Now())
}
return ctx
}
func (h *handle) Response(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context {
if h.log {
method := ctx.Value("method")
elapsed := time.Since(ctx.Value("start").(time.Time))
log.Printf("%v response in %v [%v] %s %s", direction, elapsed, r.ID, method, r.Result)
}
return ctx
}
func (h *handle) Done(ctx context.Context, err error) {
}
func (h *handle) Read(ctx context.Context, bytes int64) context.Context {
return ctx
}
func (h *handle) Wrote(ctx context.Context, bytes int64) context.Context {
return ctx
}
func (h *handle) Error(ctx context.Context, err error) {
log.Printf("%v", err)
}

View File

@ -34,8 +34,8 @@ const (
CodeServerOverloaded = -32000
)
// wireRequest is sent to a server to represent a Call or Notify operaton.
type wireRequest struct {
// WireRequest is sent to a server to represent a Call or Notify operaton.
type WireRequest struct {
// VersionTag is always encoded as the string "2.0"
VersionTag VersionTag `json:"jsonrpc"`
// Method is a string containing the method name to invoke.
@ -48,11 +48,11 @@ type wireRequest struct {
ID *ID `json:"id,omitempty"`
}
// wireResponse is a reply to a Request.
// WireResponse is a reply to a Request.
// It will always have the ID field set to tie it back to a request, and will
// have either the Result or Error fields set depending on whether it is a
// success or failure response.
type wireResponse struct {
type WireResponse struct {
// VersionTag is always encoded as the string "2.0"
VersionTag VersionTag `json:"jsonrpc"`
// Result is the response value, and is required on success.

View File

@ -120,6 +120,18 @@ type handler struct {
out io.Writer
}
type rpcStats struct {
method string
direction jsonrpc2.Direction
id *jsonrpc2.ID
payload *json.RawMessage
start time.Time
}
type statsKeyType int
const statsKey = statsKeyType(0)
func (h *handler) Deliver(ctx context.Context, r *jsonrpc2.Request, delivered bool) bool {
return false
}
@ -128,7 +140,63 @@ func (h *handler) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.I
return false
}
func (h *handler) Log(direction jsonrpc2.Direction, id *jsonrpc2.ID, elapsed time.Duration, method string, payload *json.RawMessage, err *jsonrpc2.Error) {
func (h *handler) Request(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context {
if !h.trace {
return ctx
}
stats := &rpcStats{
method: r.Method,
direction: direction,
start: time.Now(),
payload: r.Params,
}
ctx = context.WithValue(ctx, statsKey, stats)
return ctx
}
func (h *handler) Response(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context {
stats := h.getStats(ctx)
h.log(direction, r.ID, 0, stats.method, r.Result, nil)
return ctx
}
func (h *handler) Done(ctx context.Context, err error) {
if !h.trace {
return
}
stats := h.getStats(ctx)
h.log(stats.direction, stats.id, time.Since(stats.start), stats.method, stats.payload, err)
}
func (h *handler) Read(ctx context.Context, bytes int64) context.Context {
return ctx
}
func (h *handler) Wrote(ctx context.Context, bytes int64) context.Context {
return ctx
}
const eol = "\r\n\r\n\r\n"
func (h *handler) Error(ctx context.Context, err error) {
if !h.trace {
return
}
stats := h.getStats(ctx)
h.log(stats.direction, stats.id, 0, stats.method, nil, err)
}
func (h *handler) getStats(ctx context.Context) *rpcStats {
stats, ok := ctx.Value(statsKey).(*rpcStats)
if !ok || stats == nil {
stats = &rpcStats{
method: "???",
}
}
return stats
}
func (h *handler) log(direction jsonrpc2.Direction, id *jsonrpc2.ID, elapsed time.Duration, method string, payload *json.RawMessage, err error) {
if !h.trace {
return
}

View File

@ -6,8 +6,6 @@ package protocol
import (
"context"
"encoding/json"
"time"
"golang.org/x/tools/internal/jsonrpc2"
"golang.org/x/tools/internal/lsp/telemetry/trace"
@ -17,7 +15,7 @@ import (
type DocumentUri = string
type canceller struct{}
type canceller struct{ jsonrpc2.EmptyHandler }
type clientHandler struct {
canceller
@ -42,9 +40,6 @@ func (canceller) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.ID
return true
}
func (canceller) Log(direction jsonrpc2.Direction, id *jsonrpc2.ID, elapsed time.Duration, method string, payload *json.RawMessage, err *jsonrpc2.Error) {
}
func NewClient(stream jsonrpc2.Stream, client Client) (*jsonrpc2.Conn, Server, xlog.Logger) {
log := xlog.New(NewLogger(client))
conn := jsonrpc2.NewConn(stream)