mirror of
https://github.com/golang/go
synced 2024-11-26 04:17:59 -07:00
net: pass MSG_CMSG_CLOEXEC flag in ReadMsgUnix
As mentioned in #42765, calling "recvmsg" syscall on Linux should come
with "MSG_CMSG_CLOEXEC" flag.
For other systems which not supports "MSG_CMSG_CLOEXEC". ReadMsgUnix()
would check the header. If the header type is "syscall.SCM_RIGHTS",
then ReadMsgUnix() would parse the SocketControlMessage and call each
fd with "syscall.CloseOnExec"
Fixes #42765
Change-Id: I74347db72b465685d7684bf0f32415d285845ebb
GitHub-Last-Rev: ca59e2c9e0
GitHub-Pull-Request: golang/go#42768
Reviewed-on: https://go-review.googlesource.com/c/go/+/272226
Trust: Emmanuel Odeke <emmanuel@orijtech.com>
Run-TryBot: Emmanuel Odeke <emmanuel@orijtech.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Tobias Klauser <tobias.klauser@gmail.com>
This commit is contained in:
parent
bbb510ccc9
commit
e97d8eb027
@ -231,7 +231,7 @@ func (fd *FD) ReadFrom(p []byte) (int, syscall.Sockaddr, error) {
|
||||
}
|
||||
|
||||
// ReadMsg wraps the recvmsg network call.
|
||||
func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, error) {
|
||||
func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.Sockaddr, error) {
|
||||
if err := fd.readLock(); err != nil {
|
||||
return 0, 0, 0, nil, err
|
||||
}
|
||||
@ -240,7 +240,7 @@ func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, er
|
||||
return 0, 0, 0, nil, err
|
||||
}
|
||||
for {
|
||||
n, oobn, flags, sa, err := syscall.Recvmsg(fd.Sysfd, p, oob, 0)
|
||||
n, oobn, sysflags, sa, err := syscall.Recvmsg(fd.Sysfd, p, oob, flags)
|
||||
if err != nil {
|
||||
if err == syscall.EINTR {
|
||||
continue
|
||||
@ -253,7 +253,7 @@ func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, er
|
||||
}
|
||||
}
|
||||
err = fd.eofError(n, err)
|
||||
return n, oobn, flags, sa, err
|
||||
return n, oobn, sysflags, sa, err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1013,7 +1013,7 @@ func sockaddrToRaw(sa syscall.Sockaddr) (unsafe.Pointer, int32, error) {
|
||||
}
|
||||
|
||||
// ReadMsg wraps the WSARecvMsg network call.
|
||||
func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, error) {
|
||||
func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.Sockaddr, error) {
|
||||
if err := fd.readLock(); err != nil {
|
||||
return 0, 0, 0, nil, err
|
||||
}
|
||||
@ -1028,6 +1028,7 @@ func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, er
|
||||
o.rsa = new(syscall.RawSockaddrAny)
|
||||
o.msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa))
|
||||
o.msg.Namelen = int32(unsafe.Sizeof(*o.rsa))
|
||||
o.msg.Flags = uint32(flags)
|
||||
n, err := execIO(o, func(o *operation) error {
|
||||
return windows.WSARecvMsg(o.fd.Sysfd, &o.msg, &o.qty, &o.o, nil)
|
||||
})
|
||||
|
@ -64,10 +64,10 @@ func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
|
||||
return n, sa, wrapSyscallError(readFromSyscallName, err)
|
||||
}
|
||||
|
||||
func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
|
||||
n, oobn, flags, sa, err = fd.pfd.ReadMsg(p, oob)
|
||||
func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
|
||||
n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags)
|
||||
runtime.KeepAlive(fd)
|
||||
return n, oobn, flags, sa, wrapSyscallError(readMsgSyscallName, err)
|
||||
return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err)
|
||||
}
|
||||
|
||||
func (fd *netFD) Write(p []byte) (nn int, err error) {
|
||||
|
@ -75,7 +75,7 @@ func stripIPv4Header(n int, b []byte) int {
|
||||
|
||||
func (c *IPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) {
|
||||
var sa syscall.Sockaddr
|
||||
n, oobn, flags, sa, err = c.fd.readMsg(b, oob)
|
||||
n, oobn, flags, sa, err = c.fd.readMsg(b, oob, 0)
|
||||
switch sa := sa.(type) {
|
||||
case *syscall.SockaddrInet4:
|
||||
addr = &IPAddr{IP: sa.Addr[0:]}
|
||||
|
@ -268,7 +268,7 @@ func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
|
||||
return 0, nil, syscall.ENOSYS
|
||||
}
|
||||
|
||||
func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
|
||||
func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
|
||||
return 0, 0, 0, nil, syscall.ENOSYS
|
||||
}
|
||||
|
||||
|
@ -56,7 +56,7 @@ func (c *UDPConn) readFrom(b []byte, addr *UDPAddr) (int, *UDPAddr, error) {
|
||||
|
||||
func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) {
|
||||
var sa syscall.Sockaddr
|
||||
n, oobn, flags, sa, err = c.fd.readMsg(b, oob)
|
||||
n, oobn, flags, sa, err = c.fd.readMsg(b, oob, 0)
|
||||
switch sa := sa.(type) {
|
||||
case *syscall.SockaddrInet4:
|
||||
addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port}
|
||||
|
@ -113,7 +113,11 @@ func (c *UnixConn) readFrom(b []byte) (int, *UnixAddr, error) {
|
||||
|
||||
func (c *UnixConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) {
|
||||
var sa syscall.Sockaddr
|
||||
n, oobn, flags, sa, err = c.fd.readMsg(b, oob)
|
||||
n, oobn, flags, sa, err = c.fd.readMsg(b, oob, readMsgFlags)
|
||||
if oobn > 0 {
|
||||
setReadMsgCloseOnExec(oob[:oobn])
|
||||
}
|
||||
|
||||
switch sa := sa.(type) {
|
||||
case *syscall.SockaddrUnix:
|
||||
if sa.Name != "" {
|
||||
|
17
src/net/unixsock_readmsg_linux.go
Normal file
17
src/net/unixsock_readmsg_linux.go
Normal file
@ -0,0 +1,17 @@
|
||||
// Copyright 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
const readMsgFlags = syscall.MSG_CMSG_CLOEXEC
|
||||
|
||||
func setReadMsgCloseOnExec(oob []byte) {
|
||||
}
|
13
src/net/unixsock_readmsg_other.go
Normal file
13
src/net/unixsock_readmsg_other.go
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build (js && wasm) || windows
|
||||
// +build js,wasm windows
|
||||
|
||||
package net
|
||||
|
||||
const readMsgFlags = 0
|
||||
|
||||
func setReadMsgCloseOnExec(oob []byte) {
|
||||
}
|
33
src/net/unixsock_readmsg_posix.go
Normal file
33
src/net/unixsock_readmsg_posix.go
Normal file
@ -0,0 +1,33 @@
|
||||
// Copyright 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build aix || darwin || dragonfly || freebsd || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd netbsd openbsd solaris
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
const readMsgFlags = 0
|
||||
|
||||
func setReadMsgCloseOnExec(oob []byte) {
|
||||
scms, err := syscall.ParseSocketControlMessage(oob)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, scm := range scms {
|
||||
if scm.Header.Level == syscall.SOL_SOCKET && scm.Header.Type == syscall.SCM_RIGHTS {
|
||||
fds, err := syscall.ParseUnixRights(&scm)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, fd := range fds {
|
||||
syscall.CloseOnExec(fd)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
105
src/net/unixsock_readmsg_test.go
Normal file
105
src/net/unixsock_readmsg_test.go
Normal file
@ -0,0 +1,105 @@
|
||||
// Copyright 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUnixConnReadMsgUnixSCMRightsCloseOnExec(t *testing.T) {
|
||||
if !testableNetwork("unix") {
|
||||
t.Skip("not unix system")
|
||||
}
|
||||
|
||||
scmFile, err := os.Open(os.DevNull)
|
||||
if err != nil {
|
||||
t.Fatalf("file open: %v", err)
|
||||
}
|
||||
defer scmFile.Close()
|
||||
|
||||
rights := syscall.UnixRights(int(scmFile.Fd()))
|
||||
fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Socketpair: %v", err)
|
||||
}
|
||||
|
||||
writeFile := os.NewFile(uintptr(fds[0]), "write-socket")
|
||||
defer writeFile.Close()
|
||||
readFile := os.NewFile(uintptr(fds[1]), "read-socket")
|
||||
defer readFile.Close()
|
||||
|
||||
cw, err := FileConn(writeFile)
|
||||
if err != nil {
|
||||
t.Fatalf("FileConn: %v", err)
|
||||
}
|
||||
defer cw.Close()
|
||||
cr, err := FileConn(readFile)
|
||||
if err != nil {
|
||||
t.Fatalf("FileConn: %v", err)
|
||||
}
|
||||
defer cr.Close()
|
||||
|
||||
ucw, ok := cw.(*UnixConn)
|
||||
if !ok {
|
||||
t.Fatalf("got %T; want UnixConn", cw)
|
||||
}
|
||||
ucr, ok := cr.(*UnixConn)
|
||||
if !ok {
|
||||
t.Fatalf("got %T; want UnixConn", cr)
|
||||
}
|
||||
|
||||
oob := make([]byte, syscall.CmsgSpace(4))
|
||||
err = ucw.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("Can't set unix connection timeout: %v", err)
|
||||
}
|
||||
_, _, err = ucw.WriteMsgUnix(nil, rights, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("UnixConn readMsg: %v", err)
|
||||
}
|
||||
err = ucr.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("Can't set unix connection timeout: %v", err)
|
||||
}
|
||||
_, oobn, _, _, err := ucr.ReadMsgUnix(nil, oob)
|
||||
if err != nil {
|
||||
t.Fatalf("UnixConn readMsg: %v", err)
|
||||
}
|
||||
|
||||
scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
|
||||
if err != nil {
|
||||
t.Fatalf("ParseSocketControlMessage: %v", err)
|
||||
}
|
||||
if len(scms) != 1 {
|
||||
t.Fatalf("got scms = %#v; expected 1 SocketControlMessage", scms)
|
||||
}
|
||||
scm := scms[0]
|
||||
gotFds, err := syscall.ParseUnixRights(&scm)
|
||||
if err != nil {
|
||||
t.Fatalf("syscall.ParseUnixRights: %v", err)
|
||||
}
|
||||
if len(gotFds) != 1 {
|
||||
t.Fatalf("got FDs %#v: wanted only 1 fd", gotFds)
|
||||
}
|
||||
defer func() {
|
||||
if err := syscall.Close(int(gotFds[0])); err != nil {
|
||||
t.Fatalf("fail to close gotFds: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
flags, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(gotFds[0]), uintptr(syscall.F_GETFD), 0)
|
||||
if errno != 0 {
|
||||
t.Fatalf("Can't get flags of fd:%#v, with err:%v", gotFds[0], errno)
|
||||
}
|
||||
if flags&syscall.FD_CLOEXEC == 0 {
|
||||
t.Fatalf("got flags %#x, want %#x (FD_CLOEXEC) set", flags, syscall.FD_CLOEXEC)
|
||||
}
|
||||
}
|
@ -105,8 +105,8 @@ func TestSCMCredentials(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("ReadMsgUnix: %v", err)
|
||||
}
|
||||
if flags != 0 {
|
||||
t.Fatalf("ReadMsgUnix flags = 0x%x, want 0", flags)
|
||||
if flags != syscall.MSG_CMSG_CLOEXEC {
|
||||
t.Fatalf("ReadMsgUnix flags = %#x, want %#x (MSG_CMSG_CLOEXEC)", flags, syscall.MSG_CMSG_CLOEXEC)
|
||||
}
|
||||
if n != tt.dataLen {
|
||||
t.Fatalf("ReadMsgUnix n = %d, want %d", n, tt.dataLen)
|
||||
|
Loading…
Reference in New Issue
Block a user