1
0
mirror of https://github.com/golang/go synced 2024-11-26 02:17:58 -07:00

[release-branch.go1.23] encoding/gob: cover missed cases when checking ignore depth

This change makes sure that we are properly checking the ignored field
recursion depth in decIgnoreOpFor consistently. This prevents stack
exhaustion when attempting to decode a message that contains an
extremely deeply nested struct which is ignored.

Thanks to Md Sakib Anwar of The Ohio State University (anwar.40@osu.edu)
for reporting this issue.

Updates #69139
Fixes #69145
Fixes CVE-2024-34156

Change-Id: Iacce06be95a5892b3064f1c40fcba2e2567862d6
Reviewed-on: https://go-internal-review.googlesource.com/c/go/+/1440
Reviewed-by: Russ Cox <rsc@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
(cherry picked from commit 9f2ea73c5f2a7056b7da5d579a485a7216f4b20a)
Reviewed-on: https://go-internal-review.googlesource.com/c/go/+/1581
Commit-Queue: Roland Shoemaker <bracewell@google.com>
Reviewed-by: Tatiana Bradley <tatianabradley@google.com>
Reviewed-on: https://go-review.googlesource.com/c/go/+/611176
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Michael Pratt <mpratt@google.com>
TryBot-Bypass: Dmitri Shuralyov <dmitshur@google.com>
This commit is contained in:
Roland Shoemaker 2024-05-03 09:21:39 -04:00 committed by Gopher Robot
parent 53487e5477
commit fa8ff1a46d
3 changed files with 27 additions and 8 deletions

View File

