mirror of
https://github.com/golang/go
synced 2024-11-21 23:34:42 -07:00
rpc: avoid infinite loop on input error
Fixes #1828. Fixes #2179. R=golang-dev, r CC=golang-dev https://golang.org/cl/5305084
This commit is contained in:
parent
7b04471dfa
commit
2e79e8e549
@ -6,6 +6,7 @@ package jsonrpc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"json"
|
"json"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
@ -154,3 +155,67 @@ func TestClient(t *testing.T) {
|
|||||||
t.Error("Div: expected divide by zero error; got", err)
|
t.Error("Div: expected divide by zero error; got", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMalformedInput(t *testing.T) {
|
||||||
|
cli, srv := net.Pipe()
|
||||||
|
go cli.Write([]byte(`{id:1}`)) // invalid json
|
||||||
|
ServeConn(srv) // must return, not loop
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnexpectedError(t *testing.T) {
|
||||||
|
cli, srv := myPipe()
|
||||||
|
go cli.PipeWriter.CloseWithError(os.NewError("unexpected error!")) // reader will get this error
|
||||||
|
ServeConn(srv) // must return, not loop
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copied from package net.
|
||||||
|
func myPipe() (*pipe, *pipe) {
|
||||||
|
r1, w1 := io.Pipe()
|
||||||
|
r2, w2 := io.Pipe()
|
||||||
|
|
||||||
|
return &pipe{r1, w2}, &pipe{r2, w1}
|
||||||
|
}
|
||||||
|
|
||||||
|
type pipe struct {
|
||||||
|
*io.PipeReader
|
||||||
|
*io.PipeWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type pipeAddr int
|
||||||
|
|
||||||
|
func (pipeAddr) Network() string {
|
||||||
|
return "pipe"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pipeAddr) String() string {
|
||||||
|
return "pipe"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipe) Close() os.Error {
|
||||||
|
err := p.PipeReader.Close()
|
||||||
|
err1 := p.PipeWriter.Close()
|
||||||
|
if err == nil {
|
||||||
|
err = err1
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipe) LocalAddr() net.Addr {
|
||||||
|
return pipeAddr(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipe) RemoteAddr() net.Addr {
|
||||||
|
return pipeAddr(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipe) SetTimeout(nsec int64) os.Error {
|
||||||
|
return os.NewError("net.Pipe does not support timeouts")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipe) SetReadTimeout(nsec int64) os.Error {
|
||||||
|
return os.NewError("net.Pipe does not support timeouts")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipe) SetWriteTimeout(nsec int64) os.Error {
|
||||||
|
return os.NewError("net.Pipe does not support timeouts")
|
||||||
|
}
|
||||||
|
@ -394,12 +394,12 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) {
|
|||||||
func (server *Server) ServeCodec(codec ServerCodec) {
|
func (server *Server) ServeCodec(codec ServerCodec) {
|
||||||
sending := new(sync.Mutex)
|
sending := new(sync.Mutex)
|
||||||
for {
|
for {
|
||||||
service, mtype, req, argv, replyv, err := server.readRequest(codec)
|
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != os.EOF {
|
if err != os.EOF {
|
||||||
log.Println("rpc:", err)
|
log.Println("rpc:", err)
|
||||||
}
|
}
|
||||||
if err == os.EOF || err == io.ErrUnexpectedEOF {
|
if !keepReading {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
// send a response if we actually managed to read a header.
|
// send a response if we actually managed to read a header.
|
||||||
@ -418,9 +418,9 @@ func (server *Server) ServeCodec(codec ServerCodec) {
|
|||||||
// It does not close the codec upon completion.
|
// It does not close the codec upon completion.
|
||||||
func (server *Server) ServeRequest(codec ServerCodec) os.Error {
|
func (server *Server) ServeRequest(codec ServerCodec) os.Error {
|
||||||
sending := new(sync.Mutex)
|
sending := new(sync.Mutex)
|
||||||
service, mtype, req, argv, replyv, err := server.readRequest(codec)
|
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == os.EOF || err == io.ErrUnexpectedEOF {
|
if !keepReading {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// send a response if we actually managed to read a header.
|
// send a response if we actually managed to read a header.
|
||||||
@ -474,10 +474,10 @@ func (server *Server) freeResponse(resp *Response) {
|
|||||||
server.respLock.Unlock()
|
server.respLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, err os.Error) {
|
func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err os.Error) {
|
||||||
service, mtype, req, err = server.readRequestHeader(codec)
|
service, mtype, req, keepReading, err = server.readRequestHeader(codec)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == os.EOF || err == io.ErrUnexpectedEOF {
|
if !keepReading {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// discard body
|
// discard body
|
||||||
@ -505,7 +505,7 @@ func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *m
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mtype *methodType, req *Request, err os.Error) {
|
func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mtype *methodType, req *Request, keepReading bool, err os.Error) {
|
||||||
// Grab the request header.
|
// Grab the request header.
|
||||||
req = server.getRequest()
|
req = server.getRequest()
|
||||||
err = codec.ReadRequestHeader(req)
|
err = codec.ReadRequestHeader(req)
|
||||||
@ -518,6 +518,10 @@ func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We read the header successfully. If we see an error now,
|
||||||
|
// we can still recover and move on to the next request.
|
||||||
|
keepReading = true
|
||||||
|
|
||||||
serviceMethod := strings.Split(req.ServiceMethod, ".")
|
serviceMethod := strings.Split(req.ServiceMethod, ".")
|
||||||
if len(serviceMethod) != 2 {
|
if len(serviceMethod) != 2 {
|
||||||
err = os.NewError("rpc: service/method request ill-formed: " + req.ServiceMethod)
|
err = os.NewError("rpc: service/method request ill-formed: " + req.ServiceMethod)
|
||||||
|
@ -311,8 +311,9 @@ func (codec *CodecEmulator) ReadRequestBody(argv interface{}) os.Error {
|
|||||||
func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) os.Error {
|
func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) os.Error {
|
||||||
if resp.Error != "" {
|
if resp.Error != "" {
|
||||||
codec.err = os.NewError(resp.Error)
|
codec.err = os.NewError(resp.Error)
|
||||||
|
} else {
|
||||||
|
*codec.reply = *(reply.(*Reply))
|
||||||
}
|
}
|
||||||
*codec.reply = *(reply.(*Reply))
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user