1
0
mirror of https://github.com/golang/go synced 2024-11-19 05:54:44 -07:00

net: fix netFD.Close races

Fixes #271.
Fixes #321.

R=rsc, agl, cw
CC=golang-dev
https://golang.org/cl/163052
This commit is contained in:
Devon H. O'Dell 2009-12-01 23:28:57 -08:00 committed by Russ Cox
parent ff3ea68e52
commit eb16346dac
5 changed files with 98 additions and 75 deletions

View File

@ -15,14 +15,18 @@ import (
// Network file descriptor.
type netFD struct {
// locking/lifetime of sysfd
sysmu sync.Mutex;
sysref int;
closing bool;
// immutable until Close
fd int;
sysfd int;
family int;
proto int;
file *os.File;
sysfile *os.File;
cr chan *netFD;
cw chan *netFD;
cc chan *netFD;
net string;
laddr Addr;
raddr Addr;
@ -68,13 +72,13 @@ type netFD struct {
// channel will be empty for the next process's request. A larger buffer
// might help batch requests.
//
// In order to prevent race conditions, pollServer has an additional cc channel
// that receives fds to be closed. pollServer doesn't make the close system
// call, it just sets fd.file = nil and fd.fd = -1. Because of this, pollServer
// is always in sync with the kernel's view of a given descriptor.
// To avoid races in closing, all fd operations are locked and
// refcounted. when netFD.Close() is called, it calls syscall.Shutdown
// and sets a closing flag. Only when the last reference is removed
// will the fd be closed.
type pollServer struct {
cr, cw, cc chan *netFD; // buffered >= 1
cr, cw chan *netFD; // buffered >= 1
pr, pw *os.File;
pending map[int]*netFD;
poll *pollster; // low-level OS hooks
@ -85,7 +89,6 @@ func newPollServer() (s *pollServer, err os.Error) {
s = new(pollServer);
s.cr = make(chan *netFD, 1);
s.cw = make(chan *netFD, 1);
s.cc = make(chan *netFD, 1);
if s.pr, s.pw, err = os.Pipe(); err != nil {
return nil, err
}
@ -114,16 +117,7 @@ func newPollServer() (s *pollServer, err os.Error) {
}
func (s *pollServer) AddFD(fd *netFD, mode int) {
// This check verifies that the underlying file descriptor hasn't been
// closed in the mean time. Any time a netFD is closed, the closing
// goroutine makes a round trip to the pollServer which sets file = nil
// and fd = -1. The goroutine then closes the actual file descriptor.
// Thus fd.fd mirrors the kernel's view of the file descriptor.
// TODO(rsc,agl): There is still a race in Read and Write,
// because they optimistically try to use the fd and don't
// call into the PollServer unless they get EAGAIN.
intfd := fd.fd;
intfd := fd.sysfd;
if intfd < 0 {
// fd closed underfoot
if mode == 'r' {
@ -213,10 +207,10 @@ func (s *pollServer) CheckDeadlines() {
if t <= now {
s.pending[key] = nil, false;
if mode == 'r' {
s.poll.DelFD(fd.fd, mode);
s.poll.DelFD(fd.sysfd, mode);
fd.rdeadline = -1;
} else {
s.poll.DelFD(fd.fd, mode);
s.poll.DelFD(fd.sysfd, mode);
fd.wdeadline = -1;
}
s.WakeFD(fd, mode);
@ -254,7 +248,6 @@ func (s *pollServer) Run() {
for nn, _ := s.pr.Read(&scratch); nn > 0; {
nn, _ = s.pr.Read(&scratch)
}
// Read from channels
for fd, ok := <-s.cr; ok; fd, ok = <-s.cr {
s.AddFD(fd, 'r')
@ -262,11 +255,6 @@ func (s *pollServer) Run() {
for fd, ok := <-s.cw; ok; fd, ok = <-s.cw {
s.AddFD(fd, 'w')
}
for fd, ok := <-s.cc; ok; fd, ok = <-s.cc {
fd.file = nil;
fd.fd = -1;
fd.cc <- fd;
}
} else {
netfd := s.LookupFD(fd, mode);
if netfd == nil {
@ -294,12 +282,6 @@ func (s *pollServer) WaitWrite(fd *netFD) {
<-fd.cw;
}
func (s *pollServer) WaitCloseAck(fd *netFD) {
s.cc <- fd;
s.Wakeup();
<-fd.cc;
}
// Network FD methods.
// All the network FDs use a single pollServer.
@ -319,7 +301,7 @@ func newFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD, err
return nil, &OpError{"setnonblock", net, laddr, os.Errno(e)}
}
f = &netFD{
fd: fd,
sysfd: fd,
family: family,
proto: proto,
net: net,
@ -333,13 +315,37 @@ func newFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD, err
if raddr != nil {
rs = raddr.String()
}
f.file = os.NewFile(fd, net+":"+ls+"->"+rs);
f.sysfile = os.NewFile(fd, net+":"+ls+"->"+rs);
f.cr = make(chan *netFD, 1);
f.cw = make(chan *netFD, 1);
f.cc = make(chan *netFD, 1);
return f, nil;
}
// Add a reference to this fd.
func (fd *netFD) incref() {
fd.sysmu.Lock();
fd.sysref++;
fd.sysmu.Unlock();
}
// Remove a reference to this FD and close if we've been asked to do so (and
// there are no references left.
func (fd *netFD) decref() {
fd.sysmu.Lock();
fd.sysref--;
if fd.closing && fd.sysref == 0 && fd.sysfd >= 0 {
// In case the user has set linger, switch to blocking mode so
// the close blocks. As long as this doesn't happen often, we
// can handle the extra OS processes. Otherwise we'll need to
// use the pollserver for Close too. Sigh.
syscall.SetNonblock(fd.sysfd, false);
fd.sysfile.Close();
fd.sysfile = nil;
fd.sysfd = -1;
}
fd.sysmu.Unlock();
}
func isEAGAIN(e os.Error) bool {
if e1, ok := e.(*os.PathError); ok {
return e1.Error == os.EAGAIN
@ -348,36 +354,32 @@ func isEAGAIN(e os.Error) bool {
}
func (fd *netFD) Close() os.Error {
if fd == nil || fd.file == nil {
if fd == nil || fd.sysfile == nil {
return os.EINVAL
}
// In case the user has set linger,
// switch to blocking mode so the close blocks.
// As long as this doesn't happen often,
// we can handle the extra OS processes.
// Otherwise we'll need to use the pollserver
// for Close too. Sigh.
syscall.SetNonblock(fd.file.Fd(), false);
f := fd.file;
pollserver.WaitCloseAck(fd);
return f.Close();
fd.incref();
syscall.Shutdown(fd.sysfd, syscall.SHUT_RDWR);
fd.closing = true;
fd.decref();
return nil;
}
func (fd *netFD) Read(p []byte) (n int, err os.Error) {
if fd == nil || fd.file == nil {
if fd == nil || fd.sysfile == nil {
return 0, os.EINVAL
}
fd.rio.Lock();
defer fd.rio.Unlock();
fd.incref();
defer fd.decref();
if fd.rdeadline_delta > 0 {
fd.rdeadline = pollserver.Now() + fd.rdeadline_delta
} else {
fd.rdeadline = 0
}
for {
n, err = fd.file.Read(p);
n, err = fd.sysfile.Read(p);
if isEAGAIN(err) && fd.rdeadline >= 0 {
pollserver.WaitRead(fd);
continue;
@ -388,11 +390,13 @@ func (fd *netFD) Read(p []byte) (n int, err os.Error) {
}
func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) {
if fd == nil || fd.file == nil {
if fd == nil || fd.sysfile == nil {
return 0, nil, os.EINVAL
}
fd.rio.Lock();
defer fd.rio.Unlock();
fd.incref();
defer fd.decref();
if fd.rdeadline_delta > 0 {
fd.rdeadline = pollserver.Now() + fd.rdeadline_delta
} else {
@ -400,14 +404,14 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) {
}
for {
var errno int;
n, sa, errno = syscall.Recvfrom(fd.fd, p, 0);
n, sa, errno = syscall.Recvfrom(fd.sysfd, p, 0);
if errno == syscall.EAGAIN && fd.rdeadline >= 0 {
pollserver.WaitRead(fd);
continue;
}
if errno != 0 {
n = 0;
err = &os.PathError{"recvfrom", fd.file.Name(), os.Errno(errno)};
err = &os.PathError{"recvfrom", fd.sysfile.Name(), os.Errno(errno)};
}
break;
}
@ -415,11 +419,13 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) {
}
func (fd *netFD) Write(p []byte) (n int, err os.Error) {
if fd == nil || fd.file == nil {
if fd == nil || fd.sysfile == nil {
return 0, os.EINVAL
}
fd.wio.Lock();
defer fd.wio.Unlock();
fd.incref();
defer fd.decref();
if fd.wdeadline_delta > 0 {
fd.wdeadline = pollserver.Now() + fd.wdeadline_delta
} else {
@ -428,7 +434,7 @@ func (fd *netFD) Write(p []byte) (n int, err os.Error) {
err = nil;
nn := 0;
for nn < len(p) {
n, err = fd.file.Write(p[nn:]);
n, err = fd.sysfile.Write(p[nn:]);
if n > 0 {
nn += n
}
@ -447,11 +453,13 @@ func (fd *netFD) Write(p []byte) (n int, err os.Error) {
}
func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) {
if fd == nil || fd.file == nil {
if fd == nil || fd.sysfile == nil {
return 0, os.EINVAL
}
fd.wio.Lock();
defer fd.wio.Unlock();
fd.incref();
defer fd.decref();
if fd.wdeadline_delta > 0 {
fd.wdeadline = pollserver.Now() + fd.wdeadline_delta
} else {
@ -459,13 +467,13 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) {
}
err = nil;
for {
errno := syscall.Sendto(fd.fd, p, 0, sa);
errno := syscall.Sendto(fd.sysfd, p, 0, sa);
if errno == syscall.EAGAIN && fd.wdeadline >= 0 {
pollserver.WaitWrite(fd);
continue;
}
if errno != 0 {
err = &os.PathError{"sendto", fd.file.Name(), os.Errno(errno)}
err = &os.PathError{"sendto", fd.sysfile.Name(), os.Errno(errno)}
}
break;
}
@ -476,18 +484,21 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) {
}
func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.Error) {
if fd == nil || fd.file == nil {
if fd == nil || fd.sysfile == nil {
return nil, os.EINVAL
}
fd.incref();
defer fd.decref();
// See ../syscall/exec.go for description of ForkLock.
// It is okay to hold the lock across syscall.Accept
// because we have put fd.fd into non-blocking mode.
// because we have put fd.sysfd into non-blocking mode.
syscall.ForkLock.RLock();
var s, e int;
var sa syscall.Sockaddr;
for {
s, sa, e = syscall.Accept(fd.fd);
s, sa, e = syscall.Accept(fd.sysfd);
if e != syscall.EAGAIN {
break
}

View File

@ -75,11 +75,15 @@ func setsockoptNsec(fd, level, opt int, nsec int64) os.Error {
}
func setReadBuffer(fd *netFD, bytes int) os.Error {
return setsockoptInt(fd.fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes)
fd.incref();
defer fd.decref();
return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes);
}
func setWriteBuffer(fd *netFD, bytes int) os.Error {
return setsockoptInt(fd.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes)
fd.incref();
defer fd.decref();
return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes);
}
func setReadTimeout(fd *netFD, nsec int64) os.Error {
@ -100,7 +104,9 @@ func setTimeout(fd *netFD, nsec int64) os.Error {
}
func setReuseAddr(fd *netFD, reuse bool) os.Error {
return setsockoptInt(fd.fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse))
fd.incref();
defer fd.decref();
return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse));
}
func bindToDevice(fd *netFD, dev string) os.Error {
@ -109,11 +115,15 @@ func bindToDevice(fd *netFD, dev string) os.Error {
}
func setDontRoute(fd *netFD, dontroute bool) os.Error {
return setsockoptInt(fd.fd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute))
fd.incref();
defer fd.decref();
return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute));
}
func setKeepAlive(fd *netFD, keepalive bool) os.Error {
return setsockoptInt(fd.fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive))
fd.incref();
defer fd.decref();
return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive));
}
func setLinger(fd *netFD, sec int) os.Error {
@ -125,7 +135,9 @@ func setLinger(fd *netFD, sec int) os.Error {
l.Onoff = 0;
l.Linger = 0;
}
e := syscall.SetsockoptLinger(fd.fd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l);
fd.incref();
defer fd.decref();
e := syscall.SetsockoptLinger(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l);
return os.NewSyscallError("setsockopt", e);
}

View File

@ -73,7 +73,7 @@ type TCPConn struct {
func newTCPConn(fd *netFD) *TCPConn {
c := &TCPConn{fd};
setsockoptInt(fd.fd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, 1);
setsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, 1);
return c;
}
@ -234,9 +234,9 @@ func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err os.Error) {
if err != nil {
return nil, err
}
errno := syscall.Listen(fd.fd, listenBacklog());
errno := syscall.Listen(fd.sysfd, listenBacklog());
if errno != 0 {
syscall.Close(fd.fd);
syscall.Close(fd.sysfd);
return nil, &OpError{"listen", "tcp", laddr, os.Errno(errno)};
}
l = new(TCPListener);
@ -247,7 +247,7 @@ func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err os.Error) {
// AcceptTCP accepts the next incoming call and returns the new connection
// and the remote address.
func (l *TCPListener) AcceptTCP() (c *TCPConn, err os.Error) {
if l == nil || l.fd == nil || l.fd.fd < 0 {
if l == nil || l.fd == nil || l.fd.sysfd < 0 {
return nil, os.EINVAL
}
fd, err := l.fd.accept(sockaddrToTCP);

View File

@ -73,7 +73,7 @@ type UDPConn struct {
func newUDPConn(fd *netFD) *UDPConn {
c := &UDPConn{fd};
setsockoptInt(fd.fd, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1);
setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1);
return c;
}

View File

@ -332,9 +332,9 @@ func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err os.Error) {
}
return nil, e;
}
e1 := syscall.Listen(fd.fd, 8); // listenBacklog());
e1 := syscall.Listen(fd.sysfd, 8); // listenBacklog());
if e1 != 0 {
syscall.Close(fd.fd);
syscall.Close(fd.sysfd);
return nil, &OpError{"listen", "unix", laddr, os.Errno(e1)};
}
return &UnixListener{fd, laddr.Name}, nil;
@ -343,7 +343,7 @@ func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err os.Error) {
// AcceptUnix accepts the next incoming call and returns the new connection
// and the remote address.
func (l *UnixListener) AcceptUnix() (c *UnixConn, err os.Error) {
if l == nil || l.fd == nil || l.fd.fd < 0 {
if l == nil || l.fd == nil {
return nil, os.EINVAL
}
fd, e := l.fd.accept(sockaddrToUnix);