diff --git a/go/fory/array.go b/go/fory/array.go index 9b8b9c17d3..6cb1751f3f 100644 --- a/go/fory/array.go +++ b/go/fory/array.go @@ -365,8 +365,7 @@ func (s byteArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeType func (s byteArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() - err := ctx.Err() - length := buf.ReadLength(err) + length := ctx.ReadCollectionLength() if ctx.HasError() { return } diff --git a/go/fory/array_primitive.go b/go/fory/array_primitive.go index 27813060b7..06c76dc782 100644 --- a/go/fory/array_primitive.go +++ b/go/fory/array_primitive.go @@ -66,7 +66,7 @@ func (s boolArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeType func (s boolArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() - length := buf.ReadLength(err) + length := ctx.ReadBinaryLength() if ctx.HasError() { return } @@ -131,7 +131,7 @@ func (s int8ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeType func (s int8ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() - length := buf.ReadLength(err) + length := ctx.ReadBinaryLength() if ctx.HasError() { return } @@ -197,7 +197,7 @@ func (s int16ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeTyp func (s int16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() - size := buf.ReadLength(err) + size := ctx.ReadBinaryLength() length := size / 2 if ctx.HasError() { return @@ -269,7 +269,7 @@ func (s int32ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeTyp func (s int32ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() - size := buf.ReadLength(err) + size := ctx.ReadBinaryLength() length := size / 4 if ctx.HasError() { return @@ -341,7 +341,7 @@ func (s int64ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeTyp func (s int64ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() - size := buf.ReadLength(err) + size := ctx.ReadBinaryLength() length := size / 8 if ctx.HasError() { return @@ -413,7 +413,7 @@ func (s float32ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeT func (s float32ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() - size := buf.ReadLength(err) + size := ctx.ReadBinaryLength() length := size / 4 if ctx.HasError() { return @@ -485,7 +485,7 @@ func (s float64ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeT func (s float64ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() - size := buf.ReadLength(err) + size := ctx.ReadBinaryLength() length := size / 8 if ctx.HasError() { return @@ -556,7 +556,7 @@ func (s uint8ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeTyp func (s uint8ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() - length := buf.ReadLength(err) + length := ctx.ReadBinaryLength() if ctx.HasError() { return } @@ -623,7 +623,7 @@ func (s uint16ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeTy func (s uint16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() - size := buf.ReadLength(err) + size := ctx.ReadBinaryLength() length := size / 2 if ctx.HasError() { return @@ -694,7 +694,7 @@ func (s uint32ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeTy func (s uint32ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() - size := buf.ReadLength(err) + size := ctx.ReadBinaryLength() length := size / 4 if ctx.HasError() { return @@ -764,7 +764,7 @@ func (s uint64ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeTy func (s uint64ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() - size := buf.ReadLength(err) + size := ctx.ReadBinaryLength() length := size / 8 if ctx.HasError() { return @@ -838,7 +838,7 @@ func (s float16ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeT func (s float16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() - size := buf.ReadLength(ctxErr) + size := ctx.ReadBinaryLength() length := size / 2 if ctx.HasError() { return @@ -912,7 +912,7 @@ func (s bfloat16ArraySerializer) Write(ctx *WriteContext, refMode RefMode, write func (s bfloat16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() - size := buf.ReadLength(ctxErr) + size := ctx.ReadBinaryLength() length := size / 2 if ctx.HasError() { return diff --git a/go/fory/codegen/decoder.go b/go/fory/codegen/decoder.go index d3713a2433..8e580b053d 100644 --- a/go/fory/codegen/decoder.go +++ b/go/fory/codegen/decoder.go @@ -168,7 +168,7 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\tisXlang := ctx.TypeResolver().IsXlang()\n") fmt.Fprintf(buf, "\t\tif isXlang {\n") fmt.Fprintf(buf, "\t\t\t// xlang mode: slices are not nullable, read directly without null flag\n") - fmt.Fprintf(buf, "\t\t\tsliceLen := int(buf.ReadVarUint32(err))\n") + fmt.Fprintf(buf, "\t\t\tsliceLen := ctx.ReadCollectionLength()\n") fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make([]any, 0)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t} else {\n") @@ -187,7 +187,7 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\t\tif nullFlag == -3 {\n") // NullFlag fmt.Fprintf(buf, "\t\t\t\t%s = nil\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t} else {\n") - fmt.Fprintf(buf, "\t\t\t\tsliceLen := int(buf.ReadVarUint32(err))\n") + fmt.Fprintf(buf, "\t\t\t\tsliceLen := ctx.ReadCollectionLength()\n") fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make([]any, 0)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -517,7 +517,7 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc fmt.Fprintf(buf, "\t\tisXlang := ctx.TypeResolver().IsXlang()\n") fmt.Fprintf(buf, "\t\tif isXlang {\n") fmt.Fprintf(buf, "\t\t\t// xlang mode: slices are not nullable, read directly without null flag\n") - fmt.Fprintf(buf, "\t\t\tsliceLen := int(buf.ReadVarUint32(err))\n") + fmt.Fprintf(buf, "\t\t\tsliceLen := ctx.ReadCollectionLength()\n") fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make(%s, 0)\n", fieldAccess, sliceType.String()) fmt.Fprintf(buf, "\t\t\t} else {\n") @@ -532,7 +532,7 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc fmt.Fprintf(buf, "\t\t\tif nullFlag == -3 {\n") // NullFlag fmt.Fprintf(buf, "\t\t\t\t%s = nil\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t} else {\n") - fmt.Fprintf(buf, "\t\t\t\tsliceLen := int(buf.ReadVarUint32(err))\n") + fmt.Fprintf(buf, "\t\t\t\tsliceLen := ctx.ReadCollectionLength()\n") fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s, 0)\n", fieldAccess, sliceType.String()) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -555,7 +555,7 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi unwrappedElem := types.Unalias(elemType) if iface, ok := unwrappedElem.(*types.Interface); ok && iface.Empty() { fmt.Fprintf(buf, "%s// Dynamic slice []any handling - no null flag\n", indent) - fmt.Fprintf(buf, "%ssliceLen := int(buf.ReadVarUint32(err))\n", indent) + fmt.Fprintf(buf, "%ssliceLen := ctx.ReadCollectionLength()\n", indent) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make([]any, 0)\n", indent, fieldAccess) fmt.Fprintf(buf, "%s} else {\n", indent) @@ -573,7 +573,7 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi } elemIsReferencable := isReferencableType(elemType) - fmt.Fprintf(buf, "%ssliceLen := int(buf.ReadVarUint32(err))\n", indent) + fmt.Fprintf(buf, "%ssliceLen := ctx.ReadCollectionLength()\n", indent) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make(%s, 0)\n", indent, fieldAccess, sliceType.String()) fmt.Fprintf(buf, "%s} else {\n", indent) @@ -703,7 +703,7 @@ func writePrimitiveSliceReadCall(buf *bytes.Buffer, basic *types.Basic, fieldAcc case types.Int8: fmt.Fprintf(buf, "%s%s = fory.ReadInt8Slice(buf, err)\n", indent, fieldAccess) case types.Uint8: - fmt.Fprintf(buf, "%ssizeBytes := buf.ReadLength(err)\n", indent) + fmt.Fprintf(buf, "%ssizeBytes := ctx.ReadBinaryLength()\n", indent) fmt.Fprintf(buf, "%s%s = make([]uint8, sizeBytes)\n", indent, fieldAccess) fmt.Fprintf(buf, "%sif sizeBytes > 0 {\n", indent) fmt.Fprintf(buf, "%s\traw := buf.ReadBinary(sizeBytes, err)\n", indent) @@ -925,7 +925,7 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st fmt.Fprintf(buf, "\t\tisXlang := ctx.TypeResolver().IsXlang()\n") fmt.Fprintf(buf, "\t\tif isXlang {\n") fmt.Fprintf(buf, "\t\t\t// xlang mode: maps are not nullable, read directly without null flag\n") - fmt.Fprintf(buf, "\t\t\tmapLen := int(buf.ReadVarUint32(err))\n") + fmt.Fprintf(buf, "\t\t\tmapLen := ctx.ReadCollectionLength()\n") fmt.Fprintf(buf, "\t\t\tif mapLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make(%s)\n", fieldAccess, mapType.String()) fmt.Fprintf(buf, "\t\t\t} else {\n") @@ -940,7 +940,7 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st fmt.Fprintf(buf, "\t\t\tif nullFlag == -3 {\n") // NullFlag fmt.Fprintf(buf, "\t\t\t\t%s = nil\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t} else {\n") - fmt.Fprintf(buf, "\t\t\t\tmapLen := int(buf.ReadVarUint32(err))\n") + fmt.Fprintf(buf, "\t\t\t\tmapLen := ctx.ReadCollectionLength()\n") fmt.Fprintf(buf, "\t\t\t\tif mapLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s)\n", fieldAccess, mapType.String()) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -972,7 +972,7 @@ func generateMapReadInlineNoNull(buf *bytes.Buffer, mapType *types.Map, fieldAcc } indent := "\t\t\t" - fmt.Fprintf(buf, "%smapLen := int(buf.ReadVarUint32(err))\n", indent) + fmt.Fprintf(buf, "%smapLen := ctx.ReadCollectionLength()\n", indent) fmt.Fprintf(buf, "%sif mapLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make(%s)\n", indent, fieldAccess, mapType.String()) fmt.Fprintf(buf, "%s} else {\n", indent) diff --git a/go/fory/errors.go b/go/fory/errors.go index 25f7dd087a..6dc092bf2d 100644 --- a/go/fory/errors.go +++ b/go/fory/errors.go @@ -52,6 +52,10 @@ const ( ErrKindInvalidTag // ErrKindInvalidUTF16String indicates malformed UTF-16 string data ErrKindInvalidUTF16String + // ErrKindMaxCollectionSizeExceeded indicates max collection size exceeded + ErrKindMaxCollectionSizeExceeded + // ErrKindMaxBinarySizeExceeded indicates max binary size exceeded + ErrKindMaxBinarySizeExceeded ) // Error is a lightweight error type optimized for hot path performance. @@ -296,6 +300,26 @@ func InvalidUTF16StringError(byteCount int) Error { }) } +// MaxCollectionSizeExceededError creates a max collection size exceeded error +// +//go:noinline +func MaxCollectionSizeExceededError(size, limit int) Error { + return panicIfEnabled(Error{ + kind: ErrKindMaxCollectionSizeExceeded, + message: fmt.Sprintf("max collection size exceeded: size=%d, limit=%d", size, limit), + }) +} + +// MaxBinarySizeExceededError creates a max binary size exceeded error +// +//go:noinline +func MaxBinarySizeExceededError(size, limit int) Error { + return panicIfEnabled(Error{ + kind: ErrKindMaxBinarySizeExceeded, + message: fmt.Sprintf("max binary size exceeded: size=%d, limit=%d", size, limit), + }) +} + // WrapError wraps a standard error into a fory Error // //go:noinline diff --git a/go/fory/fory.go b/go/fory/fory.go index 342b0acc3a..57da20c576 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -50,18 +50,22 @@ const ( // Config holds configuration options for Fory instances type Config struct { - TrackRef bool - MaxDepth int - IsXlang bool - Compatible bool // Schema evolution compatibility mode + TrackRef bool + MaxDepth int + IsXlang bool + Compatible bool // Schema evolution compatibility mode + MaxCollectionSize int + MaxBinarySize int } // defaultConfig returns the default configuration func defaultConfig() Config { return Config{ - TrackRef: false, // Match Java's default: reference tracking disabled - MaxDepth: 20, - IsXlang: false, + TrackRef: false, // Match Java's default: reference tracking disabled + MaxDepth: 20, + IsXlang: false, + MaxCollectionSize: 1_000_000, + MaxBinarySize: 64 * 1024 * 1024, } } @@ -101,6 +105,20 @@ func WithCompatible(enabled bool) Option { } } +// WithMaxCollectionSize sets the maximum collection size limit +func WithMaxCollectionSize(size int) Option { + return func(f *Fory) { + f.config.MaxCollectionSize = size + } +} + +// WithMaxBinarySize sets the maximum binary size limit +func WithMaxBinarySize(size int) Option { + return func(f *Fory) { + f.config.MaxBinarySize = size + } +} + // ============================================================================ // Fory - Main serialization instance // ============================================================================ @@ -152,6 +170,8 @@ func New(opts ...Option) *Fory { f.writeCtx.xlang = f.config.IsXlang f.readCtx = NewReadContext(f.config.TrackRef) + f.readCtx.maxCollectionSize = f.config.MaxCollectionSize + f.readCtx.maxBinarySize = f.config.MaxBinarySize f.readCtx.typeResolver = f.typeResolver f.readCtx.refResolver = f.refResolver f.readCtx.compatible = f.config.Compatible diff --git a/go/fory/limit_test.go b/go/fory/limit_test.go new file mode 100644 index 0000000000..64d32f7a3d --- /dev/null +++ b/go/fory/limit_test.go @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestMaxCollectionSizeGuardrail(t *testing.T) { + // 1. Test slice exceeding limit + t.Run("Slice exceeds MaxCollectionSize", func(t *testing.T) { + config := WithMaxCollectionSize(2) + f := NewFory(config) + + slice := []string{"a", "b", "c"} + fBase := NewFory() + bytes, _ := fBase.Serialize(slice) + + var decoded []string + err := f.Deserialize(bytes, &decoded) + require.Error(t, err) + require.Contains(t, err.Error(), "max collection size exceeded: size=3, limit=2") + }) + + // 2. Test map exceeding limit + t.Run("Map exceeds MaxCollectionSize", func(t *testing.T) { + config := WithMaxCollectionSize(2) + f := NewFory(config) + + m := map[int32]int32{1: 1, 2: 2, 3: 3} + fBase := NewFory() + bytes, _ := fBase.Serialize(m) + + var decoded map[int32]int32 + err := f.Deserialize(bytes, &decoded) + require.Error(t, err) + require.Contains(t, err.Error(), "max collection size exceeded: size=3, limit=2") + }) + + // 3. Test string is not affected by MaxCollectionSize + t.Run("String unaffected by MaxCollectionSize", func(t *testing.T) { + config := WithMaxCollectionSize(2) + f := NewFory(config) + + str := "hello world" // length 11 + bytes, err := f.Serialize(str) + require.NoError(t, err) + + var decoded string + err = f.Deserialize(bytes, &decoded) + require.NoError(t, err) + require.Equal(t, str, decoded) + }) +} + +func TestMaxBinarySizeGuardrail(t *testing.T) { + // 1. Test binary (byte slice) exceeding limit + t.Run("Byte slice exceeds MaxBinarySize", func(t *testing.T) { + config := WithMaxBinarySize(5) + f := NewFory(config) + + // We can serialize a byte slice using standard serializer, then decode with the f instance + slice := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + fBase := NewFory() + bytes, _ := fBase.Serialize(slice) + + var decoded []byte + err := f.Deserialize(bytes, &decoded) + require.Error(t, err) + require.Contains(t, err.Error(), "max binary size exceeded: size=10, limit=5") + }) + + // 2. Test string is not affected by MaxBinarySize + t.Run("String unaffected by MaxBinarySize", func(t *testing.T) { + config := WithMaxBinarySize(2) + f := NewFory(config) + + str := "hello world" // length 11 + bytes, err := f.Serialize(str) + require.NoError(t, err) + + var decoded string + err = f.Deserialize(bytes, &decoded) + require.NoError(t, err) + require.Equal(t, str, decoded) + }) +} diff --git a/go/fory/map.go b/go/fory/map.go index f2489601f3..ace2c55973 100644 --- a/go/fory/map.go +++ b/go/fory/map.go @@ -305,7 +305,7 @@ func (s mapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { } refResolver.Reference(value) - size := int(buf.ReadVarUint32(ctxErr)) + size := ctx.ReadCollectionLength() if size == 0 || ctx.HasError() { return } diff --git a/go/fory/map_primitive.go b/go/fory/map_primitive.go index 21a4bd7b5d..16e40f5aef 100644 --- a/go/fory/map_primitive.go +++ b/go/fory/map_primitive.go @@ -69,8 +69,10 @@ func writeMapStringString(buf *ByteBuffer, m map[string]string, hasGenerics bool } // readMapStringString reads map[string]string using chunk protocol -func readMapStringString(buf *ByteBuffer, err *Error) map[string]string { - size := int(buf.ReadVarUint32(err)) +func readMapStringString(ctx *ReadContext) map[string]string { + err := ctx.Err() + buf := ctx.Buffer() + size := ctx.ReadCollectionLength() result := make(map[string]string, size) if size == 0 { return result @@ -172,8 +174,10 @@ func writeMapStringInt64(buf *ByteBuffer, m map[string]int64, hasGenerics bool) } // readMapStringInt64 reads map[string]int64 using chunk protocol -func readMapStringInt64(buf *ByteBuffer, err *Error) map[string]int64 { - size := int(buf.ReadVarUint32(err)) +func readMapStringInt64(ctx *ReadContext) map[string]int64 { + err := ctx.Err() + buf := ctx.Buffer() + size := ctx.ReadCollectionLength() result := make(map[string]int64, size) if size == 0 { return result @@ -246,8 +250,10 @@ func writeMapStringInt32(buf *ByteBuffer, m map[string]int32, hasGenerics bool) } // readMapStringInt32 reads map[string]int32 using chunk protocol -func readMapStringInt32(buf *ByteBuffer, err *Error) map[string]int32 { - size := int(buf.ReadVarUint32(err)) +func readMapStringInt32(ctx *ReadContext) map[string]int32 { + err := ctx.Err() + buf := ctx.Buffer() + size := ctx.ReadCollectionLength() result := make(map[string]int32, size) if size == 0 { return result @@ -320,8 +326,10 @@ func writeMapStringInt(buf *ByteBuffer, m map[string]int, hasGenerics bool) { } // readMapStringInt reads map[string]int using chunk protocol -func readMapStringInt(buf *ByteBuffer, err *Error) map[string]int { - size := int(buf.ReadVarUint32(err)) +func readMapStringInt(ctx *ReadContext) map[string]int { + err := ctx.Err() + buf := ctx.Buffer() + size := ctx.ReadCollectionLength() result := make(map[string]int, size) if size == 0 { return result @@ -394,8 +402,10 @@ func writeMapStringFloat64(buf *ByteBuffer, m map[string]float64, hasGenerics bo } // readMapStringFloat64 reads map[string]float64 using chunk protocol -func readMapStringFloat64(buf *ByteBuffer, err *Error) map[string]float64 { - size := int(buf.ReadVarUint32(err)) +func readMapStringFloat64(ctx *ReadContext) map[string]float64 { + err := ctx.Err() + buf := ctx.Buffer() + size := ctx.ReadCollectionLength() result := make(map[string]float64, size) if size == 0 { return result @@ -468,8 +478,10 @@ func writeMapStringBool(buf *ByteBuffer, m map[string]bool, hasGenerics bool) { } // readMapStringBool reads map[string]bool using chunk protocol -func readMapStringBool(buf *ByteBuffer, err *Error) map[string]bool { - size := int(buf.ReadVarUint32(err)) +func readMapStringBool(ctx *ReadContext) map[string]bool { + err := ctx.Err() + buf := ctx.Buffer() + size := ctx.ReadCollectionLength() result := make(map[string]bool, size) if size == 0 { return result @@ -547,8 +559,10 @@ func writeMapInt32Int32(buf *ByteBuffer, m map[int32]int32, hasGenerics bool) { } // readMapInt32Int32 reads map[int32]int32 using chunk protocol -func readMapInt32Int32(buf *ByteBuffer, err *Error) map[int32]int32 { - size := int(buf.ReadVarUint32(err)) +func readMapInt32Int32(ctx *ReadContext) map[int32]int32 { + err := ctx.Err() + buf := ctx.Buffer() + size := ctx.ReadCollectionLength() result := make(map[int32]int32, size) if size == 0 { return result @@ -621,8 +635,10 @@ func writeMapInt64Int64(buf *ByteBuffer, m map[int64]int64, hasGenerics bool) { } // readMapInt64Int64 reads map[int64]int64 using chunk protocol -func readMapInt64Int64(buf *ByteBuffer, err *Error) map[int64]int64 { - size := int(buf.ReadVarUint32(err)) +func readMapInt64Int64(ctx *ReadContext) map[int64]int64 { + err := ctx.Err() + buf := ctx.Buffer() + size := ctx.ReadCollectionLength() result := make(map[int64]int64, size) if size == 0 { return result @@ -695,8 +711,10 @@ func writeMapIntInt(buf *ByteBuffer, m map[int]int, hasGenerics bool) { } // readMapIntInt reads map[int]int using chunk protocol -func readMapIntInt(buf *ByteBuffer, err *Error) map[int]int { - size := int(buf.ReadVarUint32(err)) +func readMapIntInt(ctx *ReadContext) map[int]int { + err := ctx.Err() + buf := ctx.Buffer() + size := ctx.ReadCollectionLength() result := make(map[int]int, size) if size == 0 { return result @@ -752,7 +770,7 @@ func (s stringStringMapSerializer) ReadData(ctx *ReadContext, value reflect.Valu value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapStringString(ctx.buffer, ctx.Err()) + result := readMapStringString(ctx) value.Set(reflect.ValueOf(result)) } @@ -787,7 +805,7 @@ func (s stringInt64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapStringInt64(ctx.buffer, ctx.Err()) + result := readMapStringInt64(ctx) value.Set(reflect.ValueOf(result)) } @@ -822,7 +840,7 @@ func (s stringIntMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapStringInt(ctx.buffer, ctx.Err()) + result := readMapStringInt(ctx) value.Set(reflect.ValueOf(result)) } @@ -857,7 +875,7 @@ func (s stringFloat64MapSerializer) ReadData(ctx *ReadContext, value reflect.Val value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapStringFloat64(ctx.buffer, ctx.Err()) + result := readMapStringFloat64(ctx) value.Set(reflect.ValueOf(result)) } @@ -892,7 +910,7 @@ func (s stringBoolMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapStringBool(ctx.buffer, ctx.Err()) + result := readMapStringBool(ctx) value.Set(reflect.ValueOf(result)) } @@ -927,7 +945,7 @@ func (s int32Int32MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapInt32Int32(ctx.buffer, ctx.Err()) + result := readMapInt32Int32(ctx) value.Set(reflect.ValueOf(result)) } @@ -962,7 +980,7 @@ func (s int64Int64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapInt64Int64(ctx.buffer, ctx.Err()) + result := readMapInt64Int64(ctx) value.Set(reflect.ValueOf(result)) } @@ -997,7 +1015,7 @@ func (s intIntMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapIntInt(ctx.buffer, ctx.Err()) + result := readMapIntInt(ctx) value.Set(reflect.ValueOf(result)) } diff --git a/go/fory/reader.go b/go/fory/reader.go index a0a37d92fb..9c8b049ad2 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -29,20 +29,22 @@ import ( // ReadContext holds all state needed during deserialization. type ReadContext struct { - buffer *ByteBuffer - refReader *RefReader - trackRef bool // Cached flag to avoid indirection - xlang bool // Cross-language serialization mode - compatible bool // Schema evolution compatibility mode - typeResolver *TypeResolver // For complex type deserialization - refResolver *RefResolver // For reference tracking (legacy) - outOfBandBuffers []*ByteBuffer // Out-of-band buffers for deserialization - outOfBandIndex int // Current index into out-of-band buffers - depth int // Current nesting depth for cycle detection - maxDepth int // Maximum allowed nesting depth - err Error // Accumulated error state for deferred checking - lastTypePtr uintptr - lastTypeInfo *TypeInfo + buffer *ByteBuffer + refReader *RefReader + trackRef bool // Cached flag to avoid indirection + xlang bool // Cross-language serialization mode + compatible bool // Schema evolution compatibility mode + typeResolver *TypeResolver // For complex type deserialization + refResolver *RefResolver // For reference tracking (legacy) + outOfBandBuffers []*ByteBuffer // Out-of-band buffers for deserialization + outOfBandIndex int // Current index into out-of-band buffers + depth int // Current nesting depth for cycle detection + maxDepth int // Maximum allowed nesting depth + err Error // Accumulated error state for deferred checking + lastTypePtr uintptr + lastTypeInfo *TypeInfo + maxCollectionSize int // Size guardrail for collection reads + maxBinarySize int // Size guardrail for binary reads } // IsXlang returns whether cross-language serialization mode is enabled @@ -237,10 +239,32 @@ func (c *ReadContext) ReadAndValidateTypeId(expected TypeId) { } } -// ReadLength reads a length value as varint (non-negative values) -func (c *ReadContext) ReadLength() int { +// ReadCollectionLength reads a length value for collections with size guardrails +func (c *ReadContext) ReadCollectionLength() int { err := c.Err() - return int(c.buffer.ReadVarUint32(err)) + length := c.buffer.ReadLength(err) + if c.err.HasError() { + return 0 + } + if length > c.maxCollectionSize { + c.SetError(MaxCollectionSizeExceededError(length, c.maxCollectionSize)) + return 0 + } + return length +} + +// ReadBinaryLength reads a length value for binary data with size guardrails +func (c *ReadContext) ReadBinaryLength() int { + err := c.Err() + length := c.buffer.ReadLength(err) + if c.err.HasError() { + return 0 + } + if length > c.maxBinarySize { + c.SetError(MaxBinarySizeExceededError(length, c.maxBinarySize)) + return 0 + } + return length } // ============================================================================ @@ -434,7 +458,7 @@ func (c *ReadContext) ReadByteSlice(refMode RefMode, readType bool) []byte { if readType { _ = c.buffer.ReadUint8(err) } - size := c.buffer.ReadLength(err) + size := c.ReadBinaryLength() return c.buffer.ReadBinary(size, err) } @@ -463,7 +487,7 @@ func (c *ReadContext) ReadStringStringMap(refMode RefMode, readType bool) map[st if readType { _ = c.buffer.ReadUint8(err) } - return readMapStringString(c.buffer, err) + return readMapStringString(c) } // ReadStringInt64Map reads map[string]int64 with optional ref/type info @@ -477,7 +501,7 @@ func (c *ReadContext) ReadStringInt64Map(refMode RefMode, readType bool) map[str if readType { _ = c.buffer.ReadUint8(err) } - return readMapStringInt64(c.buffer, err) + return readMapStringInt64(c) } // ReadStringInt32Map reads map[string]int32 with optional ref/type info @@ -491,7 +515,7 @@ func (c *ReadContext) ReadStringInt32Map(refMode RefMode, readType bool) map[str if readType { _ = c.buffer.ReadUint8(err) } - return readMapStringInt32(c.buffer, err) + return readMapStringInt32(c) } // ReadStringIntMap reads map[string]int with optional ref/type info @@ -505,7 +529,7 @@ func (c *ReadContext) ReadStringIntMap(refMode RefMode, readType bool) map[strin if readType { _ = c.buffer.ReadUint8(err) } - return readMapStringInt(c.buffer, err) + return readMapStringInt(c) } // ReadStringFloat64Map reads map[string]float64 with optional ref/type info @@ -519,7 +543,7 @@ func (c *ReadContext) ReadStringFloat64Map(refMode RefMode, readType bool) map[s if readType { _ = c.buffer.ReadUint8(err) } - return readMapStringFloat64(c.buffer, err) + return readMapStringFloat64(c) } // ReadStringBoolMap reads map[string]bool with optional ref/type info @@ -533,7 +557,7 @@ func (c *ReadContext) ReadStringBoolMap(refMode RefMode, readType bool) map[stri if readType { _ = c.buffer.ReadUint8(err) } - return readMapStringBool(c.buffer, err) + return readMapStringBool(c) } // ReadInt32Int32Map reads map[int32]int32 with optional ref/type info @@ -547,7 +571,7 @@ func (c *ReadContext) ReadInt32Int32Map(refMode RefMode, readType bool) map[int3 if readType { _ = c.buffer.ReadUint8(err) } - return readMapInt32Int32(c.buffer, err) + return readMapInt32Int32(c) } // ReadInt64Int64Map reads map[int64]int64 with optional ref/type info @@ -561,7 +585,7 @@ func (c *ReadContext) ReadInt64Int64Map(refMode RefMode, readType bool) map[int6 if readType { _ = c.buffer.ReadUint8(err) } - return readMapInt64Int64(c.buffer, err) + return readMapInt64Int64(c) } // ReadIntIntMap reads map[int]int with optional ref/type info @@ -575,7 +599,7 @@ func (c *ReadContext) ReadIntIntMap(refMode RefMode, readType bool) map[int]int if readType { _ = c.buffer.ReadUint8(err) } - return readMapIntInt(c.buffer, err) + return readMapIntInt(c) } // ReadBufferObject reads a buffer object @@ -583,7 +607,7 @@ func (c *ReadContext) ReadBufferObject() *ByteBuffer { err := c.Err() isInBand := c.buffer.ReadBool(err) if isInBand { - size := c.buffer.ReadLength(err) + size := c.ReadBinaryLength() buf := c.buffer.Slice(c.buffer.readerIndex, size) c.buffer.readerIndex += size return buf diff --git a/go/fory/set.go b/go/fory/set.go index 2105b3e9df..83a7b31719 100644 --- a/go/fory/set.go +++ b/go/fory/set.go @@ -295,7 +295,7 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { err := ctx.Err() type_ := value.Type() // ReadData collection length from buffer - length := int(buf.ReadVarUint32(err)) + length := ctx.ReadCollectionLength() if length == 0 { // Initialize empty set if length is 0 value.Set(reflect.MakeMap(type_)) diff --git a/go/fory/skip.go b/go/fory/skip.go index 34005ad74d..64660dfc36 100644 --- a/go/fory/skip.go +++ b/go/fory/skip.go @@ -213,7 +213,7 @@ func readTypeInfoForSkip(ctx *ReadContext, fieldTypeId TypeId) *TypeInfo { // Uses context error state for deferred error checking. func skipCollection(ctx *ReadContext, fieldDef FieldDef) { err := ctx.Err() - length := ctx.buffer.ReadVarUint32(err) + length := uint32(ctx.ReadCollectionLength()) if ctx.HasError() || length == 0 { return } @@ -283,7 +283,7 @@ func skipCollection(ctx *ReadContext, fieldDef FieldDef) { // Uses context error state for deferred error checking. func skipMap(ctx *ReadContext, fieldDef FieldDef) { bufErr := ctx.Err() - length := ctx.buffer.ReadVarUint32(bufErr) + length := uint32(ctx.ReadCollectionLength()) if ctx.HasError() || length == 0 { return } @@ -601,31 +601,31 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo _ = ctx.buffer.ReadBinary(int(size), err) } case BINARY: - length := ctx.buffer.ReadVarUint32(err) + length := uint32(ctx.ReadBinaryLength()) if ctx.HasError() { return } _ = ctx.buffer.ReadBinary(int(length), err) case BOOL_ARRAY, INT8_ARRAY, UINT8_ARRAY: - length := ctx.buffer.ReadLength(err) + length := ctx.ReadBinaryLength() if ctx.HasError() { return } _ = ctx.buffer.ReadBinary(length, err) case INT16_ARRAY, UINT16_ARRAY, FLOAT16_ARRAY, BFLOAT16_ARRAY: - length := ctx.buffer.ReadLength(err) + length := ctx.ReadBinaryLength() if ctx.HasError() { return } _ = ctx.buffer.ReadBinary(length*2, err) case INT32_ARRAY, UINT32_ARRAY, FLOAT32_ARRAY: - length := ctx.buffer.ReadLength(err) + length := ctx.ReadBinaryLength() if ctx.HasError() { return } _ = ctx.buffer.ReadBinary(length*4, err) case INT64_ARRAY, UINT64_ARRAY, FLOAT64_ARRAY: - length := ctx.buffer.ReadLength(err) + length := ctx.ReadBinaryLength() if ctx.HasError() { return } diff --git a/go/fory/slice.go b/go/fory/slice.go index bd3a9aa7ee..90b6ff14e9 100644 --- a/go/fory/slice.go +++ b/go/fory/slice.go @@ -264,7 +264,7 @@ func (s *sliceSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, ty func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() - length := int(buf.ReadVarUint32(ctxErr)) + length := ctx.ReadCollectionLength() isArrayType := value.Type().Kind() == reflect.Array if length == 0 { diff --git a/go/fory/slice_dyn.go b/go/fory/slice_dyn.go index 3393d4b22b..90a7b8cad1 100644 --- a/go/fory/slice_dyn.go +++ b/go/fory/slice_dyn.go @@ -261,7 +261,7 @@ func (s sliceDynSerializer) Read(ctx *ReadContext, refMode RefMode, readType boo func (s sliceDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() - length := int(buf.ReadVarUint32(ctxErr)) + length := ctx.ReadCollectionLength() sliceType := value.Type() value.Set(reflect.MakeSlice(sliceType, length, length)) if length == 0 { diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go index c893908987..dddc800697 100644 --- a/go/fory/slice_primitive.go +++ b/go/fory/slice_primitive.go @@ -74,7 +74,7 @@ func (s byteSliceSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, func (s byteSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() - length := buf.ReadLength(ctxErr) + length := ctx.ReadBinaryLength() ptr := (*[]byte)(value.Addr().UnsafePointer()) if length == 0 { *ptr = make([]byte, 0) @@ -642,7 +642,7 @@ func (s stringSliceSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMod func (s stringSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() - length := int(buf.ReadVarUint32(ctxErr)) + length := ctx.ReadCollectionLength() ptr := (*[]string)(value.Addr().UnsafePointer()) if length == 0 { *ptr = make([]string, 0) @@ -1071,7 +1071,7 @@ func (s float16SliceSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMo func (s float16SliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() - size := buf.ReadLength(ctxErr) + size := ctx.ReadBinaryLength() length := size / 2 if ctx.HasError() { return @@ -1253,7 +1253,7 @@ func WriteStringSlice(buf *ByteBuffer, value []string, hasGenerics bool) { // ReadStringSlice reads []string from buffer using LIST protocol func ReadStringSlice(buf *ByteBuffer, err *Error) []string { - length := int(buf.ReadVarUint32(err)) + length := buf.ReadLength(err) if length == 0 { return make([]string, 0) } @@ -1328,7 +1328,7 @@ func (s bfloat16SliceSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefM func (s bfloat16SliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() - size := buf.ReadLength(ctxErr) + size := ctx.ReadBinaryLength() length := size / 2 if ctx.HasError() { return