From b07af158a49be49951cd0d5219a31a491ad7b8ca Mon Sep 17 00:00:00 2001 From: Rob Pike Date: Tue, 14 Jul 2009 13:23:14 -0700 Subject: [PATCH] improve rpc code. more robust. more tests. R=rsc DELTA=186 (133 added, 20 deleted, 33 changed) OCL=31611 CL=31616 --- src/pkg/gob/decoder.go | 2 +- src/pkg/rpc/client.go | 40 ++++++++++++----- src/pkg/rpc/server.go | 78 ++++++++++++++++++++++++--------- src/pkg/rpc/server_test.go | 89 +++++++++++++++++++++++++++++++------- 4 files changed, 161 insertions(+), 48 deletions(-) diff --git a/src/pkg/gob/decoder.go b/src/pkg/gob/decoder.go index 8676533e62..ef5481e109 100644 --- a/src/pkg/gob/decoder.go +++ b/src/pkg/gob/decoder.go @@ -73,7 +73,7 @@ func (dec *Decoder) Decode(e interface{}) os.Error { // TODO(r): need to make the decoder work correctly if the wire type is compatible // but not equal to the local type (e.g, extra fields). if info.wire.name() != dec.seen[id].name() { - dec.state.err = os.ErrorString("gob decode: incorrect type for wire value"); + dec.state.err = os.ErrorString("gob decode: incorrect type for wire value: want " + info.wire.name() + "; received " + dec.seen[id].name()); return dec.state.err } diff --git a/src/pkg/rpc/client.go b/src/pkg/rpc/client.go index 725add1a54..196118834b 100644 --- a/src/pkg/rpc/client.go +++ b/src/pkg/rpc/client.go @@ -7,6 +7,7 @@ package rpc import ( "gob"; "io"; + "log"; "os"; "rpc"; "sync"; @@ -25,6 +26,7 @@ type Call struct { // Client represents an RPC Client. type Client struct { sync.Mutex; // protects pending, seq + closed bool; sending sync.Mutex; seq uint64; conn io.ReadWriteCloser; @@ -49,29 +51,35 @@ func (client *Client) send(c *Call) { client.enc.Encode(request); err := client.enc.Encode(c.Args); if err != nil { - panicln("client encode error:", err) + panicln("rpc: client encode error:", err); } client.sending.Unlock(); } func (client *Client) serve() { - for { + var err os.Error; + for err == nil { response := new(Response); - err := client.dec.Decode(response); + err = client.dec.Decode(response); + if err != nil { + if err == os.EOF { + break; + } + break; + } seq := response.Seq; client.Lock(); c := client.pending[seq]; client.pending[seq] = c, false; client.Unlock(); - client.dec.Decode(c.Reply); - if err != nil { - panicln("client decode error:", err) - } + err = client.dec.Decode(c.Reply); c.Error = os.ErrorString(response.Error); - // We don't want to block here, it is the caller's responsibility to make - // sure the channel has enough buffer space. See comment in Start(). + // We don't want to block here. It is the caller's responsibility to make + // sure the channel has enough buffer space. See comment in Go(). doNotBlock := c.Done <- c; } + client.closed = true; + log.Stderr("client protocol error:", err); } // NewClient returns a new Client to handle requests to the @@ -86,9 +94,9 @@ func NewClient(conn io.ReadWriteCloser) *Client { return client; } -// Start invokes the function asynchronously. It returns the Call structure representing +// Go invokes the function asynchronously. It returns the Call structure representing // the invocation. -func (client *Client) Start(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call { +func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call { c := new(Call); c.ServiceMethod = serviceMethod; c.Args = args; @@ -102,12 +110,20 @@ func (client *Client) Start(serviceMethod string, args interface{}, reply interf // RPCs that will be using that channel. } c.Done = done; + if client.closed { + c.Error = os.EOF; + doNotBlock := c.Done <- c; + return c; + } client.send(c); return c; } // Call invokes the named function, waits for it to complete, and returns its error status. func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) os.Error { - call := <-client.Start(serviceMethod, args, reply, nil).Done; + if client.closed { + return os.EOF + } + call := <-client.Go(serviceMethod, args, reply, nil).Done; return call.Error; } diff --git a/src/pkg/rpc/server.go b/src/pkg/rpc/server.go index 3b7a5df707..304f1e2df2 100644 --- a/src/pkg/rpc/server.go +++ b/src/pkg/rpc/server.go @@ -8,6 +8,7 @@ import ( "gob"; "log"; "io"; + "net"; "os"; "reflect"; "strings"; @@ -16,6 +17,8 @@ import ( "utf8"; ) +import "fmt" // TODO DELETE + // Precompute the reflect type for os.Error. Can't use os.Error directly // because Typeof takes an empty interface value. This is annoying. var unusedError *os.Error; @@ -137,29 +140,41 @@ func (server *Server) Add(rcvr interface{}) os.Error { return nil; } +// A value to be sent as a placeholder for the response when we receive invalid request. +type InvalidRequest struct { + marker int +} +var invalidRequest = InvalidRequest{1} + func _new(t *reflect.PtrType) *reflect.PtrValue { v := reflect.MakeZero(t).(*reflect.PtrValue); v.PointTo(reflect.MakeZero(t.Elem())); return v; } +func (s *service) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob.Encoder, errmsg string) { + resp := new(Response); + // Encode the response header + sending.Lock(); + resp.ServiceMethod = req.ServiceMethod; + resp.Error = errmsg; + resp.Seq = req.Seq; + enc.Encode(resp); + // Encode the reply value. + enc.Encode(reply); + sending.Unlock(); +} + func (s *service) call(sending *sync.Mutex, function *reflect.FuncValue, req *Request, argv, replyv reflect.Value, enc *gob.Encoder) { // Invoke the method, providing a new value for the reply. returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv}); // The return value for the method is an os.Error. - err := returnValues[0].Interface(); - resp := new(Response); - if err != nil { - resp.Error = err.(os.Error).String(); + errInter := returnValues[0].Interface(); + errmsg := ""; + if errInter != nil { + errmsg = errInter.(os.Error).String(); } - // Encode the response header - sending.Lock(); - resp.ServiceMethod = req.ServiceMethod; - resp.Seq = req.Seq; - enc.Encode(resp); - // Encode the reply value. - enc.Encode(replyv.Interface()); - sending.Unlock(); + s.sendResponse(sending, req, replyv.Interface(), enc, errmsg); } func (server *Server) serve(conn io.ReadWriteCloser) { @@ -171,33 +186,56 @@ func (server *Server) serve(conn io.ReadWriteCloser) { req := new(Request); err := dec.Decode(req); if err != nil { - panicln("can't handle decode error yet", err.String()); + log.Stderr("rpc: server cannot decode request:", err); + break; } serviceMethod := strings.Split(req.ServiceMethod, ".", 0); if len(serviceMethod) != 2 { - panicln("service/Method request ill-formed:", req.ServiceMethod); + log.Stderr("rpc: service/Method request ill-formed:", req.ServiceMethod); + break; } // Look up the request. service, ok := server.serviceMap[serviceMethod[0]]; if !ok { - panicln("can't find service", serviceMethod[0]); + s := "rpc: can't find service " + req.ServiceMethod; + service.sendResponse(sending, req, invalidRequest, enc, s); + break; } mtype, ok := service.method[serviceMethod[1]]; if !ok { - panicln("can't find method", serviceMethod[1]); + s := "rpc: can't find method " + req.ServiceMethod; + service.sendResponse(sending, req, invalidRequest, enc, s); + break; } method := mtype.method; // Decode the argument value. argv := _new(mtype.argType); + replyv := _new(mtype.replyType); err = dec.Decode(argv.Interface()); if err != nil { - panicln("can't handle payload decode error yet", err.String()); + log.Stderr("tearing down connection:", err); + service.sendResponse(sending, req, replyv.Interface(), enc, err.String()); + break; } - go service.call(sending, method.Func, req, argv, _new(mtype.replyType), enc); + go service.call(sending, method.Func, req, argv, replyv, enc); } + conn.Close(); } -// Serve runs the server on the connection. -func (server *Server) Serve(conn io.ReadWriteCloser) { +// ServeConn runs the server on a single connection. When the connection +// completes, service terminates. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { go server.serve(conn) } + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. +func (server *Server) Accept(lis net.Listener) { + for { + conn, addr, err := lis.Accept(); + if err != nil { + log.Exit("rpc.Serve: accept:", err.String()); // TODO(r): exit? + } + go server.ServeConn(conn); + } +} diff --git a/src/pkg/rpc/server_test.go b/src/pkg/rpc/server_test.go index 0a1ec64be4..51d024d76b 100644 --- a/src/pkg/rpc/server_test.go +++ b/src/pkg/rpc/server_test.go @@ -11,8 +11,10 @@ import ( "io"; "log"; "net"; + "once"; "os"; "rpc"; + "strings"; "testing"; ) @@ -53,15 +55,6 @@ func (t *Arith) Error(args *Args, reply *Reply) os.Error { panicln("ERROR"); } -func run(server *Server, l net.Listener) { - conn, addr, err := l.Accept(); - if err != nil { - println("accept:", err.String()); - os.Exit(1); - } - server.Serve(conn); -} - func startServer() { server := new(Server); server.Add(new(Arith)); @@ -72,14 +65,13 @@ func startServer() { } serverAddr = l.Addr(); log.Stderr("Test RPC server listening on ", serverAddr); -// go http.Serve(l, nil); - go run(server, l); + go server.Accept(l); } func TestRPC(t *testing.T) { var i int; - startServer(); + once.Do(startServer); conn, err := net.Dial("tcp", "", serverAddr); if err != nil { @@ -106,9 +98,9 @@ func TestRPC(t *testing.T) { // Out of order. args = &Args{7,8}; mulReply := new(Reply); - mulCall := client.Start("Arith.Mul", args, mulReply, nil); + mulCall := client.Go("Arith.Mul", args, mulReply, nil); addReply := new(Reply); - addCall := client.Start("Arith.Add", args, addReply, nil); + addCall := client.Go("Arith.Add", args, addReply, nil); <-addCall.Done; if addReply.C != args.A + args.B { @@ -126,6 +118,73 @@ func TestRPC(t *testing.T) { err = client.Call("Arith.Div", args, reply); // expect an error: zero divide if err == nil { - t.Errorf("Div: expected error"); + t.Error("Div: expected error"); + } else if err.String() != "divide by zero" { + t.Error("Div: expected divide by zero error; got", err); + } +} + +func TestCheckUnknownService(t *testing.T) { + var i int; + + once.Do(startServer); + + conn, err := net.Dial("tcp", "", serverAddr); + if err != nil { + t.Fatal("dialing:", err) + } + + client := NewClient(conn); + + args := &Args{7,8}; + reply := new(Reply); + err = client.Call("Unknown.Add", args, reply); + if err == nil { + t.Error("expected error calling unknown service"); + } else if strings.Index(err.String(), "service") < 0 { + t.Error("expected error about service; got", err); + } +} + +func TestCheckUnknownMethod(t *testing.T) { + var i int; + + once.Do(startServer); + + conn, err := net.Dial("tcp", "", serverAddr); + if err != nil { + t.Fatal("dialing:", err) + } + + client := NewClient(conn); + + args := &Args{7,8}; + reply := new(Reply); + err = client.Call("Arith.Unknown", args, reply); + if err == nil { + t.Error("expected error calling unknown service"); + } else if strings.Index(err.String(), "method") < 0 { + t.Error("expected error about method; got", err); + } +} + +func TestCheckBadType(t *testing.T) { + var i int; + + once.Do(startServer); + + conn, err := net.Dial("tcp", "", serverAddr); + if err != nil { + t.Fatal("dialing:", err) + } + + client := NewClient(conn); + + reply := new(Reply); + err = client.Call("Arith.Add", reply, reply); // args, reply would be the correct thing to use + if err == nil { + t.Error("expected error calling Arith.Add with wrong arg type"); + } else if strings.Index(err.String(), "type") < 0 { + t.Error("expected error about type; got", err); } }