diff --git a/src/pkg/rpc/server.go b/src/pkg/rpc/server.go index 086457963a..acadeec37f 100644 --- a/src/pkg/rpc/server.go +++ b/src/pkg/rpc/server.go @@ -13,8 +13,11 @@ Only methods that satisfy these criteria will be made available for remote access; other methods will be ignored: - - the method receiver and name are exported, that is, begin with an upper case letter. - - the method has two arguments, both pointers to exported types. + - the method name is exported, that is, begins with an upper case letter. + - the method receiver is exported or local (defined in the package + registering the service). + - the method has two arguments, both exported or local types. + - the method's second argument is a pointer. - the method has return type os.Error. The method's first argument represents the arguments provided by the caller; the @@ -193,6 +196,14 @@ func isExported(name string) bool { return unicode.IsUpper(rune) } +// Is this type exported or local to this package? +func isExportedOrLocalType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t.PkgPath() == "" || isExported(t.Name()) +} + // Register publishes in the server the set of methods of the // receiver value that satisfy the following conditions: // - exported method @@ -252,23 +263,20 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E log.Println("method", mname, "has wrong number of ins:", mtype.NumIn()) continue } + // First arg need not be a pointer. argType := mtype.In(1) - ok := argType.Kind() == reflect.Ptr - if !ok { - log.Println(mname, "arg type not a pointer:", mtype.In(1)) + if !isExportedOrLocalType(argType) { + log.Println(mname, "argument type not exported or local:", argType) continue } + // Second arg must be a pointer. replyType := mtype.In(2) if replyType.Kind() != reflect.Ptr { - log.Println(mname, "reply type not a pointer:", mtype.In(2)) + log.Println("method", mname, "reply type not a pointer:", replyType) continue } - if argType.Elem().PkgPath() != "" && !isExported(argType.Elem().Name()) { - log.Println(mname, "argument type not exported:", argType) - continue - } - if replyType.Elem().PkgPath() != "" && !isExported(replyType.Elem().Name()) { - log.Println(mname, "reply type not exported:", replyType) + if !isExportedOrLocalType(replyType) { + log.Println("method", mname, "reply type not exported or local:", replyType) continue } // Method needs one out: os.Error. @@ -405,7 +413,15 @@ func (server *Server) ServeCodec(codec ServerCodec) { } // Decode the argument value. - argv := reflect.New(mtype.ArgType.Elem()) + var argv reflect.Value + argIsValue := false // if true, need to indirect before calling. + if mtype.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(mtype.ArgType.Elem()) + } else { + argv = reflect.New(mtype.ArgType) + argIsValue = true + } + // argv guaranteed to be a pointer now. replyv := reflect.New(mtype.ReplyType.Elem()) err = codec.ReadRequestBody(argv.Interface()) if err != nil { @@ -418,6 +434,9 @@ func (server *Server) ServeCodec(codec ServerCodec) { server.sendResponse(sending, req, replyv.Interface(), codec, err.String()) continue } + if argIsValue { + argv = argv.Elem() + } go service.call(server, sending, mtype, req, argv, replyv, codec) } codec.Close() diff --git a/src/pkg/rpc/server_test.go b/src/pkg/rpc/server_test.go index d4041ae70c..eb7b673d66 100644 --- a/src/pkg/rpc/server_test.go +++ b/src/pkg/rpc/server_test.go @@ -38,7 +38,9 @@ type Reply struct { type Arith int -func (t *Arith) Add(args *Args, reply *Reply) os.Error { +// Some of Arith's methods have value args, some have pointer args. That's deliberate. + +func (t *Arith) Add(args Args, reply *Reply) os.Error { reply.C = args.A + args.B return nil } @@ -48,7 +50,7 @@ func (t *Arith) Mul(args *Args, reply *Reply) os.Error { return nil } -func (t *Arith) Div(args *Args, reply *Reply) os.Error { +func (t *Arith) Div(args Args, reply *Reply) os.Error { if args.B == 0 { return os.ErrorString("divide by zero") } @@ -61,8 +63,8 @@ func (t *Arith) String(args *Args, reply *string) os.Error { return nil } -func (t *Arith) Scan(args *string, reply *Reply) (err os.Error) { - _, err = fmt.Sscan(*args, &reply.C) +func (t *Arith) Scan(args string, reply *Reply) (err os.Error) { + _, err = fmt.Sscan(args, &reply.C) return } @@ -262,16 +264,11 @@ func testHTTPRPC(t *testing.T, path string) { } } -type ArgNotPointer int type ReplyNotPointer int type ArgNotPublic int type ReplyNotPublic int type local struct{} -func (t *ArgNotPointer) ArgNotPointer(args Args, reply *Reply) os.Error { - return nil -} - func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) os.Error { return nil } @@ -286,11 +283,7 @@ func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) os.Error { // Check that registration handles lots of bad methods and a type with no suitable methods. func TestRegistrationError(t *testing.T) { - err := Register(new(ArgNotPointer)) - if err == nil { - t.Errorf("expected error registering ArgNotPointer") - } - err = Register(new(ReplyNotPointer)) + err := Register(new(ReplyNotPointer)) if err == nil { t.Errorf("expected error registering ReplyNotPointer") }