1
0
mirror of https://github.com/golang/go synced 2024-11-19 14:04:46 -07:00

net/rpc: wait for responses to be written before closing Codec

If there are no more requests being made, wait to shut down
the response-writing codec until the pending requests are all
answered.

Fixes #17239.

Change-Id: Ie62c63ada536171df4e70b73c95f98f778069972
Reviewed-on: https://go-review.googlesource.com/79515
Run-TryBot: Russ Cox <rsc@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rob Pike <r@golang.org>
This commit is contained in:
Russ Cox 2017-11-22 15:25:59 -05:00
parent 70ee9b4a07
commit 4f6d8a59ea
2 changed files with 63 additions and 3 deletions

View File

@ -372,7 +372,10 @@ func (m *methodType) NumCalls() (n uint) {
return n
}
func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
if wg != nil {
defer wg.Done()
}
mtype.Lock()
mtype.numCalls++
mtype.Unlock()
@ -456,6 +459,7 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) {
// decode requests and encode responses.
func (server *Server) ServeCodec(codec ServerCodec) {
sending := new(sync.Mutex)
wg := new(sync.WaitGroup)
for {
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
if err != nil {
@ -472,8 +476,12 @@ func (server *Server) ServeCodec(codec ServerCodec) {
}
continue
}
go service.call(server, sending, mtype, req, argv, replyv, codec)
wg.Add(1)
go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
}
// We've seen that there are no more requests.
// Wait for responses to be sent before closing codec.
wg.Wait()
codec.Close()
}
@ -493,7 +501,7 @@ func (server *Server) ServeRequest(codec ServerCodec) error {
}
return err
}
service.call(server, sending, mtype, req, argv, replyv, codec)
service.call(server, sending, nil, mtype, req, argv, replyv, codec)
return nil
}

View File

@ -75,6 +75,11 @@ func (t *Arith) Error(args *Args, reply *Reply) error {
panic("ERROR")
}
func (t *Arith) SleepMilli(args *Args, reply *Reply) error {
time.Sleep(time.Duration(args.A) * time.Millisecond)
return nil
}
type hidden int
func (t *hidden) Exported(args Args, reply *Reply) error {
@ -693,6 +698,53 @@ func TestAcceptExitAfterListenerClose(t *testing.T) {
newServer.Accept(l)
}
func TestShutdown(t *testing.T) {
var l net.Listener
l, _ = listenTCP()
ch := make(chan net.Conn, 1)
go func() {
defer l.Close()
c, err := l.Accept()
if err != nil {
t.Error(err)
}
ch <- c
}()
c, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatal(err)
}
c1 := <-ch
if c1 == nil {
t.Fatal(err)
}
newServer := NewServer()
newServer.Register(new(Arith))
go newServer.ServeConn(c1)
args := &Args{7, 8}
reply := new(Reply)
client := NewClient(c)
err = client.Call("Arith.Add", args, reply)
if err != nil {
t.Fatal(err)
}
// On an unloaded system 10ms is usually enough to fail 100% of the time
// with a broken server. On a loaded system, a broken server might incorrectly
// be reported as passing, but we're OK with that kind of flakiness.
// If the code is correct, this test will never fail, regardless of timeout.
args.A = 10 // 10 ms
done := make(chan *Call, 1)
call := client.Go("Arith.SleepMilli", args, reply, done)
c.(*net.TCPConn).CloseWrite()
<-done
if call.Error != nil {
t.Fatal(err)
}
}
func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
once.Do(startServer)
client, err := dial()