diff --git a/src/os/user/listgroups_unix.go b/src/os/user/listgroups_unix.go index b620ad36525..ef366fa2800 100644 --- a/src/os/user/listgroups_unix.go +++ b/src/os/user/listgroups_unix.go @@ -9,13 +9,11 @@ package user import ( "bufio" "bytes" - "context" "errors" "fmt" "io" "os" "strconv" - "time" ) func listGroupsFromReader(u *User, r io.Reader) ([]string, error) { @@ -101,13 +99,6 @@ func listGroupsFromReader(u *User, r io.Reader) ([]string, error) { } func listGroups(u *User) ([]string, error) { - if defaultUserdbClient.isUsable() { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) - defer cancel() - if ids, ok, err := defaultUserdbClient.lookupGroupIds(ctx, u.Username); ok { - return ids, err - } - } f, err := os.Open(groupFile) if err != nil { return nil, err diff --git a/src/os/user/lookup_unix.go b/src/os/user/lookup_unix.go index 0ee2ad35ef6..608d9b2140f 100644 --- a/src/os/user/lookup_unix.go +++ b/src/os/user/lookup_unix.go @@ -9,13 +9,11 @@ package user import ( "bufio" "bytes" - "context" "errors" "io" "os" "strconv" "strings" - "time" ) // lineFunc returns a value, an error, or (nil, nil) to skip the row. @@ -200,13 +198,6 @@ func findUsername(name string, r io.Reader) (*User, error) { } func lookupGroup(groupname string) (*Group, error) { - if defaultUserdbClient.isUsable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - if g, ok, err := defaultUserdbClient.lookupGroup(ctx, groupname); ok { - return g, err - } - } f, err := os.Open(groupFile) if err != nil { return nil, err @@ -216,13 +207,6 @@ func lookupGroup(groupname string) (*Group, error) { } func lookupGroupId(id string) (*Group, error) { - if defaultUserdbClient.isUsable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - if g, ok, err := defaultUserdbClient.lookupGroupId(ctx, id); ok { - return g, err - } - } f, err := os.Open(groupFile) if err != nil { return nil, err @@ -232,13 +216,6 @@ func lookupGroupId(id string) (*Group, error) { } func lookupUser(username string) (*User, error) { - if defaultUserdbClient.isUsable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - if u, ok, err := defaultUserdbClient.lookupUser(ctx, username); ok { - return u, err - } - } f, err := os.Open(userFile) if err != nil { return nil, err @@ -248,13 +225,6 @@ func lookupUser(username string) (*User, error) { } func lookupUserId(uid string) (*User, error) { - if defaultUserdbClient.isUsable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - if u, ok, err := defaultUserdbClient.lookupUserId(ctx, uid); ok { - return u, err - } - } f, err := os.Open(userFile) if err != nil { return nil, err diff --git a/src/os/user/user.go b/src/os/user/user.go index 4cf5b7c5156..0307d2ad6a1 100644 --- a/src/os/user/user.go +++ b/src/os/user/user.go @@ -11,10 +11,6 @@ One is written in pure Go and parses /etc/passwd and /etc/group. The other is cgo-based and relies on the standard C library (libc) routines such as getpwuid_r, getgrnam_r, and getgrouplist. -For Linux, the pure Go implementation queries the systemd-userdb service first. -If the service is not available, it falls back to parsing /etc/passwd and -/etc/group. - When cgo is available, and the required routines are implemented in libc for a particular platform, cgo-based (libc-backed) code is used. This can be overridden by using osusergo build tag, which enforces diff --git a/src/os/user/userdbclient.go b/src/os/user/userdbclient.go deleted file mode 100644 index b0f3895ed4a..00000000000 --- a/src/os/user/userdbclient.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2023 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package user - -// userdbClient queries the io.systemd.UserDatabase VARLINK interface provided by -// systemd-userdbd.service(8) on Linux for obtaining full user/group details -// even when cgo is not available. -// VARLINK protocol: https://varlink.org -// Systemd userdb VARLINK interface https://systemd.io/USER_GROUP_API -// dir contains multiple varlink service sockets implementing the userdb interface. -type userdbClient struct { - dir string -} - -// IsUsable checks if the client can be used to make queries. -func (cl userdbClient) isUsable() bool { - return len(cl.dir) != 0 -} - -var defaultUserdbClient userdbClient diff --git a/src/os/user/userdbclient_linux.go b/src/os/user/userdbclient_linux.go deleted file mode 100644 index e585b7f3c3b..00000000000 --- a/src/os/user/userdbclient_linux.go +++ /dev/null @@ -1,772 +0,0 @@ -// Copyright 2023 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build linux - -package user - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "io/fs" - "os" - "strconv" - "strings" - "sync" - "syscall" - "unicode/utf16" - "unicode/utf8" -) - -const ( - // Well known multiplexer service. - svcMultiplexer = "io.systemd.Multiplexer" - - userdbNamespace = "io.systemd.UserDatabase" - - // io.systemd.UserDatabase VARLINK interface methods. - mGetGroupRecord = userdbNamespace + ".GetGroupRecord" - mGetUserRecord = userdbNamespace + ".GetUserRecord" - mGetMemberships = userdbNamespace + ".GetMemberships" - - // io.systemd.UserDatabase VARLINK interface errors. - errNoRecordFound = userdbNamespace + ".NoRecordFound" - errServiceNotAvailable = userdbNamespace + ".ServiceNotAvailable" -) - -func init() { - defaultUserdbClient.dir = "/run/systemd/userdb" -} - -// userdbCall represents a VARLINK service call sent to systemd-userdb. -// method is the VARLINK method to call. -// parameters are the VARLINK parameters to pass. -// more indicates if more responses are expected. -// fastest indicates if only the fastest response should be returned. -type userdbCall struct { - method string - parameters callParameters - more bool - fastest bool -} - -func (u userdbCall) marshalJSON(service string) ([]byte, error) { - params, err := u.parameters.marshalJSON(service) - if err != nil { - return nil, err - } - var data bytes.Buffer - data.WriteString(`{"method":"`) - data.WriteString(u.method) - data.WriteString(`","parameters":`) - data.Write(params) - if u.more { - data.WriteString(`,"more":true`) - } - data.WriteString(`}`) - return data.Bytes(), nil -} - -type callParameters struct { - uid *int64 - userName string - gid *int64 - groupName string -} - -func (c callParameters) marshalJSON(service string) ([]byte, error) { - var data bytes.Buffer - data.WriteString(`{"service":"`) - data.WriteString(service) - data.WriteString(`"`) - if c.uid != nil { - data.WriteString(`,"uid":`) - data.WriteString(strconv.FormatInt(*c.uid, 10)) - } - if c.userName != "" { - data.WriteString(`,"userName":"`) - data.WriteString(c.userName) - data.WriteString(`"`) - } - if c.gid != nil { - data.WriteString(`,"gid":`) - data.WriteString(strconv.FormatInt(*c.gid, 10)) - } - if c.groupName != "" { - data.WriteString(`,"groupName":"`) - data.WriteString(c.groupName) - data.WriteString(`"`) - } - data.WriteString(`}`) - return data.Bytes(), nil -} - -type userdbReply struct { - continues bool - errorStr string -} - -func (u *userdbReply) unmarshalJSON(data []byte) error { - var ( - kContinues = []byte(`"continues"`) - kError = []byte(`"error"`) - ) - if i := bytes.Index(data, kContinues); i != -1 { - continues, err := parseJSONBoolean(data[i+len(kContinues):]) - if err != nil { - return err - } - u.continues = continues - } - if i := bytes.Index(data, kError); i != -1 { - errStr, err := parseJSONString(data[i+len(kError):]) - if err != nil { - return err - } - u.errorStr = errStr - } - return nil -} - -// response is the parsed reply from a method call to systemd-userdb. -// data is one or more VARLINK response parameters separated by 0. -// handled indicates if the call was handled by systemd-userdb. -// err is any error encountered. -type response struct { - data []byte - handled bool - err error -} - -// querySocket calls the io.systemd.UserDatabase VARLINK interface at sock with request. -// Multiple replies can be read by setting more to true in the request. -// Reply parameters are accumulated separated by 0, if there are many. -// Replies with io.systemd.UserDatabase.NoRecordFound errors are skipped. -// Other UserDatabase errors are returned as is. -// If the socket does not exist, or if the io.systemd.UserDatabase.ServiceNotAvailable -// error is seen in a response, the query is considered unhandled. -func querySocket(ctx context.Context, sock string, request []byte) response { - sockFd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - return response{err: err} - } - defer syscall.Close(sockFd) - if err := syscall.Connect(sockFd, &syscall.SockaddrUnix{Name: sock}); err != nil { - if errors.Is(err, os.ErrNotExist) { - return response{err: err} - } - return response{handled: true, err: err} - } - - // Null terminate request. - if request[len(request)-1] != 0 { - request = append(request, 0) - } - - // Write request to socket. - written := 0 - for written < len(request) { - if ctx.Err() != nil { - return response{handled: true, err: ctx.Err()} - } - if n, err := syscall.Write(sockFd, request[written:]); err != nil { - return response{handled: true, err: err} - } else { - written += n - } - } - - // Read response. - var resp bytes.Buffer - for { - if ctx.Err() != nil { - return response{handled: true, err: ctx.Err()} - } - buf := make([]byte, 4096) - if n, err := syscall.Read(sockFd, buf); err != nil { - return response{handled: true, err: err} - } else if n > 0 { - resp.Write(buf[:n]) - if buf[n-1] == 0 { - break - } - } else { - // EOF - break - } - } - - if resp.Len() == 0 { - return response{handled: true} - } - - buf := resp.Bytes() - // Remove trailing 0. - buf = buf[:len(buf)-1] - // Split into VARLINK messages. - msgs := bytes.Split(buf, []byte{0}) - - // Parse VARLINK messages. - for _, m := range msgs { - var resp userdbReply - if err := resp.unmarshalJSON(m); err != nil { - return response{handled: true, err: err} - } - // Handle VARLINK message errors. - switch e := resp.errorStr; e { - case "": - case errNoRecordFound: // Ignore not found error. - continue - case errServiceNotAvailable: - return response{} - default: - return response{handled: true, err: errors.New(e)} - } - if !resp.continues { - break - } - } - return response{data: buf, handled: true, err: ctx.Err()} -} - -// queryMany calls the io.systemd.UserDatabase VARLINK interface on many services at once. -// ss is a slice of userdb services to call. Each service must have a socket in cl.dir. -// c is sent to all services in ss. If c.fastest is true, only the fastest reply is read. -// Otherwise all replies are aggregated. um is called with aggregated reply parameters. -// queryMany returns the first error encountered. The first result is false if no userdb -// socket is available or if all requests time out. -func (cl userdbClient) queryMany(ctx context.Context, ss []string, c *userdbCall, um jsonUnmarshaler) (bool, error) { - responseCh := make(chan response, len(ss)) - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - // Query all services in parallel. - var workers sync.WaitGroup - for _, svc := range ss { - data, err := c.marshalJSON(svc) - if err != nil { - return true, err - } - // Spawn worker to query service. - workers.Add(1) - go func(sock string, data []byte) { - defer workers.Done() - responseCh <- querySocket(ctx, sock, data) - }(cl.dir+"/"+svc, data) - } - - go func() { - // Clean up workers. - workers.Wait() - close(responseCh) - }() - - var result bytes.Buffer - var notOk int -RecvResponses: - for { - select { - case resp, ok := <-responseCh: - if !ok { - // Responses channel is closed so stop reading. - break RecvResponses - } - if resp.err != nil { - // querySocket only returns unrecoverable errors, - // so return the first one received. - return true, resp.err - } - if !resp.handled { - notOk++ - continue - } - - first := result.Len() == 0 - result.Write(resp.data) - if first && c.fastest { - // Return the fastest response. - break RecvResponses - } - case <-ctx.Done(): - // If requests time out, userdb is unavailable. - return ctx.Err() != context.DeadlineExceeded, nil - } - } - // If all sockets are not ok, userdb is unavailable. - if notOk == len(ss) { - return false, nil - } - return true, um.unmarshalJSON(result.Bytes()) -} - -// services enumerates userdb service sockets in dir. -// If ok is false, io.systemd.UserDatabase service does not exist. -func (cl userdbClient) services() (s []string, ok bool, err error) { - var entries []fs.DirEntry - if entries, err = os.ReadDir(cl.dir); err != nil { - ok = !os.IsNotExist(err) - return - } - ok = true - for _, ent := range entries { - s = append(s, ent.Name()) - } - return -} - -// query looks up users/groups on the io.systemd.UserDatabase VARLINK interface. -// If the multiplexer service is available, the call is sent only to it. -// Otherwise, the call is sent simultaneously to all UserDatabase services in cl.dir. -// The fastest reply is read and parsed. All other requests are cancelled. -// If the service is unavailable, the first result is false. -// The service is considered unavailable if the requests time-out as well. -func (cl userdbClient) query(ctx context.Context, call *userdbCall, um jsonUnmarshaler) (bool, error) { - services := []string{svcMultiplexer} - if _, err := os.Stat(cl.dir + "/" + svcMultiplexer); err != nil { - // No mux service so call all available services. - var ok bool - if services, ok, err = cl.services(); !ok || err != nil { - return ok, err - } - } - call.fastest = true - if ok, err := cl.queryMany(ctx, services, call, um); !ok || err != nil { - return ok, err - } - return true, nil -} - -type jsonUnmarshaler interface { - unmarshalJSON([]byte) error -} - -func isSpace(c byte) bool { - return c == ' ' || c == '\t' || c == '\r' || c == '\n' -} - -// findElementStart returns a slice of r that starts at the next JSON element. -// It skips over valid JSON space characters and checks for the colon separator. -func findElementStart(r []byte) ([]byte, error) { - var idx int - var b byte - colon := byte(':') - var seenColon bool - for idx, b = range r { - if isSpace(b) { - continue - } - if !seenColon && b == colon { - seenColon = true - continue - } - // Spotted colon and b is not a space, so value starts here. - if seenColon { - break - } - return nil, errors.New("expected colon, got invalid character: " + string(b)) - } - if !seenColon { - return nil, errors.New("expected colon, got end of input") - } - return r[idx:], nil -} - -// parseJSONString reads a JSON string from r. -func parseJSONString(r []byte) (string, error) { - r, err := findElementStart(r) - if err != nil { - return "", err - } - // Smallest valid string is `""`. - if l := len(r); l < 2 { - return "", errors.New("unexpected end of input") - } else if l == 2 { - if bytes.Equal(r, []byte(`""`)) { - return "", nil - } - return "", errors.New("invalid string") - } - - if c := r[0]; c != '"' { - return "", errors.New(`expected " got ` + string(c)) - } - // Advance over opening quote. - r = r[1:] - - var value strings.Builder - var inEsc bool - var inUEsc bool - var strEnds bool - reader := bytes.NewReader(r) - for { - if value.Len() > 4096 { - return "", errors.New("string too large") - } - - // Parse unicode escape sequences. - if inUEsc { - maybeRune := make([]byte, 4) - n, err := reader.Read(maybeRune) - if err != nil || n != 4 { - return "", fmt.Errorf("invalid unicode escape sequence \\u%s", string(maybeRune)) - } - prn, err := strconv.ParseUint(string(maybeRune), 16, 32) - if err != nil { - return "", fmt.Errorf("invalid unicode escape sequence \\u%s", string(maybeRune)) - } - rn := rune(prn) - if !utf16.IsSurrogate(rn) { - value.WriteRune(rn) - inUEsc = false - continue - } - // rn maybe a high surrogate; read the low surrogate. - maybeRune = make([]byte, 6) - n, err = reader.Read(maybeRune) - if err != nil || n != 6 || maybeRune[0] != '\\' || maybeRune[1] != 'u' { - // Not a valid UTF-16 surrogate pair. - if _, err := reader.Seek(int64(-n), io.SeekCurrent); err != nil { - return "", err - } - // Invalid low surrogate; write the replacement character. - value.WriteRune(utf8.RuneError) - } else { - rn1, err := strconv.ParseUint(string(maybeRune[2:]), 16, 32) - if err != nil { - return "", fmt.Errorf("invalid unicode escape sequence %s", string(maybeRune)) - } - // Check if rn and rn1 are valid UTF-16 surrogate pairs. - if dec := utf16.DecodeRune(rn, rune(rn1)); dec != utf8.RuneError { - n = utf8.EncodeRune(maybeRune, dec) - // Write the decoded rune. - value.Write(maybeRune[:n]) - } - } - inUEsc = false - continue - } - - if inEsc { - b, err := reader.ReadByte() - if err != nil { - return "", err - } - switch b { - case 'b': - value.WriteByte('\b') - case 'f': - value.WriteByte('\f') - case 'n': - value.WriteByte('\n') - case 'r': - value.WriteByte('\r') - case 't': - value.WriteByte('\t') - case 'u': - inUEsc = true - case '/': - value.WriteByte('/') - case '\\': - value.WriteByte('\\') - case '"': - value.WriteByte('"') - default: - return "", errors.New("unexpected character in escape sequence " + string(b)) - } - inEsc = false - continue - } else { - rn, _, err := reader.ReadRune() - if err != nil { - if err == io.EOF { - break - } - return "", err - } - if rn == '\\' { - inEsc = true - continue - } - if rn == '"' { - // String ends on un-escaped quote. - strEnds = true - break - } - value.WriteRune(rn) - } - } - if !strEnds { - return "", errors.New("unexpected end of input") - } - return value.String(), nil -} - -// parseJSONInt64 reads a 64 bit integer from r. -func parseJSONInt64(r []byte) (int64, error) { - r, err := findElementStart(r) - if err != nil { - return 0, err - } - var num strings.Builder - for _, b := range r { - // int64 max is 19 digits long. - if num.Len() == 20 { - return 0, errors.New("number too large") - } - if strings.ContainsRune("0123456789", rune(b)) { - num.WriteByte(b) - } else { - break - } - } - n, err := strconv.ParseInt(num.String(), 10, 64) - return int64(n), err -} - -// parseJSONBoolean reads a boolean from r. -func parseJSONBoolean(r []byte) (bool, error) { - r, err := findElementStart(r) - if err != nil { - return false, err - } - if bytes.HasPrefix(r, []byte("true")) { - return true, nil - } - if bytes.HasPrefix(r, []byte("false")) { - return false, nil - } - return false, errors.New("unable to parse boolean value") -} - -type groupRecord struct { - groupName string - gid int64 -} - -func (g *groupRecord) unmarshalJSON(data []byte) error { - var ( - kGroupName = []byte(`"groupName"`) - kGid = []byte(`"gid"`) - ) - if i := bytes.Index(data, kGroupName); i != -1 { - groupname, err := parseJSONString(data[i+len(kGroupName):]) - if err != nil { - return err - } - g.groupName = groupname - } - if i := bytes.Index(data, kGid); i != -1 { - gid, err := parseJSONInt64(data[i+len(kGid):]) - if err != nil { - return err - } - g.gid = gid - } - return nil -} - -// queryGroupDb queries the userdb interface for a gid, groupname, or both. -func (cl userdbClient) queryGroupDb(ctx context.Context, gid *int64, groupname string) (*Group, bool, error) { - group := groupRecord{} - request := userdbCall{ - method: mGetGroupRecord, - parameters: callParameters{gid: gid, groupName: groupname}, - } - if ok, err := cl.query(ctx, &request, &group); !ok || err != nil { - return nil, ok, fmt.Errorf("error querying systemd-userdb group record: %s", err) - } - return &Group{ - Name: group.groupName, - Gid: strconv.FormatInt(group.gid, 10), - }, true, nil -} - -type userRecord struct { - userName string - realName string - uid int64 - gid int64 - homeDirectory string -} - -func (u *userRecord) unmarshalJSON(data []byte) error { - var ( - kUserName = []byte(`"userName"`) - kRealName = []byte(`"realName"`) - kUid = []byte(`"uid"`) - kGid = []byte(`"gid"`) - kHomeDirectory = []byte(`"homeDirectory"`) - ) - if i := bytes.Index(data, kUserName); i != -1 { - username, err := parseJSONString(data[i+len(kUserName):]) - if err != nil { - return err - } - u.userName = username - } - if i := bytes.Index(data, kRealName); i != -1 { - realname, err := parseJSONString(data[i+len(kRealName):]) - if err != nil { - return err - } - u.realName = realname - } - if i := bytes.Index(data, kUid); i != -1 { - uid, err := parseJSONInt64(data[i+len(kUid):]) - if err != nil { - return err - } - u.uid = uid - } - if i := bytes.Index(data, kGid); i != -1 { - gid, err := parseJSONInt64(data[i+len(kGid):]) - if err != nil { - return err - } - u.gid = gid - } - if i := bytes.Index(data, kHomeDirectory); i != -1 { - homedir, err := parseJSONString(data[i+len(kHomeDirectory):]) - if err != nil { - return err - } - u.homeDirectory = homedir - } - return nil -} - -// queryUserDb queries the userdb interface for a uid, username, or both. -func (cl userdbClient) queryUserDb(ctx context.Context, uid *int64, username string) (*User, bool, error) { - user := userRecord{} - request := userdbCall{ - method: mGetUserRecord, - parameters: callParameters{ - uid: uid, - userName: username, - }, - } - if ok, err := cl.query(ctx, &request, &user); !ok || err != nil { - return nil, ok, fmt.Errorf("error querying systemd-userdb user record: %s", err) - } - return &User{ - Uid: strconv.FormatInt(user.uid, 10), - Gid: strconv.FormatInt(user.gid, 10), - Username: user.userName, - Name: user.realName, - HomeDir: user.homeDirectory, - }, true, nil -} - -func (cl userdbClient) lookupGroup(ctx context.Context, groupname string) (*Group, bool, error) { - return cl.queryGroupDb(ctx, nil, groupname) -} - -func (cl userdbClient) lookupGroupId(ctx context.Context, id string) (*Group, bool, error) { - gid, err := strconv.ParseInt(id, 10, 64) - if err != nil { - return nil, true, err - } - return cl.queryGroupDb(ctx, &gid, "") -} - -func (cl userdbClient) lookupUser(ctx context.Context, username string) (*User, bool, error) { - return cl.queryUserDb(ctx, nil, username) -} - -func (cl userdbClient) lookupUserId(ctx context.Context, id string) (*User, bool, error) { - uid, err := strconv.ParseInt(id, 10, 64) - if err != nil { - return nil, true, err - } - return cl.queryUserDb(ctx, &uid, "") -} - -type memberships struct { - // Keys are groupNames and values are sets of userNames. - groupUsers map[string]map[string]struct{} -} - -// unmarshalJSON expects many (userName, groupName) records separated by a null byte. -// This is used to build a membership map. -func (m *memberships) unmarshalJSON(data []byte) error { - if m.groupUsers == nil { - m.groupUsers = make(map[string]map[string]struct{}) - } - var ( - kUserName = []byte(`"userName"`) - kGroupName = []byte(`"groupName"`) - ) - // Split records by null terminator. - records := bytes.Split(data, []byte{byte(0)}) - for _, rec := range records { - if len(rec) == 0 { - continue - } - var groupName string - var userName string - var err error - if i := bytes.Index(rec, kGroupName); i != -1 { - if groupName, err = parseJSONString(rec[i+len(kGroupName):]); err != nil { - return err - } - } - if i := bytes.Index(rec, kUserName); i != -1 { - if userName, err = parseJSONString(rec[i+len(kUserName):]); err != nil { - return err - } - } - // Associate userName with groupName. - if groupName != "" && userName != "" { - if _, ok := m.groupUsers[groupName]; ok { - m.groupUsers[groupName][userName] = struct{}{} - } else { - m.groupUsers[groupName] = map[string]struct{}{userName: {}} - } - } - } - return nil -} - -func (cl userdbClient) lookupGroupIds(ctx context.Context, username string) ([]string, bool, error) { - services, ok, err := cl.services() - if !ok || err != nil { - return nil, ok, err - } - // Fetch group memberships for username. - var ms memberships - request := userdbCall{ - method: mGetMemberships, - parameters: callParameters{userName: username}, - more: true, - } - if ok, err := cl.queryMany(ctx, services, &request, &ms); !ok || err != nil { - return nil, ok, fmt.Errorf("error querying systemd-userdb memberships record: %s", err) - } - // Fetch user group gid. - var group groupRecord - request = userdbCall{ - method: mGetGroupRecord, - parameters: callParameters{groupName: username}, - } - if ok, err := cl.query(ctx, &request, &group); !ok || err != nil { - return nil, ok, err - } - gids := []string{strconv.FormatInt(group.gid, 10)} - - // Fetch group records for each group. - for g := range ms.groupUsers { - var group groupRecord - request.parameters.groupName = g - // Query group for gid. - if ok, err := cl.query(ctx, &request, &group); !ok || err != nil { - return nil, ok, fmt.Errorf("error querying systemd-userdb group record: %s", err) - } - gids = append(gids, strconv.FormatInt(group.gid, 10)) - } - return gids, true, nil -} diff --git a/src/os/user/userdbclient_linux_test.go b/src/os/user/userdbclient_linux_test.go deleted file mode 100644 index 1b9a336f720..00000000000 --- a/src/os/user/userdbclient_linux_test.go +++ /dev/null @@ -1,504 +0,0 @@ -// Copyright 2023 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build linux - -package user - -import ( - "bytes" - "context" - "errors" - "reflect" - "sort" - "strconv" - "strings" - "sync" - "syscall" - "testing" - "time" - "unicode/utf8" -) - -func TestQueryNoUserdb(t *testing.T) { - cl := &userdbClient{dir: "/non/existent"} - if _, ok, err := cl.lookupGroup(context.Background(), "stdlibcontrib"); ok { - t.Fatalf("should fail but lookup has been handled or error is nil: %v", err) - } -} - -type userdbTestData map[string]udbResponse - -type udbResponse struct { - data []byte - delay time.Duration -} - -func userdbServer(t *testing.T, sockFn string, data userdbTestData) { - ready := make(chan struct{}) - go func() { - if err := serveUserdb(ready, sockFn, data); err != nil { - t.Error(err) - } - }() - <-ready -} - -func (u userdbTestData) String() string { - var s strings.Builder - for k, v := range u { - s.WriteString("Request:\n") - s.WriteString(k) - s.WriteString("\nResponse:\n") - if v.delay > 0 { - s.WriteString("Delay: ") - s.WriteString(v.delay.String()) - s.WriteString("\n") - } - s.WriteString("Data:\n") - s.Write(v.data) - s.WriteString("\n") - } - return s.String() -} - -// serverUserdb is a simple userdb server that replies to VARLINK method calls. -// A message is sent on the ready channel when the server is ready to accept calls. -// The server will reply to each request in the data map. If a request is not -// found in the map, the server will return an error. -func serveUserdb(ready chan<- struct{}, sockFn string, data userdbTestData) error { - sockFd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - return err - } - defer syscall.Close(sockFd) - if err := syscall.Bind(sockFd, &syscall.SockaddrUnix{Name: sockFn}); err != nil { - return err - } - if err := syscall.Listen(sockFd, 1); err != nil { - return err - } - - // Send ready signal. - ready <- struct{}{} - - var srvGroup sync.WaitGroup - - srvErrs := make(chan error, len(data)) - for len(data) != 0 { - nfd, _, err := syscall.Accept(sockFd) - if err != nil { - syscall.Close(nfd) - return err - } - - // Read request. - buf := make([]byte, 4096) - n, err := syscall.Read(nfd, buf) - if err != nil { - syscall.Close(nfd) - return err - } - if n == 0 { - // Client went away. - continue - } - if buf[n-1] != 0 { - syscall.Close(nfd) - return errors.New("request not null terminated") - } - // Remove null terminator. - buf = buf[:n-1] - got := string(buf) - - // Fetch response for request. - response, ok := data[got] - if !ok { - syscall.Close(nfd) - msg := "unexpected request:\n" + got + "\n\ndata:\n" + data.String() - return errors.New(msg) - } - delete(data, got) - - srvGroup.Add(1) - go func() { - defer srvGroup.Done() - if err := serveClient(nfd, response); err != nil { - srvErrs <- err - } - }() - } - - srvGroup.Wait() - // Combine serve errors if any. - if len(srvErrs) > 0 { - var errs []error - for err := range srvErrs { - errs = append(errs, err) - } - return errors.Join(errs...) - } - - return nil -} - -func serveClient(fd int, response udbResponse) error { - defer syscall.Close(fd) - time.Sleep(response.delay) - data := response.data - if len(data) != 0 && data[len(data)-1] != 0 { - data = append(data, 0) - } - written := 0 - for written < len(data) { - if n, err := syscall.Write(fd, data[written:]); err != nil { - return err - } else { - written += n - } - } - return nil -} - -func TestSlowUserdbLookup(t *testing.T) { - tmpdir := t.TempDir() - data := userdbTestData{ - `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{ - delay: time.Hour, - }, - } - userdbServer(t, tmpdir+"/"+svcMultiplexer, data) - cl := &userdbClient{dir: tmpdir} - // Lookup should timeout. - ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond) - defer cancel() - if _, ok, _ := cl.lookupGroup(ctx, "stdlibcontrib"); ok { - t.Fatalf("lookup should not be handled but was") - } -} - -func TestFastestUserdbLookup(t *testing.T) { - tmpdir := t.TempDir() - fastData := userdbTestData{ - `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"fast","groupName":"stdlibcontrib"}}`: udbResponse{ - data: []byte( - `{"parameters":{"record":{"groupName":"stdlibcontrib","gid":181,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`, - ), - }, - } - slowData := userdbTestData{ - `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"slow","groupName":"stdlibcontrib"}}`: udbResponse{ - delay: 50 * time.Millisecond, - data: []byte( - `{"parameters":{"record":{"groupName":"stdlibcontrib","gid":182,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`, - ), - }, - } - userdbServer(t, tmpdir+"/"+"fast", fastData) - userdbServer(t, tmpdir+"/"+"slow", slowData) - cl := &userdbClient{dir: tmpdir} - group, ok, err := cl.lookupGroup(context.Background(), "stdlibcontrib") - if !ok { - t.Fatalf("lookup should be handled but was not") - } - if err != nil { - t.Fatalf("lookup should not fail but did: %v", err) - } - if group.Gid != "181" { - t.Fatalf("lookup should return group 181 but returned %s", group.Gid) - } -} - -func TestUserdbLookupGroup(t *testing.T) { - tmpdir := t.TempDir() - data := userdbTestData{ - `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{ - data: []byte( - `{"parameters":{"record":{"groupName":"stdlibcontrib","gid":181,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`, - ), - }, - } - userdbServer(t, tmpdir+"/"+svcMultiplexer, data) - - groupname := "stdlibcontrib" - want := &Group{ - Name: "stdlibcontrib", - Gid: "181", - } - cl := &userdbClient{dir: tmpdir} - got, ok, err := cl.lookupGroup(context.Background(), groupname) - if !ok { - t.Fatal("lookup should have been handled") - } - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, want) { - t.Fatalf("lookupGroup(%s) = %v, want %v", groupname, got, want) - } -} - -func TestUserdbLookupUser(t *testing.T) { - tmpdir := t.TempDir() - data := userdbTestData{ - `{"method":"io.systemd.UserDatabase.GetUserRecord","parameters":{"service":"io.systemd.Multiplexer","userName":"stdlibcontrib"}}`: udbResponse{ - data: []byte( - `{"parameters":{"record":{"userName":"stdlibcontrib","uid":181,"gid":181,"realName":"Stdlib Contrib","homeDirectory":"/home/stdlibcontrib","status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`, - ), - }, - } - userdbServer(t, tmpdir+"/"+svcMultiplexer, data) - - username := "stdlibcontrib" - want := &User{ - Uid: "181", - Gid: "181", - Username: "stdlibcontrib", - Name: "Stdlib Contrib", - HomeDir: "/home/stdlibcontrib", - } - cl := &userdbClient{dir: tmpdir} - got, ok, err := cl.lookupUser(context.Background(), username) - if !ok { - t.Fatal("lookup should have been handled") - } - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, want) { - t.Fatalf("lookupUser(%s) = %v, want %v", username, got, want) - } -} - -func TestUserdbLookupGroupIds(t *testing.T) { - tmpdir := t.TempDir() - data := userdbTestData{ - `{"method":"io.systemd.UserDatabase.GetMemberships","parameters":{"service":"io.systemd.Multiplexer","userName":"stdlibcontrib"},"more":true}`: udbResponse{ - data: []byte( - `{"parameters":{"userName":"stdlibcontrib","groupName":"stdlib"},"continues":true}` + "\x00" + `{"parameters":{"userName":"stdlibcontrib","groupName":"contrib"}}`, - ), - }, - // group records - `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{ - data: []byte( - `{"parameters":{"record":{"groupName":"stdlibcontrib","members":["stdlibcontrib"],"gid":181,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`, - ), - }, - `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlib"}}`: udbResponse{ - data: []byte( - `{"parameters":{"record":{"groupName":"stdlib","members":["stdlibcontrib"],"gid":182,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`, - ), - }, - `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"contrib"}}`: udbResponse{ - data: []byte( - `{"parameters":{"record":{"groupName":"contrib","members":["stdlibcontrib"],"gid":183,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`, - ), - }, - } - userdbServer(t, tmpdir+"/"+svcMultiplexer, data) - - username := "stdlibcontrib" - want := []string{"181", "182", "183"} - cl := &userdbClient{dir: tmpdir} - got, ok, err := cl.lookupGroupIds(context.Background(), username) - if !ok { - t.Fatal("lookup should have been handled") - } - if err != nil { - t.Fatal(err) - } - // Result order is not specified so sort it. - sort.Strings(got) - if !reflect.DeepEqual(got, want) { - t.Fatalf("lookupGroupIds(%s) = %v, want %v", username, got, want) - } -} - -var findElementStartTestCases = []struct { - in []byte - want []byte - err bool -}{ - {in: []byte(`:`), want: []byte(``)}, - {in: []byte(`: `), want: []byte(``)}, - {in: []byte(`:"foo"`), want: []byte(`"foo"`)}, - {in: []byte(` :"foo"`), want: []byte(`"foo"`)}, - {in: []byte(` 1231 :"foo"`), err: true}, - {in: []byte(``), err: true}, - {in: []byte(`"foo"`), err: true}, - {in: []byte(`foo`), err: true}, -} - -func TestFindElementStart(t *testing.T) { - for i, tc := range findElementStartTestCases { - t.Run("#"+strconv.Itoa(i), func(t *testing.T) { - got, err := findElementStart(tc.in) - if tc.err && err == nil { - t.Errorf("want err for findElementStart(%s), got nil", tc.in) - } - if !tc.err { - if err != nil { - t.Errorf("findElementStart(%s) unexpected error: %s", tc.in, err.Error()) - } - if !bytes.Contains(tc.in, got) { - t.Errorf("%s should contain %s but does not", tc.in, got) - } - } - }) - } -} - -func FuzzFindElementStart(f *testing.F) { - for _, tc := range findElementStartTestCases { - if !tc.err { - f.Add(tc.in) - } - } - f.Fuzz(func(t *testing.T, b []byte) { - if out, err := findElementStart(b); err == nil && !bytes.Contains(b, out) { - t.Errorf("%s, %v", out, err) - } - }) -} - -var parseJSONStringTestCases = []struct { - in []byte - want string - err bool -}{ - {in: []byte(`:""`)}, - {in: []byte(`:"\n"`), want: "\n"}, - {in: []byte(`: "\""`), want: "\""}, - {in: []byte(`:"\t \\"`), want: "\t \\"}, - {in: []byte(`:"\\\\"`), want: `\\`}, - {in: []byte(`::`), err: true}, - {in: []byte(`""`), err: true}, - {in: []byte(`"`), err: true}, - {in: []byte(":\"0\xE5"), err: true}, - {in: []byte{':', '"', 0xFE, 0xFE, 0xFF, 0xFF, '"'}, want: "\uFFFD\uFFFD\uFFFD\uFFFD"}, - {in: []byte(`:"\u0061a"`), want: "aa"}, - {in: []byte(`:"\u0159\u0170"`), want: "řŰ"}, - {in: []byte(`:"\uD800\uDC00"`), want: "\U00010000"}, - {in: []byte(`:"\uD800"`), want: "\uFFFD"}, - {in: []byte(`:"\u000"`), err: true}, - {in: []byte(`:"\u00MF"`), err: true}, - {in: []byte(`:"\uD800\uDC0"`), err: true}, -} - -func TestParseJSONString(t *testing.T) { - for i, tc := range parseJSONStringTestCases { - t.Run("#"+strconv.Itoa(i), func(t *testing.T) { - got, err := parseJSONString(tc.in) - if tc.err && err == nil { - t.Errorf("want err for parseJSONString(%s), got nil", tc.in) - } - if !tc.err { - if err != nil { - t.Errorf("parseJSONString(%s) unexpected error: %s", tc.in, err.Error()) - } - if tc.want != got { - t.Errorf("parseJSONString(%s) = %s, want %s", tc.in, got, tc.want) - } - } - }) - } -} - -func FuzzParseJSONString(f *testing.F) { - for _, tc := range parseJSONStringTestCases { - f.Add(tc.in) - } - f.Fuzz(func(t *testing.T, b []byte) { - if out, err := parseJSONString(b); err == nil && !utf8.ValidString(out) { - t.Errorf("parseJSONString(%s) = %s, invalid string", b, out) - } - }) -} - -var parseJSONInt64TestCases = []struct { - in []byte - want int64 - err bool -}{ - {in: []byte(":1235"), want: 1235}, - {in: []byte(": 123"), want: 123}, - {in: []byte(":0")}, - {in: []byte(":5012313123131231"), want: 5012313123131231}, - {in: []byte("1231"), err: true}, -} - -func TestParseJSONInt64(t *testing.T) { - for i, tc := range parseJSONInt64TestCases { - t.Run("#"+strconv.Itoa(i), func(t *testing.T) { - got, err := parseJSONInt64(tc.in) - if tc.err && err == nil { - t.Errorf("want err for parseJSONInt64(%s), got nil", tc.in) - } - if !tc.err { - if err != nil { - t.Errorf("parseJSONInt64(%s) unexpected error: %s", tc.in, err.Error()) - } - if tc.want != got { - t.Errorf("parseJSONInt64(%s) = %d, want %d", tc.in, got, tc.want) - } - } - }) - } -} - -func FuzzParseJSONInt64(f *testing.F) { - for _, tc := range parseJSONInt64TestCases { - f.Add(tc.in) - } - f.Fuzz(func(t *testing.T, b []byte) { - if out, err := parseJSONInt64(b); err == nil && - !bytes.Contains(b, []byte(strconv.FormatInt(out, 10))) { - t.Errorf("parseJSONInt64(%s) = %d, %v", b, out, err) - } - }) -} - -var parseJSONBooleanTestCases = []struct { - in []byte - want bool - err bool -}{ - {in: []byte(": true "), want: true}, - {in: []byte(":true "), want: true}, - {in: []byte(": false "), want: false}, - {in: []byte(":false "), want: false}, - {in: []byte("true"), err: true}, - {in: []byte("false"), err: true}, - {in: []byte("foo"), err: true}, -} - -func TestParseJSONBoolean(t *testing.T) { - for i, tc := range parseJSONBooleanTestCases { - t.Run("#"+strconv.Itoa(i), func(t *testing.T) { - got, err := parseJSONBoolean(tc.in) - if tc.err && err == nil { - t.Errorf("want err for parseJSONBoolean(%s), got nil", tc.in) - } - if !tc.err { - if err != nil { - t.Errorf("parseJSONBoolean(%s) unexpected error: %s", tc.in, err.Error()) - } - if tc.want != got { - t.Errorf("parseJSONBoolean(%s) = %t, want %t", tc.in, got, tc.want) - } - } - }) - } -} - -func FuzzParseJSONBoolean(f *testing.F) { - for _, tc := range parseJSONBooleanTestCases { - f.Add(tc.in) - } - f.Fuzz(func(t *testing.T, b []byte) { - if out, err := parseJSONBoolean(b); err == nil && !bytes.Contains(b, []byte(strconv.FormatBool(out))) { - t.Errorf("parseJSONBoolean(%s) = %t, %v", b, out, err) - } - }) -} diff --git a/src/os/user/userdbclient_stub.go b/src/os/user/userdbclient_stub.go deleted file mode 100644 index d31f065c3a8..00000000000 --- a/src/os/user/userdbclient_stub.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2023 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !linux - -package user - -import "context" - -func (cl userdbClient) lookupGroup(_ context.Context, _ string) (*Group, bool, error) { - return nil, false, nil -} - -func (cl userdbClient) lookupGroupId(_ context.Context, _ string) (*Group, bool, error) { - return nil, false, nil -} - -func (cl userdbClient) lookupUser(_ context.Context, _ string) (*User, bool, error) { - return nil, false, nil -} - -func (cl userdbClient) lookupUserId(_ context.Context, _ string) (*User, bool, error) { - return nil, false, nil -} - -func (cl userdbClient) lookupGroupIds(_ context.Context, _ string) ([]string, bool, error) { - return nil, false, nil -}