mirror of
https://github.com/golang/go
synced 2024-11-23 10:40:08 -07:00
Revert "os/user: lookup Linux users and groups via systemd userdb"
This reverts CL 459455. Reason for revert: breaks tests on various platforms, see https://go-review.googlesource.com/c/go/+/459455/74#message-3d9462b24872f6e0b12b4abf5ea3983e1588f91a Change-Id: I4c79b28f750c2369909688f86616d76d7eaf0ab4 Reviewed-on: https://go-review.googlesource.com/c/go/+/479135 Run-TryBot: Heschi Kreinick <heschi@google.com> Auto-Submit: Heschi Kreinick <heschi@google.com> Reviewed-by: Cherry Mui <cherryyz@google.com> TryBot-Result: Gopher Robot <gobot@golang.org>
This commit is contained in:
parent
b0dfcb7465
commit
e0c69587c4
@ -9,13 +9,11 @@ package user
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func listGroupsFromReader(u *User, r io.Reader) ([]string, error) {
|
||||
@ -101,13 +99,6 @@ func listGroupsFromReader(u *User, r io.Reader) ([]string, error) {
|
||||
}
|
||||
|
||||
func listGroups(u *User) ([]string, error) {
|
||||
if defaultUserdbClient.isUsable() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
if ids, ok, err := defaultUserdbClient.lookupGroupIds(ctx, u.Username); ok {
|
||||
return ids, err
|
||||
}
|
||||
}
|
||||
f, err := os.Open(groupFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -9,13 +9,11 @@ package user
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// lineFunc returns a value, an error, or (nil, nil) to skip the row.
|
||||
@ -200,13 +198,6 @@ func findUsername(name string, r io.Reader) (*User, error) {
|
||||
}
|
||||
|
||||
func lookupGroup(groupname string) (*Group, error) {
|
||||
if defaultUserdbClient.isUsable() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if g, ok, err := defaultUserdbClient.lookupGroup(ctx, groupname); ok {
|
||||
return g, err
|
||||
}
|
||||
}
|
||||
f, err := os.Open(groupFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -216,13 +207,6 @@ func lookupGroup(groupname string) (*Group, error) {
|
||||
}
|
||||
|
||||
func lookupGroupId(id string) (*Group, error) {
|
||||
if defaultUserdbClient.isUsable() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if g, ok, err := defaultUserdbClient.lookupGroupId(ctx, id); ok {
|
||||
return g, err
|
||||
}
|
||||
}
|
||||
f, err := os.Open(groupFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -232,13 +216,6 @@ func lookupGroupId(id string) (*Group, error) {
|
||||
}
|
||||
|
||||
func lookupUser(username string) (*User, error) {
|
||||
if defaultUserdbClient.isUsable() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if u, ok, err := defaultUserdbClient.lookupUser(ctx, username); ok {
|
||||
return u, err
|
||||
}
|
||||
}
|
||||
f, err := os.Open(userFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -248,13 +225,6 @@ func lookupUser(username string) (*User, error) {
|
||||
}
|
||||
|
||||
func lookupUserId(uid string) (*User, error) {
|
||||
if defaultUserdbClient.isUsable() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if u, ok, err := defaultUserdbClient.lookupUserId(ctx, uid); ok {
|
||||
return u, err
|
||||
}
|
||||
}
|
||||
f, err := os.Open(userFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -11,10 +11,6 @@ One is written in pure Go and parses /etc/passwd and /etc/group. The other
|
||||
is cgo-based and relies on the standard C library (libc) routines such as
|
||||
getpwuid_r, getgrnam_r, and getgrouplist.
|
||||
|
||||
For Linux, the pure Go implementation queries the systemd-userdb service first.
|
||||
If the service is not available, it falls back to parsing /etc/passwd and
|
||||
/etc/group.
|
||||
|
||||
When cgo is available, and the required routines are implemented in libc
|
||||
for a particular platform, cgo-based (libc-backed) code is used.
|
||||
This can be overridden by using osusergo build tag, which enforces
|
||||
|
@ -1,22 +0,0 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package user
|
||||
|
||||
// userdbClient queries the io.systemd.UserDatabase VARLINK interface provided by
|
||||
// systemd-userdbd.service(8) on Linux for obtaining full user/group details
|
||||
// even when cgo is not available.
|
||||
// VARLINK protocol: https://varlink.org
|
||||
// Systemd userdb VARLINK interface https://systemd.io/USER_GROUP_API
|
||||
// dir contains multiple varlink service sockets implementing the userdb interface.
|
||||
type userdbClient struct {
|
||||
dir string
|
||||
}
|
||||
|
||||
// IsUsable checks if the client can be used to make queries.
|
||||
func (cl userdbClient) isUsable() bool {
|
||||
return len(cl.dir) != 0
|
||||
}
|
||||
|
||||
var defaultUserdbClient userdbClient
|
@ -1,772 +0,0 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unicode/utf16"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
// Well known multiplexer service.
|
||||
svcMultiplexer = "io.systemd.Multiplexer"
|
||||
|
||||
userdbNamespace = "io.systemd.UserDatabase"
|
||||
|
||||
// io.systemd.UserDatabase VARLINK interface methods.
|
||||
mGetGroupRecord = userdbNamespace + ".GetGroupRecord"
|
||||
mGetUserRecord = userdbNamespace + ".GetUserRecord"
|
||||
mGetMemberships = userdbNamespace + ".GetMemberships"
|
||||
|
||||
// io.systemd.UserDatabase VARLINK interface errors.
|
||||
errNoRecordFound = userdbNamespace + ".NoRecordFound"
|
||||
errServiceNotAvailable = userdbNamespace + ".ServiceNotAvailable"
|
||||
)
|
||||
|
||||
func init() {
|
||||
defaultUserdbClient.dir = "/run/systemd/userdb"
|
||||
}
|
||||
|
||||
// userdbCall represents a VARLINK service call sent to systemd-userdb.
|
||||
// method is the VARLINK method to call.
|
||||
// parameters are the VARLINK parameters to pass.
|
||||
// more indicates if more responses are expected.
|
||||
// fastest indicates if only the fastest response should be returned.
|
||||
type userdbCall struct {
|
||||
method string
|
||||
parameters callParameters
|
||||
more bool
|
||||
fastest bool
|
||||
}
|
||||
|
||||
func (u userdbCall) marshalJSON(service string) ([]byte, error) {
|
||||
params, err := u.parameters.marshalJSON(service)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data bytes.Buffer
|
||||
data.WriteString(`{"method":"`)
|
||||
data.WriteString(u.method)
|
||||
data.WriteString(`","parameters":`)
|
||||
data.Write(params)
|
||||
if u.more {
|
||||
data.WriteString(`,"more":true`)
|
||||
}
|
||||
data.WriteString(`}`)
|
||||
return data.Bytes(), nil
|
||||
}
|
||||
|
||||
type callParameters struct {
|
||||
uid *int64
|
||||
userName string
|
||||
gid *int64
|
||||
groupName string
|
||||
}
|
||||
|
||||
func (c callParameters) marshalJSON(service string) ([]byte, error) {
|
||||
var data bytes.Buffer
|
||||
data.WriteString(`{"service":"`)
|
||||
data.WriteString(service)
|
||||
data.WriteString(`"`)
|
||||
if c.uid != nil {
|
||||
data.WriteString(`,"uid":`)
|
||||
data.WriteString(strconv.FormatInt(*c.uid, 10))
|
||||
}
|
||||
if c.userName != "" {
|
||||
data.WriteString(`,"userName":"`)
|
||||
data.WriteString(c.userName)
|
||||
data.WriteString(`"`)
|
||||
}
|
||||
if c.gid != nil {
|
||||
data.WriteString(`,"gid":`)
|
||||
data.WriteString(strconv.FormatInt(*c.gid, 10))
|
||||
}
|
||||
if c.groupName != "" {
|
||||
data.WriteString(`,"groupName":"`)
|
||||
data.WriteString(c.groupName)
|
||||
data.WriteString(`"`)
|
||||
}
|
||||
data.WriteString(`}`)
|
||||
return data.Bytes(), nil
|
||||
}
|
||||
|
||||
type userdbReply struct {
|
||||
continues bool
|
||||
errorStr string
|
||||
}
|
||||
|
||||
func (u *userdbReply) unmarshalJSON(data []byte) error {
|
||||
var (
|
||||
kContinues = []byte(`"continues"`)
|
||||
kError = []byte(`"error"`)
|
||||
)
|
||||
if i := bytes.Index(data, kContinues); i != -1 {
|
||||
continues, err := parseJSONBoolean(data[i+len(kContinues):])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.continues = continues
|
||||
}
|
||||
if i := bytes.Index(data, kError); i != -1 {
|
||||
errStr, err := parseJSONString(data[i+len(kError):])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.errorStr = errStr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// response is the parsed reply from a method call to systemd-userdb.
|
||||
// data is one or more VARLINK response parameters separated by 0.
|
||||
// handled indicates if the call was handled by systemd-userdb.
|
||||
// err is any error encountered.
|
||||
type response struct {
|
||||
data []byte
|
||||
handled bool
|
||||
err error
|
||||
}
|
||||
|
||||
// querySocket calls the io.systemd.UserDatabase VARLINK interface at sock with request.
|
||||
// Multiple replies can be read by setting more to true in the request.
|
||||
// Reply parameters are accumulated separated by 0, if there are many.
|
||||
// Replies with io.systemd.UserDatabase.NoRecordFound errors are skipped.
|
||||
// Other UserDatabase errors are returned as is.
|
||||
// If the socket does not exist, or if the io.systemd.UserDatabase.ServiceNotAvailable
|
||||
// error is seen in a response, the query is considered unhandled.
|
||||
func querySocket(ctx context.Context, sock string, request []byte) response {
|
||||
sockFd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
|
||||
if err != nil {
|
||||
return response{err: err}
|
||||
}
|
||||
defer syscall.Close(sockFd)
|
||||
if err := syscall.Connect(sockFd, &syscall.SockaddrUnix{Name: sock}); err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return response{err: err}
|
||||
}
|
||||
return response{handled: true, err: err}
|
||||
}
|
||||
|
||||
// Null terminate request.
|
||||
if request[len(request)-1] != 0 {
|
||||
request = append(request, 0)
|
||||
}
|
||||
|
||||
// Write request to socket.
|
||||
written := 0
|
||||
for written < len(request) {
|
||||
if ctx.Err() != nil {
|
||||
return response{handled: true, err: ctx.Err()}
|
||||
}
|
||||
if n, err := syscall.Write(sockFd, request[written:]); err != nil {
|
||||
return response{handled: true, err: err}
|
||||
} else {
|
||||
written += n
|
||||
}
|
||||
}
|
||||
|
||||
// Read response.
|
||||
var resp bytes.Buffer
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return response{handled: true, err: ctx.Err()}
|
||||
}
|
||||
buf := make([]byte, 4096)
|
||||
if n, err := syscall.Read(sockFd, buf); err != nil {
|
||||
return response{handled: true, err: err}
|
||||
} else if n > 0 {
|
||||
resp.Write(buf[:n])
|
||||
if buf[n-1] == 0 {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
// EOF
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if resp.Len() == 0 {
|
||||
return response{handled: true}
|
||||
}
|
||||
|
||||
buf := resp.Bytes()
|
||||
// Remove trailing 0.
|
||||
buf = buf[:len(buf)-1]
|
||||
// Split into VARLINK messages.
|
||||
msgs := bytes.Split(buf, []byte{0})
|
||||
|
||||
// Parse VARLINK messages.
|
||||
for _, m := range msgs {
|
||||
var resp userdbReply
|
||||
if err := resp.unmarshalJSON(m); err != nil {
|
||||
return response{handled: true, err: err}
|
||||
}
|
||||
// Handle VARLINK message errors.
|
||||
switch e := resp.errorStr; e {
|
||||
case "":
|
||||
case errNoRecordFound: // Ignore not found error.
|
||||
continue
|
||||
case errServiceNotAvailable:
|
||||
return response{}
|
||||
default:
|
||||
return response{handled: true, err: errors.New(e)}
|
||||
}
|
||||
if !resp.continues {
|
||||
break
|
||||
}
|
||||
}
|
||||
return response{data: buf, handled: true, err: ctx.Err()}
|
||||
}
|
||||
|
||||
// queryMany calls the io.systemd.UserDatabase VARLINK interface on many services at once.
|
||||
// ss is a slice of userdb services to call. Each service must have a socket in cl.dir.
|
||||
// c is sent to all services in ss. If c.fastest is true, only the fastest reply is read.
|
||||
// Otherwise all replies are aggregated. um is called with aggregated reply parameters.
|
||||
// queryMany returns the first error encountered. The first result is false if no userdb
|
||||
// socket is available or if all requests time out.
|
||||
func (cl userdbClient) queryMany(ctx context.Context, ss []string, c *userdbCall, um jsonUnmarshaler) (bool, error) {
|
||||
responseCh := make(chan response, len(ss))
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Query all services in parallel.
|
||||
var workers sync.WaitGroup
|
||||
for _, svc := range ss {
|
||||
data, err := c.marshalJSON(svc)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
// Spawn worker to query service.
|
||||
workers.Add(1)
|
||||
go func(sock string, data []byte) {
|
||||
defer workers.Done()
|
||||
responseCh <- querySocket(ctx, sock, data)
|
||||
}(cl.dir+"/"+svc, data)
|
||||
}
|
||||
|
||||
go func() {
|
||||
// Clean up workers.
|
||||
workers.Wait()
|
||||
close(responseCh)
|
||||
}()
|
||||
|
||||
var result bytes.Buffer
|
||||
var notOk int
|
||||
RecvResponses:
|
||||
for {
|
||||
select {
|
||||
case resp, ok := <-responseCh:
|
||||
if !ok {
|
||||
// Responses channel is closed so stop reading.
|
||||
break RecvResponses
|
||||
}
|
||||
if resp.err != nil {
|
||||
// querySocket only returns unrecoverable errors,
|
||||
// so return the first one received.
|
||||
return true, resp.err
|
||||
}
|
||||
if !resp.handled {
|
||||
notOk++
|
||||
continue
|
||||
}
|
||||
|
||||
first := result.Len() == 0
|
||||
result.Write(resp.data)
|
||||
if first && c.fastest {
|
||||
// Return the fastest response.
|
||||
break RecvResponses
|
||||
}
|
||||
case <-ctx.Done():
|
||||
// If requests time out, userdb is unavailable.
|
||||
return ctx.Err() != context.DeadlineExceeded, nil
|
||||
}
|
||||
}
|
||||
// If all sockets are not ok, userdb is unavailable.
|
||||
if notOk == len(ss) {
|
||||
return false, nil
|
||||
}
|
||||
return true, um.unmarshalJSON(result.Bytes())
|
||||
}
|
||||
|
||||
// services enumerates userdb service sockets in dir.
|
||||
// If ok is false, io.systemd.UserDatabase service does not exist.
|
||||
func (cl userdbClient) services() (s []string, ok bool, err error) {
|
||||
var entries []fs.DirEntry
|
||||
if entries, err = os.ReadDir(cl.dir); err != nil {
|
||||
ok = !os.IsNotExist(err)
|
||||
return
|
||||
}
|
||||
ok = true
|
||||
for _, ent := range entries {
|
||||
s = append(s, ent.Name())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// query looks up users/groups on the io.systemd.UserDatabase VARLINK interface.
|
||||
// If the multiplexer service is available, the call is sent only to it.
|
||||
// Otherwise, the call is sent simultaneously to all UserDatabase services in cl.dir.
|
||||
// The fastest reply is read and parsed. All other requests are cancelled.
|
||||
// If the service is unavailable, the first result is false.
|
||||
// The service is considered unavailable if the requests time-out as well.
|
||||
func (cl userdbClient) query(ctx context.Context, call *userdbCall, um jsonUnmarshaler) (bool, error) {
|
||||
services := []string{svcMultiplexer}
|
||||
if _, err := os.Stat(cl.dir + "/" + svcMultiplexer); err != nil {
|
||||
// No mux service so call all available services.
|
||||
var ok bool
|
||||
if services, ok, err = cl.services(); !ok || err != nil {
|
||||
return ok, err
|
||||
}
|
||||
}
|
||||
call.fastest = true
|
||||
if ok, err := cl.queryMany(ctx, services, call, um); !ok || err != nil {
|
||||
return ok, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type jsonUnmarshaler interface {
|
||||
unmarshalJSON([]byte) error
|
||||
}
|
||||
|
||||
func isSpace(c byte) bool {
|
||||
return c == ' ' || c == '\t' || c == '\r' || c == '\n'
|
||||
}
|
||||
|
||||
// findElementStart returns a slice of r that starts at the next JSON element.
|
||||
// It skips over valid JSON space characters and checks for the colon separator.
|
||||
func findElementStart(r []byte) ([]byte, error) {
|
||||
var idx int
|
||||
var b byte
|
||||
colon := byte(':')
|
||||
var seenColon bool
|
||||
for idx, b = range r {
|
||||
if isSpace(b) {
|
||||
continue
|
||||
}
|
||||
if !seenColon && b == colon {
|
||||
seenColon = true
|
||||
continue
|
||||
}
|
||||
// Spotted colon and b is not a space, so value starts here.
|
||||
if seenColon {
|
||||
break
|
||||
}
|
||||
return nil, errors.New("expected colon, got invalid character: " + string(b))
|
||||
}
|
||||
if !seenColon {
|
||||
return nil, errors.New("expected colon, got end of input")
|
||||
}
|
||||
return r[idx:], nil
|
||||
}
|
||||
|
||||
// parseJSONString reads a JSON string from r.
|
||||
func parseJSONString(r []byte) (string, error) {
|
||||
r, err := findElementStart(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Smallest valid string is `""`.
|
||||
if l := len(r); l < 2 {
|
||||
return "", errors.New("unexpected end of input")
|
||||
} else if l == 2 {
|
||||
if bytes.Equal(r, []byte(`""`)) {
|
||||
return "", nil
|
||||
}
|
||||
return "", errors.New("invalid string")
|
||||
}
|
||||
|
||||
if c := r[0]; c != '"' {
|
||||
return "", errors.New(`expected " got ` + string(c))
|
||||
}
|
||||
// Advance over opening quote.
|
||||
r = r[1:]
|
||||
|
||||
var value strings.Builder
|
||||
var inEsc bool
|
||||
var inUEsc bool
|
||||
var strEnds bool
|
||||
reader := bytes.NewReader(r)
|
||||
for {
|
||||
if value.Len() > 4096 {
|
||||
return "", errors.New("string too large")
|
||||
}
|
||||
|
||||
// Parse unicode escape sequences.
|
||||
if inUEsc {
|
||||
maybeRune := make([]byte, 4)
|
||||
n, err := reader.Read(maybeRune)
|
||||
if err != nil || n != 4 {
|
||||
return "", fmt.Errorf("invalid unicode escape sequence \\u%s", string(maybeRune))
|
||||
}
|
||||
prn, err := strconv.ParseUint(string(maybeRune), 16, 32)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid unicode escape sequence \\u%s", string(maybeRune))
|
||||
}
|
||||
rn := rune(prn)
|
||||
if !utf16.IsSurrogate(rn) {
|
||||
value.WriteRune(rn)
|
||||
inUEsc = false
|
||||
continue
|
||||
}
|
||||
// rn maybe a high surrogate; read the low surrogate.
|
||||
maybeRune = make([]byte, 6)
|
||||
n, err = reader.Read(maybeRune)
|
||||
if err != nil || n != 6 || maybeRune[0] != '\\' || maybeRune[1] != 'u' {
|
||||
// Not a valid UTF-16 surrogate pair.
|
||||
if _, err := reader.Seek(int64(-n), io.SeekCurrent); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Invalid low surrogate; write the replacement character.
|
||||
value.WriteRune(utf8.RuneError)
|
||||
} else {
|
||||
rn1, err := strconv.ParseUint(string(maybeRune[2:]), 16, 32)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid unicode escape sequence %s", string(maybeRune))
|
||||
}
|
||||
// Check if rn and rn1 are valid UTF-16 surrogate pairs.
|
||||
if dec := utf16.DecodeRune(rn, rune(rn1)); dec != utf8.RuneError {
|
||||
n = utf8.EncodeRune(maybeRune, dec)
|
||||
// Write the decoded rune.
|
||||
value.Write(maybeRune[:n])
|
||||
}
|
||||
}
|
||||
inUEsc = false
|
||||
continue
|
||||
}
|
||||
|
||||
if inEsc {
|
||||
b, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
switch b {
|
||||
case 'b':
|
||||
value.WriteByte('\b')
|
||||
case 'f':
|
||||
value.WriteByte('\f')
|
||||
case 'n':
|
||||
value.WriteByte('\n')
|
||||
case 'r':
|
||||
value.WriteByte('\r')
|
||||
case 't':
|
||||
value.WriteByte('\t')
|
||||
case 'u':
|
||||
inUEsc = true
|
||||
case '/':
|
||||
value.WriteByte('/')
|
||||
case '\\':
|
||||
value.WriteByte('\\')
|
||||
case '"':
|
||||
value.WriteByte('"')
|
||||
default:
|
||||
return "", errors.New("unexpected character in escape sequence " + string(b))
|
||||
}
|
||||
inEsc = false
|
||||
continue
|
||||
} else {
|
||||
rn, _, err := reader.ReadRune()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if rn == '\\' {
|
||||
inEsc = true
|
||||
continue
|
||||
}
|
||||
if rn == '"' {
|
||||
// String ends on un-escaped quote.
|
||||
strEnds = true
|
||||
break
|
||||
}
|
||||
value.WriteRune(rn)
|
||||
}
|
||||
}
|
||||
if !strEnds {
|
||||
return "", errors.New("unexpected end of input")
|
||||
}
|
||||
return value.String(), nil
|
||||
}
|
||||
|
||||
// parseJSONInt64 reads a 64 bit integer from r.
|
||||
func parseJSONInt64(r []byte) (int64, error) {
|
||||
r, err := findElementStart(r)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var num strings.Builder
|
||||
for _, b := range r {
|
||||
// int64 max is 19 digits long.
|
||||
if num.Len() == 20 {
|
||||
return 0, errors.New("number too large")
|
||||
}
|
||||
if strings.ContainsRune("0123456789", rune(b)) {
|
||||
num.WriteByte(b)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
n, err := strconv.ParseInt(num.String(), 10, 64)
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// parseJSONBoolean reads a boolean from r.
|
||||
func parseJSONBoolean(r []byte) (bool, error) {
|
||||
r, err := findElementStart(r)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if bytes.HasPrefix(r, []byte("true")) {
|
||||
return true, nil
|
||||
}
|
||||
if bytes.HasPrefix(r, []byte("false")) {
|
||||
return false, nil
|
||||
}
|
||||
return false, errors.New("unable to parse boolean value")
|
||||
}
|
||||
|
||||
type groupRecord struct {
|
||||
groupName string
|
||||
gid int64
|
||||
}
|
||||
|
||||
func (g *groupRecord) unmarshalJSON(data []byte) error {
|
||||
var (
|
||||
kGroupName = []byte(`"groupName"`)
|
||||
kGid = []byte(`"gid"`)
|
||||
)
|
||||
if i := bytes.Index(data, kGroupName); i != -1 {
|
||||
groupname, err := parseJSONString(data[i+len(kGroupName):])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g.groupName = groupname
|
||||
}
|
||||
if i := bytes.Index(data, kGid); i != -1 {
|
||||
gid, err := parseJSONInt64(data[i+len(kGid):])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g.gid = gid
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// queryGroupDb queries the userdb interface for a gid, groupname, or both.
|
||||
func (cl userdbClient) queryGroupDb(ctx context.Context, gid *int64, groupname string) (*Group, bool, error) {
|
||||
group := groupRecord{}
|
||||
request := userdbCall{
|
||||
method: mGetGroupRecord,
|
||||
parameters: callParameters{gid: gid, groupName: groupname},
|
||||
}
|
||||
if ok, err := cl.query(ctx, &request, &group); !ok || err != nil {
|
||||
return nil, ok, fmt.Errorf("error querying systemd-userdb group record: %s", err)
|
||||
}
|
||||
return &Group{
|
||||
Name: group.groupName,
|
||||
Gid: strconv.FormatInt(group.gid, 10),
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
type userRecord struct {
|
||||
userName string
|
||||
realName string
|
||||
uid int64
|
||||
gid int64
|
||||
homeDirectory string
|
||||
}
|
||||
|
||||
func (u *userRecord) unmarshalJSON(data []byte) error {
|
||||
var (
|
||||
kUserName = []byte(`"userName"`)
|
||||
kRealName = []byte(`"realName"`)
|
||||
kUid = []byte(`"uid"`)
|
||||
kGid = []byte(`"gid"`)
|
||||
kHomeDirectory = []byte(`"homeDirectory"`)
|
||||
)
|
||||
if i := bytes.Index(data, kUserName); i != -1 {
|
||||
username, err := parseJSONString(data[i+len(kUserName):])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.userName = username
|
||||
}
|
||||
if i := bytes.Index(data, kRealName); i != -1 {
|
||||
realname, err := parseJSONString(data[i+len(kRealName):])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.realName = realname
|
||||
}
|
||||
if i := bytes.Index(data, kUid); i != -1 {
|
||||
uid, err := parseJSONInt64(data[i+len(kUid):])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.uid = uid
|
||||
}
|
||||
if i := bytes.Index(data, kGid); i != -1 {
|
||||
gid, err := parseJSONInt64(data[i+len(kGid):])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.gid = gid
|
||||
}
|
||||
if i := bytes.Index(data, kHomeDirectory); i != -1 {
|
||||
homedir, err := parseJSONString(data[i+len(kHomeDirectory):])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.homeDirectory = homedir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// queryUserDb queries the userdb interface for a uid, username, or both.
|
||||
func (cl userdbClient) queryUserDb(ctx context.Context, uid *int64, username string) (*User, bool, error) {
|
||||
user := userRecord{}
|
||||
request := userdbCall{
|
||||
method: mGetUserRecord,
|
||||
parameters: callParameters{
|
||||
uid: uid,
|
||||
userName: username,
|
||||
},
|
||||
}
|
||||
if ok, err := cl.query(ctx, &request, &user); !ok || err != nil {
|
||||
return nil, ok, fmt.Errorf("error querying systemd-userdb user record: %s", err)
|
||||
}
|
||||
return &User{
|
||||
Uid: strconv.FormatInt(user.uid, 10),
|
||||
Gid: strconv.FormatInt(user.gid, 10),
|
||||
Username: user.userName,
|
||||
Name: user.realName,
|
||||
HomeDir: user.homeDirectory,
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
func (cl userdbClient) lookupGroup(ctx context.Context, groupname string) (*Group, bool, error) {
|
||||
return cl.queryGroupDb(ctx, nil, groupname)
|
||||
}
|
||||
|
||||
func (cl userdbClient) lookupGroupId(ctx context.Context, id string) (*Group, bool, error) {
|
||||
gid, err := strconv.ParseInt(id, 10, 64)
|
||||
if err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
return cl.queryGroupDb(ctx, &gid, "")
|
||||
}
|
||||
|
||||
func (cl userdbClient) lookupUser(ctx context.Context, username string) (*User, bool, error) {
|
||||
return cl.queryUserDb(ctx, nil, username)
|
||||
}
|
||||
|
||||
func (cl userdbClient) lookupUserId(ctx context.Context, id string) (*User, bool, error) {
|
||||
uid, err := strconv.ParseInt(id, 10, 64)
|
||||
if err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
return cl.queryUserDb(ctx, &uid, "")
|
||||
}
|
||||
|
||||
type memberships struct {
|
||||
// Keys are groupNames and values are sets of userNames.
|
||||
groupUsers map[string]map[string]struct{}
|
||||
}
|
||||
|
||||
// unmarshalJSON expects many (userName, groupName) records separated by a null byte.
|
||||
// This is used to build a membership map.
|
||||
func (m *memberships) unmarshalJSON(data []byte) error {
|
||||
if m.groupUsers == nil {
|
||||
m.groupUsers = make(map[string]map[string]struct{})
|
||||
}
|
||||
var (
|
||||
kUserName = []byte(`"userName"`)
|
||||
kGroupName = []byte(`"groupName"`)
|
||||
)
|
||||
// Split records by null terminator.
|
||||
records := bytes.Split(data, []byte{byte(0)})
|
||||
for _, rec := range records {
|
||||
if len(rec) == 0 {
|
||||
continue
|
||||
}
|
||||
var groupName string
|
||||
var userName string
|
||||
var err error
|
||||
if i := bytes.Index(rec, kGroupName); i != -1 {
|
||||
if groupName, err = parseJSONString(rec[i+len(kGroupName):]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if i := bytes.Index(rec, kUserName); i != -1 {
|
||||
if userName, err = parseJSONString(rec[i+len(kUserName):]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Associate userName with groupName.
|
||||
if groupName != "" && userName != "" {
|
||||
if _, ok := m.groupUsers[groupName]; ok {
|
||||
m.groupUsers[groupName][userName] = struct{}{}
|
||||
} else {
|
||||
m.groupUsers[groupName] = map[string]struct{}{userName: {}}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cl userdbClient) lookupGroupIds(ctx context.Context, username string) ([]string, bool, error) {
|
||||
services, ok, err := cl.services()
|
||||
if !ok || err != nil {
|
||||
return nil, ok, err
|
||||
}
|
||||
// Fetch group memberships for username.
|
||||
var ms memberships
|
||||
request := userdbCall{
|
||||
method: mGetMemberships,
|
||||
parameters: callParameters{userName: username},
|
||||
more: true,
|
||||
}
|
||||
if ok, err := cl.queryMany(ctx, services, &request, &ms); !ok || err != nil {
|
||||
return nil, ok, fmt.Errorf("error querying systemd-userdb memberships record: %s", err)
|
||||
}
|
||||
// Fetch user group gid.
|
||||
var group groupRecord
|
||||
request = userdbCall{
|
||||
method: mGetGroupRecord,
|
||||
parameters: callParameters{groupName: username},
|
||||
}
|
||||
if ok, err := cl.query(ctx, &request, &group); !ok || err != nil {
|
||||
return nil, ok, err
|
||||
}
|
||||
gids := []string{strconv.FormatInt(group.gid, 10)}
|
||||
|
||||
// Fetch group records for each group.
|
||||
for g := range ms.groupUsers {
|
||||
var group groupRecord
|
||||
request.parameters.groupName = g
|
||||
// Query group for gid.
|
||||
if ok, err := cl.query(ctx, &request, &group); !ok || err != nil {
|
||||
return nil, ok, fmt.Errorf("error querying systemd-userdb group record: %s", err)
|
||||
}
|
||||
gids = append(gids, strconv.FormatInt(group.gid, 10))
|
||||
}
|
||||
return gids, true, nil
|
||||
}
|
@ -1,504 +0,0 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func TestQueryNoUserdb(t *testing.T) {
|
||||
cl := &userdbClient{dir: "/non/existent"}
|
||||
if _, ok, err := cl.lookupGroup(context.Background(), "stdlibcontrib"); ok {
|
||||
t.Fatalf("should fail but lookup has been handled or error is nil: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type userdbTestData map[string]udbResponse
|
||||
|
||||
type udbResponse struct {
|
||||
data []byte
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func userdbServer(t *testing.T, sockFn string, data userdbTestData) {
|
||||
ready := make(chan struct{})
|
||||
go func() {
|
||||
if err := serveUserdb(ready, sockFn, data); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
<-ready
|
||||
}
|
||||
|
||||
func (u userdbTestData) String() string {
|
||||
var s strings.Builder
|
||||
for k, v := range u {
|
||||
s.WriteString("Request:\n")
|
||||
s.WriteString(k)
|
||||
s.WriteString("\nResponse:\n")
|
||||
if v.delay > 0 {
|
||||
s.WriteString("Delay: ")
|
||||
s.WriteString(v.delay.String())
|
||||
s.WriteString("\n")
|
||||
}
|
||||
s.WriteString("Data:\n")
|
||||
s.Write(v.data)
|
||||
s.WriteString("\n")
|
||||
}
|
||||
return s.String()
|
||||
}
|
||||
|
||||
// serverUserdb is a simple userdb server that replies to VARLINK method calls.
|
||||
// A message is sent on the ready channel when the server is ready to accept calls.
|
||||
// The server will reply to each request in the data map. If a request is not
|
||||
// found in the map, the server will return an error.
|
||||
func serveUserdb(ready chan<- struct{}, sockFn string, data userdbTestData) error {
|
||||
sockFd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(sockFd)
|
||||
if err := syscall.Bind(sockFd, &syscall.SockaddrUnix{Name: sockFn}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := syscall.Listen(sockFd, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send ready signal.
|
||||
ready <- struct{}{}
|
||||
|
||||
var srvGroup sync.WaitGroup
|
||||
|
||||
srvErrs := make(chan error, len(data))
|
||||
for len(data) != 0 {
|
||||
nfd, _, err := syscall.Accept(sockFd)
|
||||
if err != nil {
|
||||
syscall.Close(nfd)
|
||||
return err
|
||||
}
|
||||
|
||||
// Read request.
|
||||
buf := make([]byte, 4096)
|
||||
n, err := syscall.Read(nfd, buf)
|
||||
if err != nil {
|
||||
syscall.Close(nfd)
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
// Client went away.
|
||||
continue
|
||||
}
|
||||
if buf[n-1] != 0 {
|
||||
syscall.Close(nfd)
|
||||
return errors.New("request not null terminated")
|
||||
}
|
||||
// Remove null terminator.
|
||||
buf = buf[:n-1]
|
||||
got := string(buf)
|
||||
|
||||
// Fetch response for request.
|
||||
response, ok := data[got]
|
||||
if !ok {
|
||||
syscall.Close(nfd)
|
||||
msg := "unexpected request:\n" + got + "\n\ndata:\n" + data.String()
|
||||
return errors.New(msg)
|
||||
}
|
||||
delete(data, got)
|
||||
|
||||
srvGroup.Add(1)
|
||||
go func() {
|
||||
defer srvGroup.Done()
|
||||
if err := serveClient(nfd, response); err != nil {
|
||||
srvErrs <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
srvGroup.Wait()
|
||||
// Combine serve errors if any.
|
||||
if len(srvErrs) > 0 {
|
||||
var errs []error
|
||||
for err := range srvErrs {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func serveClient(fd int, response udbResponse) error {
|
||||
defer syscall.Close(fd)
|
||||
time.Sleep(response.delay)
|
||||
data := response.data
|
||||
if len(data) != 0 && data[len(data)-1] != 0 {
|
||||
data = append(data, 0)
|
||||
}
|
||||
written := 0
|
||||
for written < len(data) {
|
||||
if n, err := syscall.Write(fd, data[written:]); err != nil {
|
||||
return err
|
||||
} else {
|
||||
written += n
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSlowUserdbLookup(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
data := userdbTestData{
|
||||
`{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{
|
||||
delay: time.Hour,
|
||||
},
|
||||
}
|
||||
userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
|
||||
cl := &userdbClient{dir: tmpdir}
|
||||
// Lookup should timeout.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
|
||||
defer cancel()
|
||||
if _, ok, _ := cl.lookupGroup(ctx, "stdlibcontrib"); ok {
|
||||
t.Fatalf("lookup should not be handled but was")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFastestUserdbLookup(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
fastData := userdbTestData{
|
||||
`{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"fast","groupName":"stdlibcontrib"}}`: udbResponse{
|
||||
data: []byte(
|
||||
`{"parameters":{"record":{"groupName":"stdlibcontrib","gid":181,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
|
||||
),
|
||||
},
|
||||
}
|
||||
slowData := userdbTestData{
|
||||
`{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"slow","groupName":"stdlibcontrib"}}`: udbResponse{
|
||||
delay: 50 * time.Millisecond,
|
||||
data: []byte(
|
||||
`{"parameters":{"record":{"groupName":"stdlibcontrib","gid":182,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
|
||||
),
|
||||
},
|
||||
}
|
||||
userdbServer(t, tmpdir+"/"+"fast", fastData)
|
||||
userdbServer(t, tmpdir+"/"+"slow", slowData)
|
||||
cl := &userdbClient{dir: tmpdir}
|
||||
group, ok, err := cl.lookupGroup(context.Background(), "stdlibcontrib")
|
||||
if !ok {
|
||||
t.Fatalf("lookup should be handled but was not")
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("lookup should not fail but did: %v", err)
|
||||
}
|
||||
if group.Gid != "181" {
|
||||
t.Fatalf("lookup should return group 181 but returned %s", group.Gid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserdbLookupGroup(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
data := userdbTestData{
|
||||
`{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{
|
||||
data: []byte(
|
||||
`{"parameters":{"record":{"groupName":"stdlibcontrib","gid":181,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
|
||||
),
|
||||
},
|
||||
}
|
||||
userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
|
||||
|
||||
groupname := "stdlibcontrib"
|
||||
want := &Group{
|
||||
Name: "stdlibcontrib",
|
||||
Gid: "181",
|
||||
}
|
||||
cl := &userdbClient{dir: tmpdir}
|
||||
got, ok, err := cl.lookupGroup(context.Background(), groupname)
|
||||
if !ok {
|
||||
t.Fatal("lookup should have been handled")
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("lookupGroup(%s) = %v, want %v", groupname, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserdbLookupUser(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
data := userdbTestData{
|
||||
`{"method":"io.systemd.UserDatabase.GetUserRecord","parameters":{"service":"io.systemd.Multiplexer","userName":"stdlibcontrib"}}`: udbResponse{
|
||||
data: []byte(
|
||||
`{"parameters":{"record":{"userName":"stdlibcontrib","uid":181,"gid":181,"realName":"Stdlib Contrib","homeDirectory":"/home/stdlibcontrib","status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
|
||||
),
|
||||
},
|
||||
}
|
||||
userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
|
||||
|
||||
username := "stdlibcontrib"
|
||||
want := &User{
|
||||
Uid: "181",
|
||||
Gid: "181",
|
||||
Username: "stdlibcontrib",
|
||||
Name: "Stdlib Contrib",
|
||||
HomeDir: "/home/stdlibcontrib",
|
||||
}
|
||||
cl := &userdbClient{dir: tmpdir}
|
||||
got, ok, err := cl.lookupUser(context.Background(), username)
|
||||
if !ok {
|
||||
t.Fatal("lookup should have been handled")
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("lookupUser(%s) = %v, want %v", username, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserdbLookupGroupIds(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
data := userdbTestData{
|
||||
`{"method":"io.systemd.UserDatabase.GetMemberships","parameters":{"service":"io.systemd.Multiplexer","userName":"stdlibcontrib"},"more":true}`: udbResponse{
|
||||
data: []byte(
|
||||
`{"parameters":{"userName":"stdlibcontrib","groupName":"stdlib"},"continues":true}` + "\x00" + `{"parameters":{"userName":"stdlibcontrib","groupName":"contrib"}}`,
|
||||
),
|
||||
},
|
||||
// group records
|
||||
`{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{
|
||||
data: []byte(
|
||||
`{"parameters":{"record":{"groupName":"stdlibcontrib","members":["stdlibcontrib"],"gid":181,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
|
||||
),
|
||||
},
|
||||
`{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlib"}}`: udbResponse{
|
||||
data: []byte(
|
||||
`{"parameters":{"record":{"groupName":"stdlib","members":["stdlibcontrib"],"gid":182,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
|
||||
),
|
||||
},
|
||||
`{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"contrib"}}`: udbResponse{
|
||||
data: []byte(
|
||||
`{"parameters":{"record":{"groupName":"contrib","members":["stdlibcontrib"],"gid":183,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
|
||||
),
|
||||
},
|
||||
}
|
||||
userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
|
||||
|
||||
username := "stdlibcontrib"
|
||||
want := []string{"181", "182", "183"}
|
||||
cl := &userdbClient{dir: tmpdir}
|
||||
got, ok, err := cl.lookupGroupIds(context.Background(), username)
|
||||
if !ok {
|
||||
t.Fatal("lookup should have been handled")
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Result order is not specified so sort it.
|
||||
sort.Strings(got)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("lookupGroupIds(%s) = %v, want %v", username, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
var findElementStartTestCases = []struct {
|
||||
in []byte
|
||||
want []byte
|
||||
err bool
|
||||
}{
|
||||
{in: []byte(`:`), want: []byte(``)},
|
||||
{in: []byte(`: `), want: []byte(``)},
|
||||
{in: []byte(`:"foo"`), want: []byte(`"foo"`)},
|
||||
{in: []byte(` :"foo"`), want: []byte(`"foo"`)},
|
||||
{in: []byte(` 1231 :"foo"`), err: true},
|
||||
{in: []byte(``), err: true},
|
||||
{in: []byte(`"foo"`), err: true},
|
||||
{in: []byte(`foo`), err: true},
|
||||
}
|
||||
|
||||
func TestFindElementStart(t *testing.T) {
|
||||
for i, tc := range findElementStartTestCases {
|
||||
t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
|
||||
got, err := findElementStart(tc.in)
|
||||
if tc.err && err == nil {
|
||||
t.Errorf("want err for findElementStart(%s), got nil", tc.in)
|
||||
}
|
||||
if !tc.err {
|
||||
if err != nil {
|
||||
t.Errorf("findElementStart(%s) unexpected error: %s", tc.in, err.Error())
|
||||
}
|
||||
if !bytes.Contains(tc.in, got) {
|
||||
t.Errorf("%s should contain %s but does not", tc.in, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzFindElementStart(f *testing.F) {
|
||||
for _, tc := range findElementStartTestCases {
|
||||
if !tc.err {
|
||||
f.Add(tc.in)
|
||||
}
|
||||
}
|
||||
f.Fuzz(func(t *testing.T, b []byte) {
|
||||
if out, err := findElementStart(b); err == nil && !bytes.Contains(b, out) {
|
||||
t.Errorf("%s, %v", out, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var parseJSONStringTestCases = []struct {
|
||||
in []byte
|
||||
want string
|
||||
err bool
|
||||
}{
|
||||
{in: []byte(`:""`)},
|
||||
{in: []byte(`:"\n"`), want: "\n"},
|
||||
{in: []byte(`: "\""`), want: "\""},
|
||||
{in: []byte(`:"\t \\"`), want: "\t \\"},
|
||||
{in: []byte(`:"\\\\"`), want: `\\`},
|
||||
{in: []byte(`::`), err: true},
|
||||
{in: []byte(`""`), err: true},
|
||||
{in: []byte(`"`), err: true},
|
||||
{in: []byte(":\"0\xE5"), err: true},
|
||||
{in: []byte{':', '"', 0xFE, 0xFE, 0xFF, 0xFF, '"'}, want: "\uFFFD\uFFFD\uFFFD\uFFFD"},
|
||||
{in: []byte(`:"\u0061a"`), want: "aa"},
|
||||
{in: []byte(`:"\u0159\u0170"`), want: "řŰ"},
|
||||
{in: []byte(`:"\uD800\uDC00"`), want: "\U00010000"},
|
||||
{in: []byte(`:"\uD800"`), want: "\uFFFD"},
|
||||
{in: []byte(`:"\u000"`), err: true},
|
||||
{in: []byte(`:"\u00MF"`), err: true},
|
||||
{in: []byte(`:"\uD800\uDC0"`), err: true},
|
||||
}
|
||||
|
||||
func TestParseJSONString(t *testing.T) {
|
||||
for i, tc := range parseJSONStringTestCases {
|
||||
t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
|
||||
got, err := parseJSONString(tc.in)
|
||||
if tc.err && err == nil {
|
||||
t.Errorf("want err for parseJSONString(%s), got nil", tc.in)
|
||||
}
|
||||
if !tc.err {
|
||||
if err != nil {
|
||||
t.Errorf("parseJSONString(%s) unexpected error: %s", tc.in, err.Error())
|
||||
}
|
||||
if tc.want != got {
|
||||
t.Errorf("parseJSONString(%s) = %s, want %s", tc.in, got, tc.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzParseJSONString(f *testing.F) {
|
||||
for _, tc := range parseJSONStringTestCases {
|
||||
f.Add(tc.in)
|
||||
}
|
||||
f.Fuzz(func(t *testing.T, b []byte) {
|
||||
if out, err := parseJSONString(b); err == nil && !utf8.ValidString(out) {
|
||||
t.Errorf("parseJSONString(%s) = %s, invalid string", b, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var parseJSONInt64TestCases = []struct {
|
||||
in []byte
|
||||
want int64
|
||||
err bool
|
||||
}{
|
||||
{in: []byte(":1235"), want: 1235},
|
||||
{in: []byte(": 123"), want: 123},
|
||||
{in: []byte(":0")},
|
||||
{in: []byte(":5012313123131231"), want: 5012313123131231},
|
||||
{in: []byte("1231"), err: true},
|
||||
}
|
||||
|
||||
func TestParseJSONInt64(t *testing.T) {
|
||||
for i, tc := range parseJSONInt64TestCases {
|
||||
t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
|
||||
got, err := parseJSONInt64(tc.in)
|
||||
if tc.err && err == nil {
|
||||
t.Errorf("want err for parseJSONInt64(%s), got nil", tc.in)
|
||||
}
|
||||
if !tc.err {
|
||||
if err != nil {
|
||||
t.Errorf("parseJSONInt64(%s) unexpected error: %s", tc.in, err.Error())
|
||||
}
|
||||
if tc.want != got {
|
||||
t.Errorf("parseJSONInt64(%s) = %d, want %d", tc.in, got, tc.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzParseJSONInt64(f *testing.F) {
|
||||
for _, tc := range parseJSONInt64TestCases {
|
||||
f.Add(tc.in)
|
||||
}
|
||||
f.Fuzz(func(t *testing.T, b []byte) {
|
||||
if out, err := parseJSONInt64(b); err == nil &&
|
||||
!bytes.Contains(b, []byte(strconv.FormatInt(out, 10))) {
|
||||
t.Errorf("parseJSONInt64(%s) = %d, %v", b, out, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var parseJSONBooleanTestCases = []struct {
|
||||
in []byte
|
||||
want bool
|
||||
err bool
|
||||
}{
|
||||
{in: []byte(": true "), want: true},
|
||||
{in: []byte(":true "), want: true},
|
||||
{in: []byte(": false "), want: false},
|
||||
{in: []byte(":false "), want: false},
|
||||
{in: []byte("true"), err: true},
|
||||
{in: []byte("false"), err: true},
|
||||
{in: []byte("foo"), err: true},
|
||||
}
|
||||
|
||||
func TestParseJSONBoolean(t *testing.T) {
|
||||
for i, tc := range parseJSONBooleanTestCases {
|
||||
t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
|
||||
got, err := parseJSONBoolean(tc.in)
|
||||
if tc.err && err == nil {
|
||||
t.Errorf("want err for parseJSONBoolean(%s), got nil", tc.in)
|
||||
}
|
||||
if !tc.err {
|
||||
if err != nil {
|
||||
t.Errorf("parseJSONBoolean(%s) unexpected error: %s", tc.in, err.Error())
|
||||
}
|
||||
if tc.want != got {
|
||||
t.Errorf("parseJSONBoolean(%s) = %t, want %t", tc.in, got, tc.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzParseJSONBoolean(f *testing.F) {
|
||||
for _, tc := range parseJSONBooleanTestCases {
|
||||
f.Add(tc.in)
|
||||
}
|
||||
f.Fuzz(func(t *testing.T, b []byte) {
|
||||
if out, err := parseJSONBoolean(b); err == nil && !bytes.Contains(b, []byte(strconv.FormatBool(out))) {
|
||||
t.Errorf("parseJSONBoolean(%s) = %t, %v", b, out, err)
|
||||
}
|
||||
})
|
||||
}
|
@ -1,29 +0,0 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !linux
|
||||
|
||||
package user
|
||||
|
||||
import "context"
|
||||
|
||||
func (cl userdbClient) lookupGroup(_ context.Context, _ string) (*Group, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (cl userdbClient) lookupGroupId(_ context.Context, _ string) (*Group, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (cl userdbClient) lookupUser(_ context.Context, _ string) (*User, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (cl userdbClient) lookupUserId(_ context.Context, _ string) (*User, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (cl userdbClient) lookupGroupIds(_ context.Context, _ string) ([]string, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user