@ -911,8 +911,11 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
var maxIgnoreNestingDepth = 10000 var maxIgnoreNestingDepth = 10000
// decIgnoreOpFor returns the decoding op for a field that has no destination. // decIgnoreOpFor returns the decoding op for a field that has no destination.
func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp, depth int) *decOp { func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp) *decOp {
if depth > maxIgnoreNestingDepth { // Track how deep we've recursed trying to skip nested ignored fields.
dec.ignoreDepth++
defer func() { dec.ignoreDepth-- }()
if dec.ignoreDepth > maxIgnoreNestingDepth {
error_(errors.New("invalid nesting depth")) error_(errors.New("invalid nesting depth"))
} }
// If this type is already in progress, it's a recursive type (e.g. map[string]*T). // If this type is already in progress, it's a recursive type (e.g. map[string]*T).
@ -938,7 +941,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp,
errorf("bad data: undefined type %s", wireId.string()) errorf("bad data: undefined type %s", wireId.string())
case wire.ArrayT != nil: case wire.ArrayT != nil:
elemId := wire.ArrayT.Elem elemId := wire.ArrayT.Elem
elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1) elemOp := dec.decIgnoreOpFor(elemId, inProgress)
op = func(i *decInstr, state *decoderState, value reflect.Value) { op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreArray(state, *elemOp, wire.ArrayT.Len) state.dec.ignoreArray(state, *elemOp, wire.ArrayT.Len)
} }
@ -946,15 +949,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp,
case wire.MapT != nil: case wire.MapT != nil:
keyId := dec.wireType[wireId].MapT.Key keyId := dec.wireType[wireId].MapT.Key
elemId := dec.wireType[wireId].MapT.Elem elemId := dec.wireType[wireId].MapT.Elem
keyOp := dec.decIgnoreOpFor(keyId, inProgress, depth+1) keyOp := dec.decIgnoreOpFor(keyId, inProgress)
elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1) elemOp := dec.decIgnoreOpFor(elemId, inProgress)
op = func(i *decInstr, state *decoderState, value reflect.Value) { op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreMap(state, *keyOp, *elemOp) state.dec.ignoreMap(state, *keyOp, *elemOp)
} }
case wire.SliceT != nil: case wire.SliceT != nil:
elemId := wire.SliceT.Elem elemId := wire.SliceT.Elem
elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1) elemOp := dec.decIgnoreOpFor(elemId, inProgress)
op = func(i *decInstr, state *decoderState, value reflect.Value) { op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreSlice(state, *elemOp) state.dec.ignoreSlice(state, *elemOp)
} }
@ -1115,7 +1118,7 @@ func (dec *Decoder) compileSingle(remoteId typeId, ut *userTypeInfo) (engine *de
func (dec *Decoder) compileIgnoreSingle(remoteId typeId) *decEngine { func (dec *Decoder) compileIgnoreSingle(remoteId typeId) *decEngine {
engine := new(decEngine) engine := new(decEngine)
engine.instr = make([]decInstr, 1) // one item engine.instr = make([]decInstr, 1) // one item
op := dec.decIgnoreOpFor(remoteId, make(map[typeId]*decOp), 0) op := dec.decIgnoreOpFor(remoteId, make(map[typeId]*decOp))
ovfl := overflow(dec.typeString(remoteId)) ovfl := overflow(dec.typeString(remoteId))
engine.instr[0] = decInstr{*op, 0, nil, ovfl} engine.instr[0] = decInstr{*op, 0, nil, ovfl}
engine.numInstr = 1 engine.numInstr = 1
@ -1160,7 +1163,7 @@ func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEn
localField, present := srt.FieldByName(wireField.Name) localField, present := srt.FieldByName(wireField.Name)
// TODO(r): anonymous names // TODO(r): anonymous names
if !present || !isExported(wireField.Name) { if !present || !isExported(wireField.Name) {
op := dec.decIgnoreOpFor(wireField.Id, make(map[typeId]*decOp), 0) op := dec.decIgnoreOpFor(wireField.Id, make(map[typeId]*decOp))
engine.instr[fieldnum] = decInstr{*op, fieldnum, nil, ovfl} engine.instr[fieldnum] = decInstr{*op, fieldnum, nil, ovfl}
continue continue
} }

View File

@ -35,6 +35,8 @@ type Decoder struct {
freeList *decoderState // list of free decoderStates; avoids reallocation freeList *decoderState // list of free decoderStates; avoids reallocation
countBuf []byte // used for decoding integers while parsing messages countBuf []byte // used for decoding integers while parsing messages
err error err error
// ignoreDepth tracks the depth of recursively parsed ignored fields
ignoreDepth int
} }
// NewDecoder returns a new decoder that reads from the [io.Reader]. // NewDecoder returns a new decoder that reads from the [io.Reader].

View File

@ -806,6 +806,8 @@ func TestIgnoreDepthLimit(t *testing.T) {
defer func() { maxIgnoreNestingDepth = oldNestingDepth }() defer func() { maxIgnoreNestingDepth = oldNestingDepth }()
b := new(bytes.Buffer) b := new(bytes.Buffer)
enc := NewEncoder(b) enc := NewEncoder(b)
// Nested slice
typ := reflect.TypeFor[int]() typ := reflect.TypeFor[int]()
nested := reflect.ArrayOf(1, typ) nested := reflect.ArrayOf(1, typ)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
@ -819,4 +821,16 @@ func TestIgnoreDepthLimit(t *testing.T) {
if err := dec.Decode(&output); err == nil || err.Error() != expectedErr { if err := dec.Decode(&output); err == nil || err.Error() != expectedErr {
t.Errorf("Decode didn't fail with depth limit of 100: want %q, got %q", expectedErr, err) t.Errorf("Decode didn't fail with depth limit of 100: want %q, got %q", expectedErr, err)
} }
// Nested struct
nested = reflect.StructOf([]reflect.StructField{{Name: "F", Type: typ}})
for i := 0; i < 100; i++ {
nested = reflect.StructOf([]reflect.StructField{{Name: "F", Type: nested}})
}
badStruct = reflect.New(reflect.StructOf([]reflect.StructField{{Name: "F", Type: nested}}))
enc.Encode(badStruct.Interface())
dec = NewDecoder(b)
if err := dec.Decode(&output); err == nil || err.Error() != expectedErr {
t.Errorf("Decode didn't fail with depth limit of 100: want %q, got %q", expectedErr, err)
}
} }