1
0
mirror of https://github.com/golang/go synced 2024-11-25 02:57:57 -07:00

rpc: make more tolerant of errors.

Add Error type to enable clients to distinguish
between local and remote errors.
Also return "connection shut down error" after
the first error return rather than returning the
same error each time.

R=r
CC=golang-dev
https://golang.org/cl/4080058
This commit is contained in:
Roger Peppe 2011-02-09 10:57:59 -08:00 committed by Rob Pike
parent d916cca327
commit d40ae94993
3 changed files with 131 additions and 120 deletions

View File

@ -15,6 +15,16 @@ import (
"sync"
)
// ServerError represents an error that has been returned from
// the remote side of the RPC connection.
type ServerError string
func (e ServerError) String() string {
return string(e)
}
const ErrShutdown = os.ErrorString("connection is shut down")
// Call represents an active RPC.
type Call struct {
ServiceMethod string // The name of the service and method to call.
@ -30,12 +40,12 @@ type Call struct {
// with a single Client.
type Client struct {
mutex sync.Mutex // protects pending, seq
shutdown os.Error // non-nil if the client is shut down
sending sync.Mutex
seq uint64
codec ClientCodec
pending map[uint64]*Call
closing bool
shutdown bool
}
// A ClientCodec implements writing of RPC requests and
@ -55,8 +65,8 @@ type ClientCodec interface {
func (client *Client) send(c *Call) {
// Register this call.
client.mutex.Lock()
if client.shutdown != nil {
c.Error = client.shutdown
if client.shutdown {
c.Error = ErrShutdown
client.mutex.Unlock()
c.done()
return
@ -79,6 +89,7 @@ func (client *Client) send(c *Call) {
func (client *Client) input() {
var err os.Error
var marker struct{}
for err == nil {
response := new(Response)
err = client.codec.ReadResponseHeader(response)
@ -93,20 +104,27 @@ func (client *Client) input() {
c := client.pending[seq]
client.pending[seq] = c, false
client.mutex.Unlock()
if response.Error == "" {
err = client.codec.ReadResponseBody(c.Reply)
if response.Error != "" {
c.Error = os.ErrorString(response.Error)
} else if err != nil {
c.Error = err
if err != nil {
c.Error = os.ErrorString("reading body " + err.String())
}
} else {
// Empty strings should turn into nil os.Errors
c.Error = nil
// We've got an error response. Give this to the request;
// any subsequent requests will get the ReadResponseBody
// error if there is one.
c.Error = ServerError(response.Error)
err = client.codec.ReadResponseBody(&marker)
if err != nil {
err = os.ErrorString("reading error body: " + err.String())
}
}
c.done()
}
// Terminate pending calls.
client.mutex.Lock()
client.shutdown = err
client.shutdown = true
for _, call := range client.pending {
call.Error = err
call.done()
@ -209,10 +227,11 @@ func Dial(network, address string) (*Client, os.Error) {
}
func (client *Client) Close() os.Error {
if client.shutdown != nil || client.closing {
return os.ErrorString("rpc: already closed")
}
client.mutex.Lock()
if client.shutdown || client.closing {
client.mutex.Unlock()
return ErrShutdown
}
client.closing = true
client.mutex.Unlock()
return client.codec.Close()
@ -239,8 +258,8 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface
}
}
c.Done = done
if client.shutdown != nil {
c.Error = client.shutdown
if client.shutdown {
c.Error = ErrShutdown
c.done()
return c
}
@ -250,8 +269,8 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface
// Call invokes the named function, waits for it to complete, and returns its error status.
func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) os.Error {
if client.shutdown != nil {
return client.shutdown
if client.shutdown {
return ErrShutdown
}
call := <-client.Go(serviceMethod, args, reply, nil).Done
return call.Error

View File

@ -299,7 +299,7 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E
// A value sent as a placeholder for the response when the server receives an invalid request.
type InvalidRequest struct {
marker int
Marker int
}
var invalidRequest = InvalidRequest{1}
@ -316,6 +316,7 @@ func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec Se
resp.ServiceMethod = req.ServiceMethod
if errmsg != "" {
resp.Error = errmsg
reply = invalidRequest
}
resp.Seq = req.Seq
sending.Lock()
@ -389,9 +390,28 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) {
func (server *Server) ServeCodec(codec ServerCodec) {
sending := new(sync.Mutex)
for {
// Grab the request header.
req := new(Request)
err := codec.ReadRequestHeader(req)
req, service, mtype, err := server.readRequest(codec)
if err != nil {
if err != os.EOF {
log.Println("rpc:", err)
}
if err == os.EOF || err == io.ErrUnexpectedEOF {
break
}
// discard body
codec.ReadRequestBody(new(interface{}))
// send a response if we actually managed to read a header.
if req != nil {
sendResponse(sending, req, invalidRequest, codec, err.String())
}
continue
}
// Decode the argument value.
argv := _new(mtype.ArgType)
replyv := _new(mtype.ReplyType)
err = codec.ReadRequestBody(argv.Interface())
if err != nil {
if err == os.EOF || err == io.ErrUnexpectedEOF {
if err == io.ErrUnexpectedEOF {
@ -399,44 +419,45 @@ func (server *Server) ServeCodec(codec ServerCodec) {
}
break
}
s := "rpc: server cannot decode request: " + err.String()
sendResponse(sending, req, invalidRequest, codec, s)
break
}
serviceMethod := strings.Split(req.ServiceMethod, ".", -1)
if len(serviceMethod) != 2 {
s := "rpc: service/method request ill-formed: " + req.ServiceMethod
sendResponse(sending, req, invalidRequest, codec, s)
continue
}
// Look up the request.
server.Lock()
service, ok := server.serviceMap[serviceMethod[0]]
server.Unlock()
if !ok {
s := "rpc: can't find service " + req.ServiceMethod
sendResponse(sending, req, invalidRequest, codec, s)
continue
}
mtype, ok := service.method[serviceMethod[1]]
if !ok {
s := "rpc: can't find method " + req.ServiceMethod
sendResponse(sending, req, invalidRequest, codec, s)
continue
}
// Decode the argument value.
argv := _new(mtype.ArgType)
replyv := _new(mtype.ReplyType)
err = codec.ReadRequestBody(argv.Interface())
if err != nil {
log.Println("rpc: tearing down", serviceMethod[0], "connection:", err)
sendResponse(sending, req, replyv.Interface(), codec, err.String())
break
continue
}
go service.call(sending, mtype, req, argv, replyv, codec)
}
codec.Close()
}
func (server *Server) readRequest(codec ServerCodec) (req *Request, service *service, mtype *methodType, err os.Error) {
// Grab the request header.
req = new(Request)
err = codec.ReadRequestHeader(req)
if err != nil {
req = nil
if err == os.EOF || err == io.ErrUnexpectedEOF {
return
}
err = os.ErrorString("rpc: server cannot decode request: " + err.String())
return
}
serviceMethod := strings.Split(req.ServiceMethod, ".", -1)
if len(serviceMethod) != 2 {
err = os.ErrorString("rpc: service/method request ill-formed: " + req.ServiceMethod)
return
}
// Look up the request.
server.Lock()
service = server.serviceMap[serviceMethod[0]]
server.Unlock()
if service == nil {
err = os.ErrorString("rpc: can't find service " + req.ServiceMethod)
return
}
mtype = service.method[serviceMethod[1]]
if mtype == nil {
err = os.ErrorString("rpc: can't find method " + req.ServiceMethod)
}
return
}
// Accept accepts connections on the listener and serves requests
// for each incoming connection. Accept blocks; the caller typically

View File

@ -134,14 +134,25 @@ func testRPC(t *testing.T, addr string) {
t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
}
// Nonexistent method
args = &Args{7, 0}
reply = new(Reply)
err = client.Call("Arith.BadOperation", args, reply)
// expect an error
if err == nil {
t.Error("BadOperation: expected error")
} else if !strings.HasPrefix(err.String(), "rpc: can't find method ") {
t.Errorf("BadOperation: expected can't find method error; got %q", err)
}
// Unknown service
args = &Args{7, 8}
reply = new(Reply)
err = client.Call("Arith.Mul", args, reply)
if err != nil {
t.Errorf("Mul: expected no error but got string %q", err.String())
}
if reply.C != args.A*args.B {
t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
err = client.Call("Arith.Unknown", args, reply)
if err == nil {
t.Error("expected error calling unknown service")
} else if strings.Index(err.String(), "method") < 0 {
t.Error("expected error about method; got", err)
}
// Out of order.
@ -178,6 +189,15 @@ func testRPC(t *testing.T, addr string) {
t.Error("Div: expected divide by zero error; got", err)
}
// Bad type.
reply = new(Reply)
err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
if err == nil {
t.Error("expected error calling Arith.Add with wrong arg type")
} else if strings.Index(err.String(), "type") < 0 {
t.Error("expected error about type; got", err)
}
// Non-struct argument
const Val = 12345
str := fmt.Sprint(Val)
@ -200,9 +220,19 @@ func testRPC(t *testing.T, addr string) {
if str != expect {
t.Errorf("String: expected %s got %s", expect, str)
}
args = &Args{7, 8}
reply = new(Reply)
err = client.Call("Arith.Mul", args, reply)
if err != nil {
t.Errorf("Mul: expected no error but got string %q", err.String())
}
if reply.C != args.A*args.B {
t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
}
}
func TestHTTPRPC(t *testing.T) {
func TestHTTP(t *testing.T) {
once.Do(startServer)
testHTTPRPC(t, "")
newOnce.Do(startNewServer)
@ -233,65 +263,6 @@ func testHTTPRPC(t *testing.T, path string) {
}
}
func TestCheckUnknownService(t *testing.T) {
once.Do(startServer)
conn, err := net.Dial("tcp", "", serverAddr)
if err != nil {
t.Fatal("dialing:", err)
}
client := NewClient(conn)
args := &Args{7, 8}
reply := new(Reply)
err = client.Call("Unknown.Add", args, reply)
if err == nil {
t.Error("expected error calling unknown service")
} else if strings.Index(err.String(), "service") < 0 {
t.Error("expected error about service; got", err)
}
}
func TestCheckUnknownMethod(t *testing.T) {
once.Do(startServer)
conn, err := net.Dial("tcp", "", serverAddr)
if err != nil {
t.Fatal("dialing:", err)
}
client := NewClient(conn)
args := &Args{7, 8}
reply := new(Reply)
err = client.Call("Arith.Unknown", args, reply)
if err == nil {
t.Error("expected error calling unknown service")
} else if strings.Index(err.String(), "method") < 0 {
t.Error("expected error about method; got", err)
}
}
func TestCheckBadType(t *testing.T) {
once.Do(startServer)
conn, err := net.Dial("tcp", "", serverAddr)
if err != nil {
t.Fatal("dialing:", err)
}
client := NewClient(conn)
reply := new(Reply)
err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
if err == nil {
t.Error("expected error calling Arith.Add with wrong arg type")
} else if strings.Index(err.String(), "type") < 0 {
t.Error("expected error about type; got", err)
}
}
type ArgNotPointer int
type ReplyNotPointer int
type ArgNotPublic int