diff --git a/rust/README.md b/rust/README.md index c951bd68c7..841635d854 100644 --- a/rust/README.md +++ b/rust/README.md @@ -670,7 +670,7 @@ cargo build ```bash # Run all tests -cargo test --features tests +cargo test --workspace # Run specific test cargo test -p tests --test test_complex_struct diff --git a/rust/fory-core/src/meta/type_meta.rs b/rust/fory-core/src/meta/type_meta.rs index d4ab3a956a..1bb5fd1354 100644 --- a/rust/fory-core/src/meta/type_meta.rs +++ b/rust/fory-core/src/meta/type_meta.rs @@ -60,6 +60,7 @@ use std::collections::HashMap; use std::rc::Rc; const SMALL_NUM_FIELDS_THRESHOLD: usize = 0b11111; +const MAX_TYPE_META_FIELDS: usize = i16::MAX as usize; const REGISTER_BY_NAME_FLAG: u8 = 0b100000; const FIELD_NAME_SIZE_THRESHOLD: usize = 0b1111; /// Marker value in encoding bits to indicate field ID mode (instead of field name) @@ -642,6 +643,9 @@ impl TypeMeta { length as usize }; let bytes = reader.read_bytes(length)?; + if encoding_idx as usize >= encodings.len() { + return Err(Error::invalid_data("encoding_index out of bounds")); + } let encoding = encodings[encoding_idx as usize]; decoder.decode(bytes, encoding) } @@ -817,6 +821,13 @@ impl TypeMeta { if num_fields == SMALL_NUM_FIELDS_THRESHOLD { num_fields += reader.read_varuint32()? as usize; } + // limit the number of fields to prevent potential OOM when creating Vec + if num_fields > MAX_TYPE_META_FIELDS { + return Err(Error::invalid_data(format!( + "too many fields in type meta: {}, max: {}", + num_fields, MAX_TYPE_META_FIELDS + ))); + } let mut type_id; let mut user_type_id = NO_USER_TYPE_ID; let namespace; diff --git a/rust/fory-core/src/row/bit_util.rs b/rust/fory-core/src/row/bit_util.rs index bbb94e5f79..faa41a73b6 100644 --- a/rust/fory-core/src/row/bit_util.rs +++ b/rust/fory-core/src/row/bit_util.rs @@ -18,5 +18,5 @@ const WORD_SIZE: usize = 8; pub fn calculate_bitmap_width_in_bytes(num_fields: usize) -> usize { - ((num_fields + 63) / 64) * WORD_SIZE + (num_fields.saturating_add(63) / 64).saturating_mul(WORD_SIZE) } diff --git a/rust/fory-core/src/row/reader.rs b/rust/fory-core/src/row/reader.rs index 1549a55ea7..4d1a5815a6 100644 --- a/rust/fory-core/src/row/reader.rs +++ b/rust/fory-core/src/row/reader.rs @@ -24,12 +24,20 @@ struct FieldAccessorHelper<'a> { } impl<'a> FieldAccessorHelper<'a> { - fn get_offset_size(&self, idx: usize) -> (u32, u32) { + fn read_u32(row: &[u8], offset: usize) -> Option { + // in case `row` does not have enough bytes + let end = offset.checked_add(4)?; + let bytes = row.get(offset..end)?; + Some(LittleEndian::read_u32(bytes)) + } + + fn get_offset_size(&self, idx: usize) -> Option<(u32, u32)> { let row = self.row; let field_offset = (self.get_field_offset)(idx); - let offset = LittleEndian::read_u32(&row[field_offset..field_offset + 4]); - let size = LittleEndian::read_u32(&row[field_offset + 4..field_offset + 8]); - (offset, size) + let offset = Self::read_u32(row, field_offset)?; + let size_offset = field_offset.checked_add(4)?; + let size = Self::read_u32(row, size_offset)?; + Some((offset, size)) } pub fn new( @@ -44,8 +52,15 @@ impl<'a> FieldAccessorHelper<'a> { pub fn get_field_bytes(&self, idx: usize) -> &'a [u8] { let row = self.row; - let (offset, size) = self.get_offset_size(idx); - &row[(offset as usize)..(offset + size) as usize] + let Some((offset, size)) = self.get_offset_size(idx) else { + return &[]; + }; + let offset = offset as usize; + let size = size as usize; + let Some(end) = offset.checked_add(size) else { + return &[]; + }; + row.get(offset..end).unwrap_or(&[]) } } @@ -76,7 +91,10 @@ pub struct ArrayViewer<'r> { impl<'r> ArrayViewer<'r> { pub fn new(row: &'r [u8]) -> ArrayViewer<'r> { - let num_elements = LittleEndian::read_u64(&row[0..8]) as usize; + let num_elements = row + .get(0..8) + .map(|bytes| LittleEndian::read_u64(bytes) as usize) + .unwrap_or(0); let bit_map_width_in_bytes = calculate_bitmap_width_in_bytes(num_elements); ArrayViewer { num_elements, @@ -103,10 +121,17 @@ pub struct MapViewer<'r> { impl<'r> MapViewer<'r> { pub fn new(row: &'r [u8]) -> MapViewer<'r> { - let key_byte_size = LittleEndian::read_u64(&row[0..8]) as usize; + let Some(header) = row.get(0..8) else { + return MapViewer { + key_row: &[], + value_row: &[], + }; + }; + let key_byte_size = LittleEndian::read_u64(header) as usize; + let key_end = (8usize).saturating_add(key_byte_size).min(row.len()); MapViewer { - value_row: &row[key_byte_size + 8..row.len()], - key_row: &row[8..key_byte_size + 8], + value_row: &row[key_end..], + key_row: &row[8..key_end], } } diff --git a/rust/fory-core/src/serializer/collection.rs b/rust/fory-core/src/serializer/collection.rs index 68a6dc6a4d..ae16620c7e 100644 --- a/rust/fory-core/src/serializer/collection.rs +++ b/rust/fory-core/src/serializer/collection.rs @@ -32,6 +32,19 @@ pub const DECL_ELEMENT_TYPE: u8 = 0b100; // Whether collection elements type same. pub const IS_SAME_TYPE: u8 = 0b1000; +fn check_collection_len(context: &ReadContext, len: u32) -> Result<(), Error> { + if std::mem::size_of::() == 0 { + return Ok(()); + } + let len = len as usize; + let remaining = context.reader.slice_after_cursor().len(); + if len > remaining { + let cursor = context.reader.get_cursor(); + return Err(Error::buffer_out_of_bound(cursor, len, cursor + remaining)); + } + Ok(()) +} + pub fn write_collection_type_info( context: &mut WriteContext, collection_type_id: u32, @@ -245,6 +258,7 @@ where (header & IS_SAME_TYPE) != 0, Error::type_error("Type inconsistent, target type is not polymorphic") ); + check_collection_len::(context, len)?; if !has_null { (0..len) .map(|_| T::fory_read_data(context)) @@ -284,6 +298,7 @@ where (header & IS_SAME_TYPE) != 0, Error::type_error("Type inconsistent, target type is not polymorphic") ); + check_collection_len::(context, len)?; let mut vec = Vec::with_capacity(len as usize); if !has_null { for _ in 0..len { @@ -320,14 +335,14 @@ where } else { RefMode::None }; - - let mut vec = Vec::with_capacity(len as usize); if is_same_type { let type_info = if !is_declared { context.read_any_type_info()? } else { T::fory_get_type_info(context.get_type_resolver())? }; + check_collection_len::(context, len)?; + let mut vec = Vec::with_capacity(len as usize); if elem_ref_mode == RefMode::None { for _ in 0..len { vec.push(T::fory_read_with_type_info( @@ -345,12 +360,15 @@ where )?); } } + Ok(vec) } else { + check_collection_len::(context, len)?; + let mut vec = Vec::with_capacity(len as usize); for _ in 0..len { vec.push(T::fory_read(context, elem_ref_mode, true)?); } + Ok(vec) } - Ok(vec) } /// Slow but versatile collection deserialization for dynamic trait object and shared/circular reference. @@ -382,6 +400,7 @@ where } else { T::fory_get_type_info(context.get_type_resolver())? }; + check_collection_len::(context, len)?; // All elements are same type if elem_ref_mode == RefMode::None { // No null elements, no tracking @@ -395,6 +414,7 @@ where .collect::>() } } else { + check_collection_len::(context, len)?; (0..len) .map(|_| T::fory_read(context, elem_ref_mode, true)) .collect::>() diff --git a/rust/fory-core/src/serializer/map.rs b/rust/fory-core/src/serializer/map.rs index 4e90a1c3d9..f8c7997c07 100644 --- a/rust/fory-core/src/serializer/map.rs +++ b/rust/fory-core/src/serializer/map.rs @@ -34,6 +34,16 @@ const TRACKING_VALUE_REF: u8 = 0b1000; pub const VALUE_NULL: u8 = 0b10000; pub const DECL_VALUE_TYPE: u8 = 0b100000; +fn check_map_len(context: &ReadContext, len: u32) -> Result<(), Error> { + let len = len as usize; + let remaining = context.reader.slice_after_cursor().len(); + if len > remaining { + let cursor = context.reader.get_cursor(); + return Err(Error::buffer_out_of_bound(cursor, len, cursor + remaining)); + } + Ok(()) +} + fn write_chunk_size(context: &mut WriteContext, header_offset: usize, size: u8) { context.writer.set_bytes(header_offset + 1, &[size]); } @@ -547,10 +557,10 @@ impl Result { let len = context.reader.read_varuint32()?; - let mut map = HashMap::::with_capacity(len as usize); if len == 0 { - return Ok(map); + return Ok(HashMap::new()); } + check_map_len(context, len)?; if K::fory_is_polymorphic() || K::fory_is_shared_ref() || V::fory_is_polymorphic() @@ -559,6 +569,7 @@ impl = HashMap::with_capacity(len as usize); return read_hashmap_data_dyn_ref(context, map, len); } + let mut map = HashMap::::with_capacity(len as usize); let mut len_counter = 0; loop { if len_counter == len { @@ -698,10 +709,11 @@ impl Result { let len = context.reader.read_varuint32()?; - let mut map = BTreeMap::::new(); if len == 0 { - return Ok(map); + return Ok(BTreeMap::new()); } + check_map_len(context, len)?; + let mut map = BTreeMap::::new(); if K::fory_is_polymorphic() || K::fory_is_shared_ref() || V::fory_is_polymorphic() diff --git a/rust/fory-core/src/serializer/primitive_list.rs b/rust/fory-core/src/serializer/primitive_list.rs index 2fc8029e32..cd381a7dbf 100644 --- a/rust/fory-core/src/serializer/primitive_list.rs +++ b/rust/fory-core/src/serializer/primitive_list.rs @@ -83,6 +83,15 @@ pub fn fory_read_data(context: &mut ReadContext) -> Result if size_bytes % std::mem::size_of::() != 0 { return Err(Error::invalid_data("Invalid data length")); } + let remaining = context.reader.slice_after_cursor().len(); + if size_bytes > remaining { + let cursor = context.reader.get_cursor(); + return Err(Error::buffer_out_of_bound( + cursor, + size_bytes, + cursor + remaining, + )); + } let len = size_bytes / std::mem::size_of::(); let mut vec: Vec = Vec::with_capacity(len); diff --git a/rust/fory-core/src/serializer/skip.rs b/rust/fory-core/src/serializer/skip.rs index b96d26e82c..9bfa9e8590 100644 --- a/rust/fory-core/src/serializer/skip.rs +++ b/rust/fory-core/src/serializer/skip.rs @@ -182,6 +182,9 @@ fn skip_collection(context: &mut ReadContext, field_type: &FieldType) -> Result< let is_same_type = (header & IS_SAME_TYPE) != 0; let skip_ref_flag = is_same_type && !has_null; let is_declared = (header & DECL_ELEMENT_TYPE) != 0; + if field_type.generics.is_empty() { + return Err(Error::invalid_data("empty generics")); + } let default_elem_type = field_type.generics.first().unwrap(); let (type_info, elem_field_type); let elem_type = if is_same_type && !is_declared { @@ -213,6 +216,9 @@ fn skip_map(context: &mut ReadContext, field_type: &FieldType) -> Result<(), Err return Ok(()); } let mut len_counter = 0; + if field_type.generics.len() < 2 { + return Err(Error::invalid_data("map must have at least 2 generics")); + } let default_key_type = field_type.generics.first().unwrap(); let default_value_type = field_type.generics.get(1).unwrap(); loop {