Check key and value validity in PbMap (#1076)

Currently `PbMap` does not check ranges of values added as keys or values.

This allows serializing and deserializing a message with a map and getting a
different map back, because `PbMap` allows adding values that are out of range
(e.g. an integer larger than 32-bit range as a `sfixed32` value) but when
serializing it truncates the value.

Update `PbMap` to use the same validation functions as `PbList` when adding
elements. Both keys and values are checked.

Fixes #1065.
diff --git a/protobuf/CHANGELOG.md b/protobuf/CHANGELOG.md
index 64dc4c3..219a13c 100644
--- a/protobuf/CHANGELOG.md
+++ b/protobuf/CHANGELOG.md
@@ -13,6 +13,9 @@
   be replaced by `Map`.
 
   For immutable lists and maps, you can use `built_value`. ([#1072])
+  
+* Map fields now check key and value validity when adding elements. ([#1065],
+  [#1076])
 
 [text format]: https://protobuf.dev/reference/protobuf/textformat-spec/
 [#1080]: https://github.com/google/protobuf.dart/pull/1080
@@ -20,6 +23,8 @@
 [wkts]: https://protobuf.dev/reference/protobuf/google.protobuf
 [#1081]: https://github.com/google/protobuf.dart/pull/1081
 [#1072]: https://github.com/google/protobuf.dart/pull/1072
+[#1065]: https://github.com/google/protobuf.dart/issues/1065
+[#1076]: https://github.com/google/protobuf.dart/pull/1076
 
 ## 5.1.0
 
diff --git a/protobuf/lib/src/protobuf/pb_list.dart b/protobuf/lib/src/protobuf/pb_list.dart
index 137ad55..032b487 100644
--- a/protobuf/lib/src/protobuf/pb_list.dart
+++ b/protobuf/lib/src/protobuf/pb_list.dart
@@ -8,11 +8,6 @@
 import 'internal.dart';
 import 'utils.dart';
 
-/// Type of a function that checks items added to a `PbList`.
-///
-/// Throws [ArgumentError] or [RangeError] when the item is not valid.
-typedef CheckFunc<E> = void Function(E? x);
-
 @pragma('dart2js:tryInline')
 @pragma('vm:prefer-inline')
 @pragma('wasm:prefer-inline')
@@ -50,7 +45,7 @@
 
   PbList._unmodifiable()
     : _wrappedList = _emptyList,
-      _check = checkNotNull,
+      _check = null,
       _isReadOnly = true;
 
   @override
diff --git a/protobuf/lib/src/protobuf/pb_map.dart b/protobuf/lib/src/protobuf/pb_map.dart
index 0200f4a..eb55a79 100644
--- a/protobuf/lib/src/protobuf/pb_map.dart
+++ b/protobuf/lib/src/protobuf/pb_map.dart
@@ -14,7 +14,12 @@
 @pragma('vm:prefer-inline')
 @pragma('wasm:prefer-inline')
 PbMap<K, V> newPbMap<K, V>(int keyFieldType, int valueFieldType) =>
-    PbMap<K, V>._(keyFieldType, valueFieldType);
+    PbMap<K, V>._(
+      keyFieldType,
+      valueFieldType,
+      getCheckFunction(keyFieldType),
+      getCheckFunction(valueFieldType),
+    );
 
 @pragma('dart2js:tryInline')
 @pragma('vm:prefer-inline')
@@ -46,11 +51,21 @@
 
   bool _isReadOnly = false;
 
-  PbMap._(this.keyFieldType, this.valueFieldType) : _wrappedMap = <K, V>{};
+  final CheckFunc<K>? _checkKey;
+  final CheckFunc<V>? _checkValue;
+
+  PbMap._(
+    this.keyFieldType,
+    this.valueFieldType,
+    this._checkKey,
+    this._checkValue,
+  ) : _wrappedMap = <K, V>{};
 
   PbMap._unmodifiable(this.keyFieldType, this.valueFieldType)
     : _wrappedMap = <K, V>{},
-      _isReadOnly = true;
+      _isReadOnly = true,
+      _checkKey = null,
+      _checkValue = null;
 
   @override
   V? operator [](Object? key) => _wrappedMap[key];
@@ -60,13 +75,17 @@
     if (_isReadOnly) {
       throw UnsupportedError('Attempted to change a read-only map field');
     }
-    ArgumentError.checkNotNull(key, 'key');
-    ArgumentError.checkNotNull(value, 'value');
+    if (_checkKey != null) {
+      _checkKey(key);
+    }
+    if (_checkValue != null) {
+      _checkValue(value);
+    }
     _wrappedMap[key] = value;
   }
 
-  /// A [PbMap] is equal to another [PbMap] with equal key/value
-  /// pairs in any order.
+  /// A [PbMap] is equal to another [PbMap] with equal key/value pairs in any
+  /// order.
   @override
   bool operator ==(Object other) {
     if (identical(other, this)) {
@@ -86,8 +105,8 @@
     return true;
   }
 
-  /// A [PbMap] is equal to another [PbMap] with equal key/value
-  /// pairs in any order. Then, the `hashCode` is guaranteed to be the same.
+  /// A [PbMap] is equal to another [PbMap] with equal key/value pairs in any
+  /// order. Then, the `hashCode` is guaranteed to be the same.
   @override
   int get hashCode {
     return _wrappedMap.entries.fold(
@@ -126,7 +145,12 @@
   }
 
   PbMap<K, V> _deepCopy() {
-    final newMap = PbMap<K, V>._(keyFieldType, valueFieldType);
+    final newMap = PbMap<K, V>._(
+      keyFieldType,
+      valueFieldType,
+      _checkKey,
+      _checkValue,
+    );
     final wrappedMap = _wrappedMap;
     final newWrappedMap = newMap._wrappedMap;
     if (PbFieldType.isGroupOrMessage(valueFieldType)) {
diff --git a/protobuf/lib/src/protobuf/utils.dart b/protobuf/lib/src/protobuf/utils.dart
index e52440d..6247e86 100644
--- a/protobuf/lib/src/protobuf/utils.dart
+++ b/protobuf/lib/src/protobuf/utils.dart
@@ -7,6 +7,11 @@
 import 'internal.dart';
 import 'json_parsing_context.dart';
 
+/// Type of a function that checks items added to `PbList` and `PbMap`.
+///
+/// Throws [ArgumentError] or [RangeError] when the item is not valid.
+typedef CheckFunc<E> = void Function(E? x);
+
 // TODO(antonm): reconsider later if PbList should take care of equality.
 bool deepEquals(Object? lhs, Object? rhs) {
   // Some GeneratedMessages implement Map, so test this first.
diff --git a/protoc_plugin/test/map_field_test.dart b/protoc_plugin/test/map_field_test.dart
index 7667bca..3514098 100644
--- a/protoc_plugin/test/map_field_test.dart
+++ b/protoc_plugin/test/map_field_test.dart
@@ -10,6 +10,11 @@
 import 'gen/map_field.pb.dart';
 
 void main() {
+  int int32ToEnumFieldTag =
+      TestMap().info_.byName['int32ToEnumField']!.tagNumber;
+  int int32ToMessageFieldTag =
+      TestMap().info_.byName['int32ToMessageField']!.tagNumber;
+
   void setValues(TestMap testMap) {
     testMap
       ..int32ToInt32Field[1] = 11
@@ -408,7 +413,9 @@
     // that we handle 0 length fields. (#719)
     {
       final messageBytes = <int>[
-        (5 << 3) | 2, // tag = 5, wire type = 2 (length delimited)
+        ...varint32Bytes(
+          (int32ToMessageFieldTag << 3) | 2, // wire type = 2 (length delimited)
+        ),
         0, // length = 0
       ];
       final message = TestMap.fromBuffer(messageBytes);
@@ -420,7 +427,9 @@
 
     {
       final messageBytes = <int>[
-        (4 << 3) | 2, // tag = 4, wire type = 2 (length delimited)
+        ...varint32Bytes(
+          (int32ToEnumFieldTag << 3) | 2, // wire type = 2 (length delimited)
+        ),
         0, // length = 0
       ];
       final message = TestMap.fromBuffer(messageBytes);
@@ -435,7 +444,9 @@
     // Similar to the case above, but the field just has key (no value)
     {
       final messageBytes = <int>[
-        (5 << 3) | 2, // tag = 5, wire type = 2 (length delimited)
+        ...varint32Bytes(
+          (int32ToMessageFieldTag << 3) | 2, // wire type = 2 (length delimited)
+        ),
         2, // length = 2
         (1 << 3) | 0, // tag = 1 (map key), wire type = 0 (varint)
         1, // key = 1
@@ -449,7 +460,9 @@
 
     {
       final messageBytes = <int>[
-        (4 << 3) | 2, // tag = 4, wire type = 2 (length delimited)
+        ...varint32Bytes(
+          (int32ToEnumFieldTag << 3) | 2, // wire type = 2 (length delimited)
+        ),
         2, // length = 2
         (1 << 3) | 0, // tag = 1 (map key), wire type = 0 (varint)
         1, // key = 1
@@ -466,7 +479,9 @@
     // Similar to the case above, but the field just has value (no key)
     {
       final messageBytes = <int>[
-        (5 << 3) | 2, // tag = 5, wire type = 2 (length delimited)
+        ...varint32Bytes(
+          (int32ToMessageFieldTag << 3) | 2, // wire type = 2 (length delimited)
+        ),
         2, // length = 2
         (2 << 3) | 2, // tag = 2 (map value), wire type = 2 (length delimited)
         0, // length = 0 (empty message)
@@ -480,7 +495,9 @@
 
     {
       final messageBytes = <int>[
-        (4 << 3) | 2, // tag = 4, wire type = 2 (length delimited)
+        ...varint32Bytes(
+          (int32ToEnumFieldTag << 3) | 2, // wire type = 2 (length delimited)
+        ),
         2, // length = 2
         (2 << 3) | 2, // tag = 2 (map value), wire type = 2 (length delimited)
         1, // enum value = 1
@@ -497,3 +514,13 @@
     }, throwsA(const TypeMatcher<UnsupportedError>()));
   });
 }
+
+List<int> varint32Bytes(int value) {
+  List<int> output = [];
+  while (value >= 0x80) {
+    output.add(0x80 | (value & 0x7f));
+    value >>= 7;
+  }
+  output.add(value);
+  return output;
+}
diff --git a/protoc_plugin/test/protos/map_field.proto b/protoc_plugin/test/protos/map_field.proto
index 8236376..eb921eb 100644
--- a/protoc_plugin/test/protos/map_field.proto
+++ b/protoc_plugin/test/protos/map_field.proto
@@ -20,14 +20,34 @@
     }
 
     map<int32, int32>        int32_to_int32_field = 1;
-    map<int32, string>       int32_to_string_field = 2;
-    map<int32, bytes>        int32_to_bytes_field = 3;
-    map<int32, EnumValue>    int32_to_enum_field = 4;
-    map<int32, MessageValue> int32_to_message_field = 5;
-    map<string, int32>       string_to_int32_field = 6;
-    map<uint32, int32>       uint32_to_int32_field = 7;
-    map<int64, int32>        int64_to_int32_field = 8;
-    map<uint64, int32>       uint64_to_int32_field = 9;
+    map<int32, int64>        int32_to_int64_field = 2;
+    map<int32, uint32>       int32_to_uint32_field = 3;
+    map<int32, uint64>       int32_to_uint64_field = 4;
+    map<int32, sint32>       int32_to_sint32_field = 5;
+    map<int32, sint64>       int32_to_sint64_field = 6;
+    map<int32, fixed32>      int32_to_fixed32_field = 7;
+    map<int32, fixed64>      int32_to_fixed64_field = 8;
+    map<int32, sfixed32>     int32_to_sfixed32_field = 9;
+    map<int32, sfixed64>     int32_to_sfixed64_field = 10;
+    map<int32, float>        int32_to_float_field = 11;
+    map<int32, double>       int32_to_double_field = 12;
+    map<int32, bool>         int32_to_bool_field = 13;
+    map<int32, string>       int32_to_string_field = 14;
+    map<int32, bytes>        int32_to_bytes_field = 15;
+    map<int32, EnumValue>    int32_to_enum_field = 16;
+    map<int32, MessageValue> int32_to_message_field = 17;
+
+    map<int64, int32>        int64_to_int32_field = 18;
+    map<uint32, int32>       uint32_to_int32_field = 19;
+    map<uint64, int32>       uint64_to_int32_field = 20;
+    map<sint32, int32>       sint32_to_int32_field = 21;
+    map<sint64, int32>       sint64_to_int32_field = 22;
+    map<fixed32, int32>      fixed32_to_int32_field = 23;
+    map<fixed64, int32>      fixed64_to_int32_field = 24;
+    map<sfixed32, int32>     sfixed32_to_int32_field = 25;
+    map<sfixed64, int32>     sfixed64_to_int32_field = 26;
+    map<bool, int32>         bool_to_int32_field = 27;
+    map<string, int32>       string_to_int32_field = 28;
 }
 
 message Inner  {
@@ -47,6 +67,6 @@
         optional string key = 1;
         optional int32 value = 2;
     }
-    repeated Int32ToString int32_to_string_field = 2;
-    repeated StringToInt32 string_to_int32_field = 6;
+    repeated Int32ToString int32_to_string_field = 14;
+    repeated StringToInt32 string_to_int32_field = 28;
 }
diff --git a/protoc_plugin/test/validate_fail_test.dart b/protoc_plugin/test/validate_fail_test.dart
index ed90e7f..a09d679 100644
--- a/protoc_plugin/test/validate_fail_test.dart
+++ b/protoc_plugin/test/validate_fail_test.dart
@@ -5,6 +5,7 @@
 import 'package:test/test.dart';
 
 import 'gen/google/protobuf/unittest.pb.dart';
+import 'gen/map_field.pb.dart';
 
 const int minI32 = -2147483648;
 
@@ -465,4 +466,158 @@
       }, throwsArgumentError);
     }
   });
+
+  test('Maps validate keys', () {
+    // Nullability and type checks
+    expect(() {
+      (TestMap() as dynamic).int32ToInt32Field[null] = 0;
+    }, throwsTypeError);
+
+    expect(() {
+      (TestMap() as dynamic).boolToInt32Field[null] = 0;
+    }, throwsTypeError);
+
+    // int32
+    TestMap().int32ToInt32Field[minI32] = 0;
+    TestMap().int32ToInt32Field[-123] = 0;
+    TestMap().int32ToInt32Field[maxI32] = 0;
+    TestMap().int32ToInt32Field[123] = 0;
+
+    expect(() {
+      TestMap().int32ToInt32Field[minI32 - 1] = 0;
+    }, throwsArgumentError);
+
+    expect(() {
+      TestMap().int32ToInt32Field[maxI32 + 1] = 0;
+    }, throwsArgumentError);
+
+    // sint32
+    TestMap().sint32ToInt32Field[minI32] = 0;
+    TestMap().sint32ToInt32Field[-123] = 0;
+    TestMap().sint32ToInt32Field[maxI32] = 0;
+    TestMap().sint32ToInt32Field[123] = 0;
+
+    expect(() {
+      TestMap().sint32ToInt32Field[minI32 - 1] = 0;
+    }, throwsArgumentError);
+
+    expect(() {
+      TestMap().sint32ToInt32Field[maxI32 + 1] = 0;
+    }, throwsArgumentError);
+
+    // sfixed32
+    TestMap().sfixed32ToInt32Field[minI32] = 0;
+    TestMap().sfixed32ToInt32Field[-123] = 0;
+    TestMap().sfixed32ToInt32Field[maxI32] = 0;
+    TestMap().sfixed32ToInt32Field[123] = 0;
+
+    expect(() {
+      TestMap().sfixed32ToInt32Field[minI32 - 1] = 0;
+    }, throwsArgumentError);
+
+    expect(() {
+      TestMap().sfixed32ToInt32Field[maxI32 + 1] = 0;
+    }, throwsArgumentError);
+
+    // uint32
+    TestMap().uint32ToInt32Field[maxU32] = 0;
+    TestMap().uint32ToInt32Field[123] = 0;
+
+    expect(() {
+      TestMap().uint32ToInt32Field[-1] = 0;
+    }, throwsArgumentError);
+
+    expect(() {
+      TestMap().uint32ToInt32Field[maxU32 + 1] = 0;
+    }, throwsArgumentError);
+
+    // fixed32
+    TestMap().fixed32ToInt32Field[maxU32] = 0;
+    TestMap().fixed32ToInt32Field[123] = 0;
+
+    expect(() {
+      TestMap().fixed32ToInt32Field[-1] = 0;
+    }, throwsArgumentError);
+
+    expect(() {
+      TestMap().fixed32ToInt32Field[maxU32 + 1] = 0;
+    }, throwsArgumentError);
+  });
+
+  test('Maps validate values', () {
+    // Nullability and type checks
+    expect(() {
+      (TestMap() as dynamic).int32ToInt32Field[0] = null;
+    }, throwsTypeError);
+
+    expect(() {
+      (TestMap() as dynamic).int32ToBoolField[0] = null;
+    }, throwsTypeError);
+
+    // int32
+    TestMap().int32ToInt32Field[0] = minI32;
+    TestMap().int32ToInt32Field[0] = -123;
+    TestMap().int32ToInt32Field[0] = maxI32;
+    TestMap().int32ToInt32Field[0] = 123;
+
+    expect(() {
+      TestMap().int32ToInt32Field[0] = minI32 - 1;
+    }, throwsArgumentError);
+
+    expect(() {
+      TestMap().int32ToInt32Field[0] = maxI32 + 1;
+    }, throwsArgumentError);
+
+    // sint32
+    TestMap().int32ToSint32Field[0] = minI32;
+    TestMap().int32ToSint32Field[0] = -123;
+    TestMap().int32ToSint32Field[0] = maxI32;
+    TestMap().int32ToSint32Field[0] = 123;
+
+    expect(() {
+      TestMap().int32ToSint32Field[0] = minI32 - 1;
+    }, throwsArgumentError);
+
+    expect(() {
+      TestMap().int32ToSint32Field[0] = maxI32 + 1;
+    }, throwsArgumentError);
+
+    // sfixed32
+    TestMap().int32ToSfixed32Field[0] = minI32;
+    TestMap().int32ToSfixed32Field[0] = -123;
+    TestMap().int32ToSfixed32Field[0] = maxI32;
+    TestMap().int32ToSfixed32Field[0] = 123;
+
+    expect(() {
+      TestMap().int32ToSfixed32Field[0] = minI32 - 1;
+    }, throwsArgumentError);
+
+    expect(() {
+      TestMap().int32ToSfixed32Field[0] = maxI32 + 1;
+    }, throwsArgumentError);
+
+    // uint32
+    TestMap().int32ToUint32Field[0] = maxU32;
+    TestMap().int32ToUint32Field[0] = 123;
+
+    expect(() {
+      TestMap().int32ToUint32Field[0] = -1;
+    }, throwsArgumentError);
+
+    expect(() {
+      TestMap().int32ToUint32Field[0] = maxU32 + 1;
+    }, throwsArgumentError);
+
+    // fixed32
+    TestMap().int32ToFixed32Field[0] = maxU32;
+    TestMap().int32ToFixed32Field[0] = 123;
+
+    expect(() {
+      TestMap().int32ToFixed32Field[0] = -1;
+    }, throwsArgumentError);
+
+    expect(() {
+      TestMap().int32ToFixed32Field[0] = maxU32 + 1;
+    }, throwsArgumentError);
+  });
 }