diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go index a964113c67..cdbc3dce7c 100644 --- a/internal/jsonrpc2/jsonrpc2.go +++ b/internal/jsonrpc2/jsonrpc2.go @@ -35,12 +35,26 @@ type Conn struct { pendingMu sync.Mutex // protects the pending map pending map[ID]chan *wireResponse handlingMu sync.Mutex // protects the handling map - handling map[ID]handling + handling map[ID]*Request } +type requestState int + +const ( + requestWaiting = requestState(iota) + requestSerial + requestParallel + requestReplied + requestDone +) + // Request is sent to a server to represent a Call or Notify operaton. type Request struct { - conn *Conn + conn *Conn + cancel context.CancelFunc + start time.Time + state requestState + nextRequest chan struct{} // Method is a string containing the method name to invoke. Method string @@ -52,12 +66,6 @@ type Request struct { ID *ID } -type queueEntry struct { - ctx context.Context - r *Request - size int64 -} - // Handler is an option you can pass to NewConn to handle incoming requests. // If the request returns false from IsNotify then the Handler must eventually // call Reply on the Conn with the supplied request. @@ -75,7 +83,6 @@ type Canceler func(context.Context, *Conn, ID) type rpcStats struct { server bool method string - ctx context.Context span trace.Span start time.Time received int64 @@ -87,13 +94,15 @@ type statsKeyType string const rpcStatsKey = statsKeyType("rpcStatsKey") func start(ctx context.Context, server bool, method string, id *ID) (context.Context, *rpcStats) { + if method == "" { + panic("no method in rpc stats") + } s := &rpcStats{ server: server, method: method, - ctx: ctx, start: time.Now(), } - s.ctx = context.WithValue(s.ctx, rpcStatsKey, s) + ctx = context.WithValue(ctx, rpcStatsKey, s) tags := make([]tag.Mutator, 0, 4) tags = append(tags, tag.Upsert(telemetry.KeyMethod, method)) mode := telemetry.Outbound @@ -106,10 +115,10 @@ func start(ctx context.Context, server bool, method string, id *ID) (context.Con if id != nil { tags = append(tags, tag.Upsert(telemetry.KeyRPCID, id.String())) } - s.ctx, s.span = trace.StartSpan(ctx, method, trace.WithSpanKind(spanKind)) - s.ctx, _ = tag.New(s.ctx, tags...) - stats.Record(s.ctx, telemetry.Started.M(1)) - return s.ctx, s + ctx, s.span = trace.StartSpan(ctx, method, trace.WithSpanKind(spanKind)) + ctx, _ = tag.New(ctx, tags...) + stats.Record(ctx, telemetry.Started.M(1)) + return ctx, s } func (s *rpcStats) end(ctx context.Context, err *error) { @@ -145,11 +154,11 @@ func NewConn(s Stream) *Conn { conn := &Conn{ stream: s, pending: make(map[ID]chan *wireResponse), - handling: make(map[ID]handling), + handling: make(map[ID]*Request), } // the default handler reports a method error conn.Handler = func(ctx context.Context, r *Request) { - if r.IsNotify() { + if !r.IsNotify() { r.Reply(ctx, nil, NewErrorf(CodeMethodNotFound, "method %q not found", r.Method)) } } @@ -273,28 +282,38 @@ func (r *Request) IsNotify() bool { return r.ID == nil } +// Parallel indicates that the system is now allowed to process other requests +// in parallel with this one. +// It is safe to call any number of times, but must only be called from the +// request handling go routine. +// It is implied by both reply and by the handler returning. +func (r *Request) Parallel() { + if r.state >= requestParallel { + return + } + r.state = requestParallel + close(r.nextRequest) +} + // Reply sends a reply to the given request. // It is an error to call this if request was not a call. // You must call this exactly once for any given request. +// It should only be called from the handler go routine. // If err is set then result will be ignored. func (r *Request) Reply(ctx context.Context, result interface{}, err error) error { - ctx, st := trace.StartSpan(ctx, r.Method+":reply", trace.WithSpanKind(trace.SpanKindClient)) - defer st.End() - + if r.state >= requestReplied { + return fmt.Errorf("reply invoked more than once") + } if r.IsNotify() { return fmt.Errorf("reply not invoked with a valid call") } - r.conn.handlingMu.Lock() - handling, found := r.conn.handling[*r.ID] - if found { - delete(r.conn.handling, *r.ID) - } - r.conn.handlingMu.Unlock() - if !found { - return fmt.Errorf("not a call in progress: %v", r.ID) - } + ctx, st := trace.StartSpan(ctx, r.Method+":reply", trace.WithSpanKind(trace.SpanKindClient)) + defer st.End() - elapsed := time.Since(handling.start) + r.Parallel() + r.state = requestReplied + + elapsed := time.Since(r.start) var raw *json.RawMessage if err == nil { raw, err = marshalToRaw(result) @@ -319,10 +338,9 @@ func (r *Request) Reply(ctx context.Context, result interface{}, err error) erro v := ctx.Value(rpcStatsKey) if v != nil { - s := v.(*rpcStats) - s.sent += n + v.(*rpcStats).sent += n } else { - //panic("no stats available in reply") + panic("no stats available in reply") } if err != nil { @@ -333,10 +351,17 @@ func (r *Request) Reply(ctx context.Context, result interface{}, err error) erro return nil } -type handling struct { - request *Request - cancel context.CancelFunc - start time.Time +func (c *Conn) setHandling(r *Request, active bool) { + if r.ID == nil { + return + } + r.conn.handlingMu.Lock() + defer r.conn.handlingMu.Unlock() + if active { + r.conn.handling[*r.ID] = r + } else { + delete(r.conn.handling, *r.ID) + } } // combined has all the fields of both Request and Response. @@ -350,40 +375,13 @@ type combined struct { Error *Error `json:"error,omitempty"` } -func (c *Conn) deliver(ctx context.Context, q chan queueEntry, request *Request, size int64) bool { - e := queueEntry{ctx: ctx, r: request, size: size} - if !c.RejectIfOverloaded { - q <- e - return true - } - select { - case q <- e: - return true - default: - return false - } -} - // Run blocks until the connection is terminated, and returns any error that // caused the termination. // It must be called exactly once for each Conn. // It returns only when the reader is closed or there is an error in the stream. func (c *Conn) Run(ctx context.Context) error { - q := make(chan queueEntry, c.Capacity) - defer close(q) - // start the queue processor - go func() { - // TODO: idle notification? - for e := range q { - if e.ctx.Err() != nil { - continue - } - ctx, rpcStats := start(ctx, true, e.r.Method, e.r.ID) - rpcStats.received += e.size - c.Handler(ctx, e.r) - rpcStats.end(ctx, nil) - } - }() + nextRequest := make(chan struct{}) + close(nextRequest) for { // get the data for a message data, n, err := c.stream.Read(ctx) @@ -403,33 +401,36 @@ func (c *Conn) Run(ctx context.Context) error { switch { case msg.Method != "": // if method is set it must be a request - request := &Request{ - conn: c, - Method: msg.Method, - Params: msg.Params, - ID: msg.ID, - } - if request.IsNotify() { - c.Logger(Receive, request.ID, -1, request.Method, request.Params, nil) - // we have a Notify, add to the processor queue - c.deliver(ctx, q, request, n) - //TODO: log when we drop a message? - } else { - // we have a Call, add to the processor queue - reqCtx, cancelReq := context.WithCancel(ctx) - c.handlingMu.Lock() - c.handling[*request.ID] = handling{ - request: request, - cancel: cancelReq, - start: time.Now(), - } - c.handlingMu.Unlock() - c.Logger(Receive, request.ID, -1, request.Method, request.Params, nil) - if !c.deliver(reqCtx, q, request, n) { - // queue is full, reject the message by directly replying - request.Reply(ctx, nil, NewErrorf(CodeServerOverloaded, "no room in queue")) - } + reqCtx, cancelReq := context.WithCancel(ctx) + reqCtx, rpcStats := start(reqCtx, true, msg.Method, msg.ID) + rpcStats.received += n + thisRequest := nextRequest + nextRequest = make(chan struct{}) + req := &Request{ + conn: c, + cancel: cancelReq, + nextRequest: nextRequest, + start: time.Now(), + Method: msg.Method, + Params: msg.Params, + ID: msg.ID, } + c.setHandling(req, true) + go func() { + <-thisRequest + req.state = requestSerial + defer func() { + c.setHandling(req, false) + if !req.IsNotify() && req.state < requestReplied { + req.Reply(reqCtx, nil, NewErrorf(CodeInternalError, "method %q did not reply", req.Method)) + } + req.Parallel() + rpcStats.end(reqCtx, nil) + cancelReq() + }() + c.Logger(Receive, req.ID, -1, req.Method, req.Params, nil) + c.Handler(reqCtx, req) + }() case msg.ID != nil: // we have a response, get the pending entry from the map c.pendingMu.Lock()