diff --git a/src/pkg/net/rpc/client.go b/src/pkg/net/rpc/client.go index abc1e59cd51..69c44076950 100644 --- a/src/pkg/net/rpc/client.go +++ b/src/pkg/net/rpc/client.go @@ -31,8 +31,7 @@ type Call struct { Args interface{} // The argument to the function (*struct). Reply interface{} // The reply from the function (*struct). Error error // After completion, the error status. - Done chan *Call // Strobes when call is complete; value is the error status. - seq uint64 + Done chan *Call // Strobes when call is complete. } // Client represents an RPC Client. @@ -65,28 +64,33 @@ type ClientCodec interface { Close() error } -func (client *Client) send(c *Call) { +func (client *Client) send(call *Call) { + client.sending.Lock() + defer client.sending.Unlock() + // Register this call. client.mutex.Lock() if client.shutdown { - c.Error = ErrShutdown + call.Error = ErrShutdown client.mutex.Unlock() - c.done() + call.done() return } - c.seq = client.seq + seq := client.seq client.seq++ - client.pending[c.seq] = c + client.pending[seq] = call client.mutex.Unlock() // Encode and send the request. - client.sending.Lock() - defer client.sending.Unlock() - client.request.Seq = c.seq - client.request.ServiceMethod = c.ServiceMethod - if err := client.codec.WriteRequest(&client.request, c.Args); err != nil { - c.Error = err - c.done() + client.request.Seq = seq + client.request.ServiceMethod = call.ServiceMethod + err := client.codec.WriteRequest(&client.request, call.Args) + if err != nil { + client.mutex.Lock() + delete(client.pending, seq) + client.mutex.Unlock() + call.Error = err + call.done() } } @@ -104,36 +108,39 @@ func (client *Client) input() { } seq := response.Seq client.mutex.Lock() - c := client.pending[seq] + call := client.pending[seq] delete(client.pending, seq) client.mutex.Unlock() if response.Error == "" { - err = client.codec.ReadResponseBody(c.Reply) + err = client.codec.ReadResponseBody(call.Reply) if err != nil { - c.Error = errors.New("reading body " + err.Error()) + call.Error = errors.New("reading body " + err.Error()) } } else { // We've got an error response. Give this to the request; // any subsequent requests will get the ReadResponseBody // error if there is one. - c.Error = ServerError(response.Error) + call.Error = ServerError(response.Error) err = client.codec.ReadResponseBody(nil) if err != nil { err = errors.New("reading error body: " + err.Error()) } } - c.done() + call.done() } // Terminate pending calls. + client.sending.Lock() client.mutex.Lock() client.shutdown = true + closing := client.closing for _, call := range client.pending { call.Error = err call.done() } client.mutex.Unlock() - if err != io.EOF || !client.closing { + client.sending.Unlock() + if err != io.EOF || !closing { log.Println("rpc: client protocol error:", err) } } @@ -269,20 +276,12 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface } } call.Done = done - if client.shutdown { - call.Error = ErrShutdown - call.done() - return call - } client.send(call) return call } // Call invokes the named function, waits for it to complete, and returns its error status. func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) error { - if client.shutdown { - return ErrShutdown - } call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done return call.Error }