diff --git a/internal/jsonrpc2/handler.go b/internal/jsonrpc2/handler.go index 598c79b570..bf8bfb3aa6 100644 --- a/internal/jsonrpc2/handler.go +++ b/internal/jsonrpc2/handler.go @@ -38,9 +38,9 @@ type Handler interface { // response // Request is called near the start of processing any request. - Request(ctx context.Context, direction Direction, r *WireRequest) context.Context + Request(ctx context.Context, conn *Conn, direction Direction, r *WireRequest) context.Context // Response is called near the start of processing any response. - Response(ctx context.Context, direction Direction, r *WireResponse) context.Context + Response(ctx context.Context, conn *Conn, direction Direction, r *WireResponse) context.Context // Done is called when any request is fully processed. // For calls, this means the response has also been processed, for notifies // this is as soon as the message has been written to the stream. @@ -90,11 +90,11 @@ func (EmptyHandler) Cancel(ctx context.Context, conn *Conn, id ID, cancelled boo return false } -func (EmptyHandler) Request(ctx context.Context, direction Direction, r *WireRequest) context.Context { +func (EmptyHandler) Request(ctx context.Context, conn *Conn, direction Direction, r *WireRequest) context.Context { return ctx } -func (EmptyHandler) Response(ctx context.Context, direction Direction, r *WireResponse) context.Context { +func (EmptyHandler) Response(ctx context.Context, conn *Conn, direction Direction, r *WireResponse) context.Context { return ctx } diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go index f9c402168d..62c8141b44 100644 --- a/internal/jsonrpc2/jsonrpc2.go +++ b/internal/jsonrpc2/jsonrpc2.go @@ -110,7 +110,7 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (e return fmt.Errorf("marshalling notify request: %v", err) } for _, h := range c.handlers { - ctx = h.Request(ctx, Send, request) + ctx = h.Request(ctx, c, Send, request) } defer func() { for _, h := range c.handlers { @@ -145,7 +145,7 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface return fmt.Errorf("marshalling call request: %v", err) } for _, h := range c.handlers { - ctx = h.Request(ctx, Send, request) + ctx = h.Request(ctx, c, Send, request) } // we have to add ourselves to the pending map before we send, otherwise we // are racing the response @@ -175,7 +175,7 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface select { case response := <-rchan: for _, h := range c.handlers { - ctx = h.Response(ctx, Receive, response) + ctx = h.Response(ctx, c, Receive, response) } // is it an error response? if response.Error != nil { @@ -262,7 +262,7 @@ func (r *Request) Reply(ctx context.Context, result interface{}, err error) erro return err } for _, h := range r.conn.handlers { - ctx = h.Response(ctx, Send, response) + ctx = h.Response(ctx, r.conn, Send, response) } n, err := r.conn.stream.Write(ctx, data) for _, h := range r.conn.handlers { @@ -347,7 +347,7 @@ func (c *Conn) Run(runCtx context.Context) error { }, } for _, h := range c.handlers { - reqCtx = h.Request(reqCtx, Receive, &req.WireRequest) + reqCtx = h.Request(reqCtx, c, Receive, &req.WireRequest) reqCtx = h.Read(reqCtx, n) } c.setHandling(req, true) diff --git a/internal/jsonrpc2/jsonrpc2_test.go b/internal/jsonrpc2/jsonrpc2_test.go index 89252fd10f..192a5e805d 100644 --- a/internal/jsonrpc2/jsonrpc2_test.go +++ b/internal/jsonrpc2/jsonrpc2_test.go @@ -164,7 +164,7 @@ func (h *handle) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.ID return false } -func (h *handle) Request(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context { +func (h *handle) Request(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context { if h.log { if r.ID != nil { log.Printf("%v call [%v] %s %v", direction, r.ID, r.Method, r.Params) @@ -177,7 +177,7 @@ func (h *handle) Request(ctx context.Context, direction jsonrpc2.Direction, r *j return ctx } -func (h *handle) Response(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context { +func (h *handle) Response(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context { if h.log { method := ctx.Value("method") elapsed := time.Since(ctx.Value("start").(time.Time)) diff --git a/internal/lsp/cmd/serve.go b/internal/lsp/cmd/serve.go index 70370afefd..ceca1f1d11 100644 --- a/internal/lsp/cmd/serve.go +++ b/internal/lsp/cmd/serve.go @@ -149,7 +149,7 @@ func (h *handler) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.I return false } -func (h *handler) Request(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context { +func (h *handler) Request(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context { if r.Method == "" { panic("no method in rpc stats") } @@ -174,7 +174,7 @@ func (h *handler) Request(ctx context.Context, direction jsonrpc2.Direction, r * return ctx } -func (h *handler) Response(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context { +func (h *handler) Response(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context { return ctx } diff --git a/internal/lsp/protocol/protocol.go b/internal/lsp/protocol/protocol.go index 80be75d650..8aaa3ef2f2 100644 --- a/internal/lsp/protocol/protocol.go +++ b/internal/lsp/protocol/protocol.go @@ -6,6 +6,7 @@ package protocol import ( "context" + "encoding/json" "golang.org/x/tools/internal/jsonrpc2" "golang.org/x/tools/internal/telemetry/log" @@ -13,6 +14,11 @@ import ( "golang.org/x/tools/internal/xcontext" ) +const ( + // RequestCancelledError should be used when a request is cancelled early. + RequestCancelledError = -32800 +) + type DocumentUri = string type canceller struct{ jsonrpc2.EmptyHandler } @@ -27,6 +33,18 @@ type serverHandler struct { server Server } +func (canceller) Request(ctx context.Context, conn *jsonrpc2.Conn, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context { + if direction == jsonrpc2.Receive && r.Method == "$/cancelRequest" { + var params CancelParams + if err := json.Unmarshal(*r.Params, ¶ms); err != nil { + log.Error(ctx, "", err) + } else { + conn.Cancel(params.ID) + } + } + return ctx +} + func (canceller) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.ID, cancelled bool) bool { if cancelled { return false diff --git a/internal/lsp/protocol/tsclient.go b/internal/lsp/protocol/tsclient.go index 113989ff07..969e05507f 100644 --- a/internal/lsp/protocol/tsclient.go +++ b/internal/lsp/protocol/tsclient.go @@ -8,6 +8,7 @@ import ( "golang.org/x/tools/internal/jsonrpc2" "golang.org/x/tools/internal/telemetry/log" + "golang.org/x/tools/internal/xcontext" ) type Client interface { @@ -27,15 +28,12 @@ func (h clientHandler) Deliver(ctx context.Context, r *jsonrpc2.Request, deliver if delivered { return false } - switch r.Method { - case "$/cancelRequest": - var params CancelParams - if err := json.Unmarshal(*r.Params, ¶ms); err != nil { - sendParseError(ctx, r, err) - return true - } - r.Conn().Cancel(params.ID) + if ctx.Err() != nil { + ctx := xcontext.Detach(ctx) + r.Reply(ctx, nil, jsonrpc2.NewErrorf(RequestCancelledError, "")) return true + } + switch r.Method { case "window/showMessage": // notif var params ShowMessageParams if err := json.Unmarshal(*r.Params, ¶ms); err != nil { diff --git a/internal/lsp/protocol/tsserver.go b/internal/lsp/protocol/tsserver.go index 1882adecaa..9ee423b5c7 100644 --- a/internal/lsp/protocol/tsserver.go +++ b/internal/lsp/protocol/tsserver.go @@ -8,6 +8,7 @@ import ( "golang.org/x/tools/internal/jsonrpc2" "golang.org/x/tools/internal/telemetry/log" + "golang.org/x/tools/internal/xcontext" ) type Server interface { @@ -46,13 +47,13 @@ type Server interface { Symbol(context.Context, *WorkspaceSymbolParams) ([]SymbolInformation, error) CodeLens(context.Context, *CodeLensParams) ([]CodeLens, error) ResolveCodeLens(context.Context, *CodeLens) (*CodeLens, error) + DocumentLink(context.Context, *DocumentLinkParams) ([]DocumentLink, error) + ResolveDocumentLink(context.Context, *DocumentLink) (*DocumentLink, error) Formatting(context.Context, *DocumentFormattingParams) ([]TextEdit, error) RangeFormatting(context.Context, *DocumentRangeFormattingParams) ([]TextEdit, error) OnTypeFormatting(context.Context, *DocumentOnTypeFormattingParams) ([]TextEdit, error) Rename(context.Context, *RenameParams) (*WorkspaceEdit, error) PrepareRename(context.Context, *PrepareRenameParams) (*Range, error) - DocumentLink(context.Context, *DocumentLinkParams) ([]DocumentLink, error) - ResolveDocumentLink(context.Context, *DocumentLink) (*DocumentLink, error) ExecuteCommand(context.Context, *ExecuteCommandParams) (interface{}, error) } @@ -60,15 +61,12 @@ func (h serverHandler) Deliver(ctx context.Context, r *jsonrpc2.Request, deliver if delivered { return false } - switch r.Method { - case "$/cancelRequest": - var params CancelParams - if err := json.Unmarshal(*r.Params, ¶ms); err != nil { - sendParseError(ctx, r, err) - return true - } - r.Conn().Cancel(params.ID) + if ctx.Err() != nil { + ctx := xcontext.Detach(ctx) + r.Reply(ctx, nil, jsonrpc2.NewErrorf(RequestCancelledError, "")) return true + } + switch r.Method { case "workspace/didChangeWorkspaceFolders": // notif var params DidChangeWorkspaceFoldersParams if err := json.Unmarshal(*r.Params, ¶ms); err != nil { @@ -435,6 +433,28 @@ func (h serverHandler) Deliver(ctx context.Context, r *jsonrpc2.Request, deliver log.Error(ctx, "", err) } return true + case "textDocument/documentLink": // req + var params DocumentLinkParams + if err := json.Unmarshal(*r.Params, ¶ms); err != nil { + sendParseError(ctx, r, err) + return true + } + resp, err := h.server.DocumentLink(ctx, ¶ms) + if err := r.Reply(ctx, resp, err); err != nil { + log.Error(ctx, "", err) + } + return true + case "documentLink/resolve": // req + var params DocumentLink + if err := json.Unmarshal(*r.Params, ¶ms); err != nil { + sendParseError(ctx, r, err) + return true + } + resp, err := h.server.ResolveDocumentLink(ctx, ¶ms) + if err := r.Reply(ctx, resp, err); err != nil { + log.Error(ctx, "", err) + } + return true case "textDocument/formatting": // req var params DocumentFormattingParams if err := json.Unmarshal(*r.Params, ¶ms); err != nil { @@ -490,28 +510,6 @@ func (h serverHandler) Deliver(ctx context.Context, r *jsonrpc2.Request, deliver log.Error(ctx, "", err) } return true - case "textDocument/documentLink": // req - var params DocumentLinkParams - if err := json.Unmarshal(*r.Params, ¶ms); err != nil { - sendParseError(ctx, r, err) - return true - } - resp, err := h.server.DocumentLink(ctx, ¶ms) - if err := r.Reply(ctx, resp, err); err != nil { - log.Error(ctx, "", err) - } - return true - case "documentLink/resolve": // req - var params DocumentLink - if err := json.Unmarshal(*r.Params, ¶ms); err != nil { - sendParseError(ctx, r, err) - return true - } - resp, err := h.server.ResolveDocumentLink(ctx, ¶ms) - if err := r.Reply(ctx, resp, err); err != nil { - log.Error(ctx, "", err) - } - return true case "workspace/executeCommand": // req var params ExecuteCommandParams if err := json.Unmarshal(*r.Params, ¶ms); err != nil { @@ -756,6 +754,22 @@ func (s *serverDispatcher) ResolveCodeLens(ctx context.Context, params *CodeLens return &result, nil } +func (s *serverDispatcher) DocumentLink(ctx context.Context, params *DocumentLinkParams) ([]DocumentLink, error) { + var result []DocumentLink + if err := s.Conn.Call(ctx, "textDocument/documentLink", params, &result); err != nil { + return nil, err + } + return result, nil +} + +func (s *serverDispatcher) ResolveDocumentLink(ctx context.Context, params *DocumentLink) (*DocumentLink, error) { + var result DocumentLink + if err := s.Conn.Call(ctx, "documentLink/resolve", params, &result); err != nil { + return nil, err + } + return &result, nil +} + func (s *serverDispatcher) Formatting(ctx context.Context, params *DocumentFormattingParams) ([]TextEdit, error) { var result []TextEdit if err := s.Conn.Call(ctx, "textDocument/formatting", params, &result); err != nil { @@ -796,22 +810,6 @@ func (s *serverDispatcher) PrepareRename(ctx context.Context, params *PrepareRen return &result, nil } -func (s *serverDispatcher) DocumentLink(ctx context.Context, params *DocumentLinkParams) ([]DocumentLink, error) { - var result []DocumentLink - if err := s.Conn.Call(ctx, "textDocument/documentLink", params, &result); err != nil { - return nil, err - } - return result, nil -} - -func (s *serverDispatcher) ResolveDocumentLink(ctx context.Context, params *DocumentLink) (*DocumentLink, error) { - var result DocumentLink - if err := s.Conn.Call(ctx, "documentLink/resolve", params, &result); err != nil { - return nil, err - } - return &result, nil -} - func (s *serverDispatcher) ExecuteCommand(ctx context.Context, params *ExecuteCommandParams) (interface{}, error) { var result interface{} if err := s.Conn.Call(ctx, "workspace/executeCommand", params, &result); err != nil { diff --git a/internal/lsp/protocol/typescript/requests.ts b/internal/lsp/protocol/typescript/requests.ts index 0115568f70..373b925619 100644 --- a/internal/lsp/protocol/typescript/requests.ts +++ b/internal/lsp/protocol/typescript/requests.ts @@ -224,7 +224,8 @@ function output(side: side) { "golang.org/x/tools/internal/jsonrpc2" "golang.org/x/tools/internal/telemetry/log" - ) + "golang.org/x/tools/internal/xcontext" + ) `); const a = side.name[0].toUpperCase() + side.name.substring(1) f(`type ${a} interface {`); @@ -235,15 +236,12 @@ function output(side: side) { if delivered { return false } - switch r.Method { - case "$/cancelRequest": - var params CancelParams - if err := json.Unmarshal(*r.Params, ¶ms); err != nil { - sendParseError(ctx, r, err) - return true - } - r.Conn().Cancel(params.ID) - return true`); + if ctx.Err() != nil { + ctx := xcontext.Detach(ctx) + r.Reply(ctx, nil, jsonrpc2.NewErrorf(RequestCancelledError, "")) + return true + } + switch r.Method {`); side.cases.forEach((v) => {f(v)}); f(` default: