From 08483defacf38d2d1f89eae0136268dfee30ec84 Mon Sep 17 00:00:00 2001 From: Rob Pike Date: Mon, 28 Jun 2010 16:05:54 -0700 Subject: [PATCH] rpc: allow non-struct args and reply (they must still be pointers) R=rsc CC=golang-dev https://golang.org/cl/1722046 --- src/pkg/rpc/server.go | 45 ++++++++++-------------- src/pkg/rpc/server_test.go | 71 ++++++++++++++++++++++++++++++-------- 2 files changed, 75 insertions(+), 41 deletions(-) diff --git a/src/pkg/rpc/server.go b/src/pkg/rpc/server.go index f7fce45a14b..b8a0e5ccc0e 100644 --- a/src/pkg/rpc/server.go +++ b/src/pkg/rpc/server.go @@ -13,8 +13,8 @@ Only methods that satisfy these criteria will be made available for remote access; other methods will be ignored: - - the method name is publicly visible, that is, begins with an upper case letter. - - the method has two arguments, both pointers to publicly visible structs. + - the method receiver and name are publicly visible, that is, begin with an upper case letter. + - the method has two arguments, both pointers to publicly visible types. - the method has return type os.Error. The method's first argument represents the arguments provided by the caller; the @@ -30,7 +30,7 @@ NewClient on the connection. The convenience function Dial (DialHTTP) performs both steps for a raw network connection (an HTTP connection). The resulting Client object has two methods, Call and Go, that specify the service and method to - call, a structure containing the arguments, and a structure to receive the result + call, a pointer containing the arguments, and a pointer to receive the result parameters. Call waits for the remote call to complete; Go launches the call asynchronously @@ -46,22 +46,23 @@ A, B int } - type Reply struct { - C int + type Quotient struct { + Quo, Rem int } type Arith int - func (t *Arith) Multiply(args *Args, reply *Reply) os.Error { - reply.C = args.A * args.B + func (t *Arith) Multiply(args *Args, reply *int) os.Error { + *reply = args.A * args.B return nil } - func (t *Arith) Divide(args *Args, reply *Reply) os.Error { + func (t *Arith) Divide(args *Args, quo *Quotient) os.Error { if args.B == 0 { return os.ErrorString("divide by zero") } - reply.C = args.A / args.B + quo.Quo = args.A / args.B + quo.Rem = args.A % args.B return nil } @@ -88,17 +89,18 @@ // Synchronous call args := &server.Args{7,8} - reply := new(server.Reply) - err = client.Call("Arith.Multiply", args, reply) + var reply int + err = client.Call("Arith.Multiply", args, &reply) if err != nil { log.Exit("arith error:", err) } - fmt.Printf("Arith: %d*%d=%d", args.A, args.B, reply.C) + fmt.Printf("Arith: %d*%d=%d", args.A, args.B, *reply) or // Asynchronous call - divCall := client.Go("Arith.Divide", args, reply, nil) + quotient := new(Quotient) + divCall := client.Go("Arith.Divide", args, "ient, nil) replyCall := <-divCall.Done // will be equal to divCall // check errors, print, etc. @@ -193,7 +195,7 @@ func (server *serverType) register(rcvr interface{}) os.Error { if sname == "" { log.Exit("rpc: no service name for type", s.typ.String()) } - if !isPublic(sname) { + if s.typ.PkgPath() != "" && !isPublic(sname) { s := "rpc Register: type " + sname + " is not public" log.Stderr(s) return os.ErrorString(s) @@ -209,11 +211,10 @@ func (server *serverType) register(rcvr interface{}) os.Error { method := s.typ.Method(m) mtype := method.Type mname := method.Name - if !isPublic(mname) { + if mtype.PkgPath() != "" && !isPublic(mname) { continue } // Method needs three ins: receiver, *args, *reply. - // The args and reply must be structs until gobs are more general. if mtype.NumIn() != 3 { log.Stderr("method", mname, "has wrong number of ins:", mtype.NumIn()) continue @@ -223,24 +224,16 @@ func (server *serverType) register(rcvr interface{}) os.Error { log.Stderr(mname, "arg type not a pointer:", mtype.In(1)) continue } - if _, ok := argType.Elem().(*reflect.StructType); !ok { - log.Stderr(mname, "arg type not a pointer to a struct:", argType) - continue - } replyType, ok := mtype.In(2).(*reflect.PtrType) if !ok { log.Stderr(mname, "reply type not a pointer:", mtype.In(2)) continue } - if _, ok := replyType.Elem().(*reflect.StructType); !ok { - log.Stderr(mname, "reply type not a pointer to a struct:", replyType) - continue - } - if !isPublic(argType.Elem().Name()) { + if argType.Elem().PkgPath() != "" && !isPublic(argType.Elem().Name()) { log.Stderr(mname, "argument type not public:", argType) continue } - if !isPublic(replyType.Elem().Name()) { + if replyType.Elem().PkgPath() != "" && !isPublic(replyType.Elem().Name()) { log.Stderr(mname, "reply type not public:", replyType) continue } diff --git a/src/pkg/rpc/server_test.go b/src/pkg/rpc/server_test.go index edf35e6c9f2..e502db4e31a 100644 --- a/src/pkg/rpc/server_test.go +++ b/src/pkg/rpc/server_test.go @@ -5,6 +5,7 @@ package rpc import ( + "fmt" "http" "log" "net" @@ -48,6 +49,16 @@ func (t *Arith) Div(args *Args, reply *Reply) os.Error { return nil } +func (t *Arith) String(args *Args, reply *string) os.Error { + *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) + return nil +} + +func (t *Arith) Scan(args *string, reply *Reply) (err os.Error) { + _, err = fmt.Sscan(*args, &reply.C) + return +} + func (t *Arith) Error(args *Args, reply *Reply) os.Error { panic("ERROR") } @@ -136,6 +147,29 @@ func TestRPC(t *testing.T) { } else if err.String() != "divide by zero" { t.Error("Div: expected divide by zero error; got", err) } + + // Non-struct argument + const Val = 12345 + str := fmt.Sprint(Val) + reply = new(Reply) + err = client.Call("Arith.Scan", &str, reply) + if err != nil { + t.Errorf("Scan: expected no error but got string %q", err.String()) + } else if reply.C != Val { + t.Errorf("Scan: expected %d got %d", Val, reply.C) + } + + // Non-struct reply + args = &Args{27, 35} + str = "" + err = client.Call("Arith.String", args, &str) + if err != nil { + t.Errorf("String: expected no error but got string %q", err.String()) + } + expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) + if str != expect { + t.Errorf("String: expected %s got %s", expect, str) + } } func TestHTTPRPC(t *testing.T) { @@ -217,37 +251,44 @@ func TestCheckBadType(t *testing.T) { } } -type Bad int +type ArgNotPointer int +type ReplyNotPointer int +type ArgNotPublic int +type ReplyNotPublic int type local struct{} -func (t *Bad) ArgNotPointer(args Args, reply *Reply) os.Error { +func (t *ArgNotPointer) ArgNotPointer(args Args, reply *Reply) os.Error { return nil } -func (t *Bad) ArgNotPointerToStruct(args *int, reply *Reply) os.Error { +func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) os.Error { return nil } -func (t *Bad) ReplyNotPointer(args *Args, reply Reply) os.Error { +func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) os.Error { return nil } -func (t *Bad) ReplyNotPointerToStruct(args *Args, reply *int) os.Error { - return nil -} - -func (t *Bad) ArgNotPublic(args *local, reply *Reply) os.Error { - return nil -} - -func (t *Bad) ReplyNotPublic(args *Args, reply *local) os.Error { +func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) os.Error { return nil } // Check that registration handles lots of bad methods and a type with no suitable methods. func TestRegistrationError(t *testing.T) { - err := Register(new(Bad)) + err := Register(new(ArgNotPointer)) if err == nil { - t.Errorf("expected error registering bad type") + t.Errorf("expected error registering ArgNotPointer") + } + err = Register(new(ReplyNotPointer)) + if err == nil { + t.Errorf("expected error registering ReplyNotPointer") + } + err = Register(new(ArgNotPublic)) + if err == nil { + t.Errorf("expected error registering ArgNotPublic") + } + err = Register(new(ReplyNotPublic)) + if err == nil { + t.Errorf("expected error registering ReplyNotPublic") } }