From dcff89057bc0e0d7cb14cf414f2df6f5fb1a41ec Mon Sep 17 00:00:00 2001 From: Russ Cox Date: Tue, 27 Apr 2010 13:51:25 -0700 Subject: [PATCH] rpc: abstract client and server encodings R=r CC=golang-dev, rog https://golang.org/cl/811046 --- src/pkg/rpc/client.go | 70 +++++++++++++++++++++++------ src/pkg/rpc/server.go | 102 +++++++++++++++++++++++++++++++----------- 2 files changed, 131 insertions(+), 41 deletions(-) diff --git a/src/pkg/rpc/client.go b/src/pkg/rpc/client.go index 6b2ddd6f0a..d742d099fb 100644 --- a/src/pkg/rpc/client.go +++ b/src/pkg/rpc/client.go @@ -33,13 +33,25 @@ type Client struct { shutdown os.Error // non-nil if the client is shut down sending sync.Mutex seq uint64 - conn io.ReadWriteCloser - enc *gob.Encoder - dec *gob.Decoder + codec ClientCodec pending map[uint64]*Call closing bool } +// A ClientCodec implements writing of RPC requests and +// reading of RPC responses for the client side of an RPC session. +// The client calls WriteRequest to write a request to the connection +// and calls ReadResponseHeader and ReadResponseBody in pairs +// to read responses. The client calls Close when finished with the +// connection. +type ClientCodec interface { + WriteRequest(*Request, interface{}) os.Error + ReadResponseHeader(*Response) os.Error + ReadResponseBody(interface{}) os.Error + + Close() os.Error +} + func (client *Client) send(c *Call) { // Register this call. client.mutex.Lock() @@ -59,9 +71,7 @@ func (client *Client) send(c *Call) { client.sending.Lock() request.Seq = c.seq request.ServiceMethod = c.ServiceMethod - client.enc.Encode(request) - err := client.enc.Encode(c.Args) - if err != nil { + if err := client.codec.WriteRequest(request, c.Args); err != nil { panic("rpc: client encode error: " + err.String()) } client.sending.Unlock() @@ -71,7 +81,7 @@ func (client *Client) input() { var err os.Error for err == nil { response := new(Response) - err = client.dec.Decode(response) + err = client.codec.ReadResponseHeader(response) if err != nil { if err == os.EOF && !client.closing { err = io.ErrUnexpectedEOF @@ -83,7 +93,7 @@ func (client *Client) input() { c := client.pending[seq] client.pending[seq] = c, false client.mutex.Unlock() - err = client.dec.Decode(c.Reply) + err = client.codec.ReadResponseBody(c.Reply) // Empty strings should turn into nil os.Errors if response.Error != "" { c.Error = os.ErrorString(response.Error) @@ -110,17 +120,49 @@ func (client *Client) input() { // NewClient returns a new Client to handle requests to the // set of services at the other end of the connection. func NewClient(conn io.ReadWriteCloser) *Client { - client := new(Client) - client.conn = conn - client.enc = gob.NewEncoder(conn) - client.dec = gob.NewDecoder(conn) - client.pending = make(map[uint64]*Call) + return NewClientWithCodec(&gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)}) +} + +// NewClientWithCodec is like NewClient but uses the specified +// codec to encode requests and decode responses. +func NewClientWithCodec(codec ClientCodec) *Client { + client := &Client{ + codec: codec, + pending: make(map[uint64]*Call), + } go client.input() return client } +type gobClientCodec struct { + rwc io.ReadWriteCloser + dec *gob.Decoder + enc *gob.Encoder +} + +func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) os.Error { + if err := c.enc.Encode(r); err != nil { + return err + } + return c.enc.Encode(body) +} + +func (c *gobClientCodec) ReadResponseHeader(r *Response) os.Error { + return c.dec.Decode(r) +} + +func (c *gobClientCodec) ReadResponseBody(body interface{}) os.Error { + return c.dec.Decode(body) +} + +func (c *gobClientCodec) Close() os.Error { + return c.rwc.Close() +} + + // DialHTTP connects to an HTTP RPC server at the specified network address. func DialHTTP(network, address string) (*Client, os.Error) { + var err os.Error conn, err := net.Dial(network, "", address) if err != nil { return nil, err @@ -156,7 +198,7 @@ func (client *Client) Close() os.Error { client.mutex.Lock() client.closing = true client.mutex.Unlock() - return client.conn.Close() + return client.codec.Close() } // Go invokes the function asynchronously. It returns the Call structure representing diff --git a/src/pkg/rpc/server.go b/src/pkg/rpc/server.go index 413f9a59ac..4c957597bc 100644 --- a/src/pkg/rpc/server.go +++ b/src/pkg/rpc/server.go @@ -272,7 +272,7 @@ func _new(t *reflect.PtrType) *reflect.PtrValue { return v } -func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob.Encoder, errmsg string) { +func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) { resp := new(Response) // Encode the response header resp.ServiceMethod = req.ServiceMethod @@ -281,13 +281,14 @@ func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob } resp.Seq = req.Seq sending.Lock() - enc.Encode(resp) - // Encode the reply value. - enc.Encode(reply) + err := codec.WriteResponse(resp, reply) + if err != nil { + log.Stderr("rpc: writing response: ", err) + } sending.Unlock() } -func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, enc *gob.Encoder) { +func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) { mtype.Lock() mtype.numCalls++ mtype.Unlock() @@ -300,17 +301,40 @@ func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, arg if errInter != nil { errmsg = errInter.(os.Error).String() } - sendResponse(sending, req, replyv.Interface(), enc, errmsg) + sendResponse(sending, req, replyv.Interface(), codec, errmsg) } -func (server *serverType) input(conn io.ReadWriteCloser) { - dec := gob.NewDecoder(conn) - enc := gob.NewEncoder(conn) +type gobServerCodec struct { + rwc io.ReadWriteCloser + dec *gob.Decoder + enc *gob.Encoder +} + +func (c *gobServerCodec) ReadRequestHeader(r *Request) os.Error { + return c.dec.Decode(r) +} + +func (c *gobServerCodec) ReadRequestBody(body interface{}) os.Error { + return c.dec.Decode(body) +} + +func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) os.Error { + if err := c.enc.Encode(r); err != nil { + return err + } + return c.enc.Encode(body) +} + +func (c *gobServerCodec) Close() os.Error { + return c.rwc.Close() +} + +func (server *serverType) input(codec ServerCodec) { sending := new(sync.Mutex) for { // Grab the request header. req := new(Request) - err := dec.Decode(req) + err := codec.ReadRequestHeader(req) if err != nil { if err == os.EOF || err == io.ErrUnexpectedEOF { if err == io.ErrUnexpectedEOF { @@ -319,13 +343,13 @@ func (server *serverType) input(conn io.ReadWriteCloser) { break } s := "rpc: server cannot decode request: " + err.String() - sendResponse(sending, req, invalidRequest, enc, s) - continue + sendResponse(sending, req, invalidRequest, codec, s) + break } serviceMethod := strings.Split(req.ServiceMethod, ".", 0) if len(serviceMethod) != 2 { - s := "rpc: service/method request ill:formed: " + req.ServiceMethod - sendResponse(sending, req, invalidRequest, enc, s) + s := "rpc: service/method request ill-formed: " + req.ServiceMethod + sendResponse(sending, req, invalidRequest, codec, s) continue } // Look up the request. @@ -334,27 +358,27 @@ func (server *serverType) input(conn io.ReadWriteCloser) { server.Unlock() if !ok { s := "rpc: can't find service " + req.ServiceMethod - sendResponse(sending, req, invalidRequest, enc, s) + sendResponse(sending, req, invalidRequest, codec, s) continue } mtype, ok := service.method[serviceMethod[1]] if !ok { s := "rpc: can't find method " + req.ServiceMethod - sendResponse(sending, req, invalidRequest, enc, s) + sendResponse(sending, req, invalidRequest, codec, s) continue } // Decode the argument value. argv := _new(mtype.argType) replyv := _new(mtype.replyType) - err = dec.Decode(argv.Interface()) + err = codec.ReadRequestBody(argv.Interface()) if err != nil { log.Stderr("rpc: tearing down", serviceMethod[0], "connection:", err) - sendResponse(sending, req, replyv.Interface(), enc, err.String()) - continue + sendResponse(sending, req, replyv.Interface(), codec, err.String()) + break } - go service.call(sending, mtype, req, argv, replyv, enc) + go service.call(sending, mtype, req, argv, replyv, codec) } - conn.Close() + codec.Close() } func (server *serverType) accept(lis net.Listener) { @@ -363,7 +387,7 @@ func (server *serverType) accept(lis net.Listener) { if err != nil { log.Exit("rpc.Serve: accept:", err.String()) // TODO(r): exit? } - go server.input(conn) + go ServeConn(conn) } } @@ -376,10 +400,34 @@ func (server *serverType) accept(lis net.Listener) { // suitable methods. func Register(rcvr interface{}) os.Error { return server.register(rcvr) } -// ServeConn runs the server on a single connection. When the connection -// completes, service terminates. ServeConn blocks; the caller typically -// invokes it in a go statement. -func ServeConn(conn io.ReadWriteCloser) { server.input(conn) } +// A ServerCodec implements reading of RPC requests and writing of +// RPC responses for the server side of an RPC session. +// The server calls ReadRequestHeader and ReadRequestBody in pairs +// to read requests from the connection, and it calls WriteResponse to +// write a response back. The server calls Close when finished with the +// connection. +type ServerCodec interface { + ReadRequestHeader(*Request) os.Error + ReadRequestBody(interface{}) os.Error + WriteResponse(*Response, interface{}) os.Error + + Close() os.Error +} + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +// The caller typically invokes ServeConn in a go statement. +// ServeConn uses the gob wire format (see package gob) on the +// connection. To use an alternate codec, use ServeCodec. +func ServeConn(conn io.ReadWriteCloser) { + ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)}) +} + +// ServeCodec is like ServeConn but uses the specified codec to +// decode requests and encode responses. +func ServeCodec(codec ServerCodec) { + server.input(codec) +} // Accept accepts connections on the listener and serves requests // for each incoming connection. Accept blocks; the caller typically @@ -404,7 +452,7 @@ func serveHTTP(c *http.Conn, req *http.Request) { return } io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") - server.input(conn) + ServeConn(conn) } // HandleHTTP registers an HTTP handler for RPC messages.