[dart2wasm] Convert UTF-8 chunks to U8List before decoding
Currently `convertSingle` converts the input to `U8List`, but
`convertChunked` works on `Uint8List`. This makes functions common in
both (`decode8`, `decode16`) polymorphic in input.
Update `convertChunked` to also convert the input to `U8List`. With this
`decode8` and `decode16` becomes monomorphic in the input type. Also
update array accesses in these methods to avoid bounds checks.
Check for a few fast cases in `List<int>` to `U8List` copying. If the
list is a `WasmI8ArrayBase` (used in typed data) or `WasmListBase` (used
in lists), we avoid polymorphism, indirections, and bounds checks during
copying.
Golem reports up to 600% improvement in some chunked parsing micro-
benchmarks.
Change-Id: Iddf6dae1a5d77cf574be77313dff779b4715e283
Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/395980
Commit-Queue: Ömer Ağacan <omersa@google.com>
Reviewed-by: Slava Egorov <vegorov@google.com>
diff --git a/sdk/lib/_internal/wasm/lib/convert_patch.dart b/sdk/lib/_internal/wasm/lib/convert_patch.dart
index dff3eae..a94e682 100644
--- a/sdk/lib/_internal/wasm/lib/convert_patch.dart
+++ b/sdk/lib/_internal/wasm/lib/convert_patch.dart
@@ -8,7 +8,8 @@
 import "dart:_js_string_convert";
 import "dart:_js_types";
 import "dart:_js_helper" show jsStringToDartString;
-import "dart:_list" show GrowableList, WasmListBaseUnsafeExtensions;
+import "dart:_list"
+    show GrowableList, WasmListBaseUnsafeExtensions, WasmListBase;
 import "dart:_string";
 import "dart:_typed_data";
 import "dart:_wasm";
@@ -1987,11 +1988,11 @@
     _bomIndex = -1;
   }
 
-  int scan(Uint8List bytes, int start, int end) {
+  int scan(U8List bytes, int start, int end) {
     int size = 0;
     int flags = 0;
     for (int i = start; i < end; i++) {
-      int t = scanTable.readUnsigned(bytes[i]);
+      int t = scanTable.readUnsigned(bytes.getUnchecked(i));
       size += t & sizeMask;
       flags |= t;
     }
@@ -2104,14 +2105,13 @@
   String convertChunked(List<int> codeUnits, int start, int? maybeEnd) {
     int end = RangeError.checkValidRange(start, maybeEnd, codeUnits.length);
 
-    // Have bytes as Uint8List.
-    Uint8List bytes;
+    final U8List bytes;
     int errorOffset;
-    if (codeUnits is Uint8List) {
-      bytes = unsafeCast<Uint8List>(codeUnits);
+    if (codeUnits is U8List) {
+      bytes = unsafeCast<U8List>(codeUnits);
       errorOffset = 0;
     } else {
-      bytes = _makeUint8List(codeUnits, start, end);
+      bytes = _makeU8List(codeUnits, start, end);
       errorOffset = start;
       end -= start;
       start = 0;
@@ -2205,17 +2205,17 @@
     return result;
   }
 
-  int skipBomSingle(Uint8List bytes, int start, int end) {
+  int skipBomSingle(U8List bytes, int start, int end) {
     if (end - start >= 3 &&
-        bytes[start] == 0xEF &&
-        bytes[start + 1] == 0xBB &&
-        bytes[start + 2] == 0xBF) {
+        bytes.getUnchecked(start) == 0xEF &&
+        bytes.getUnchecked(start + 1) == 0xBB &&
+        bytes.getUnchecked(start + 2) == 0xBF) {
       return start + 3;
     }
     return start;
   }
 
-  int skipBomChunked(Uint8List bytes, int start, int end) {
+  int skipBomChunked(U8List bytes, int start, int end) {
     assert(start <= end);
     int bomIndex = _bomIndex;
     // Already skipped?
@@ -2229,7 +2229,7 @@
         _bomIndex = bomIndex;
         return start;
       }
-      if (bytes[i++] != bomValues[bomIndex++]) {
+      if (bytes.getUnchecked(i++) != bomValues[bomIndex++]) {
         // No BOM.
         _bomIndex = -1;
         return start;
@@ -2241,7 +2241,7 @@
     return i;
   }
 
-  String decode8(Uint8List bytes, int start, int end, int size) {
+  String decode8(U8List bytes, int start, int end, int size) {
     assert(start < end);
     OneByteString result = OneByteString.withLength(size);
     int i = start;
@@ -2249,7 +2249,7 @@
     if (_state == X1) {
       // Half-way though 2-byte sequence
       assert(_charOrIndex == 2 || _charOrIndex == 3);
-      final int e = bytes[i++] ^ 0x80;
+      final int e = bytes.getUnchecked(i++) ^ 0x80;
       if (e >= 0x40) {
         _state = errorMissingExtension;
         _charOrIndex = i - 1;
@@ -2260,7 +2260,7 @@
     }
     assert(_state == accept);
     while (i < end) {
-      int byte = bytes[i++];
+      int byte = bytes.getUnchecked(i++);
       if (byte >= 0x80) {
         if (byte < 0xC0) {
           _state = errorUnexpectedExtension;
@@ -2273,7 +2273,7 @@
           _charOrIndex = byte & 0x1F;
           break;
         }
-        final int e = bytes[i++] ^ 0x80;
+        final int e = bytes.getUnchecked(i++) ^ 0x80;
         if (e >= 0x40) {
           _state = errorMissingExtension;
           _charOrIndex = i - 1;
@@ -2293,7 +2293,7 @@
     return result;
   }
 
-  String decode16(Uint8List bytes, int start, int end, int size) {
+  String decode16(U8List bytes, int start, int end, int size) {
     assert(start < end);
     final OneByteString transitionTable = unsafeCast<OneByteString>(
       _Utf8Decoder.transitionTable,
@@ -2309,7 +2309,7 @@
 
     // First byte
     assert(!isErrorState(state));
-    final int byte = bytes[i++];
+    final int byte = bytes.getUnchecked(i++);
     final int type = typeTable.codeUnitAtUnchecked(byte) & typeMask;
     if (state == accept) {
       char = byte & (shiftedByteMask >> type);
@@ -2320,7 +2320,7 @@
     }
 
     while (i < end) {
-      final int byte = bytes[i++];
+      final int byte = bytes.getUnchecked(i++);
       final int type = typeTable.codeUnitAtUnchecked(byte) & typeMask;
       if (state == accept) {
         if (char >= 0x10000) {
@@ -2369,3 +2369,74 @@
     return result;
   }
 }
+
+U8List _makeU8List(List<int> codeUnits, int start, int end) {
+  if (codeUnits is WasmListBase) {
+    return _makeU8ListFromWasmListBase(
+      unsafeCast<WasmListBase<int>>(codeUnits),
+      start,
+      end,
+    );
+  }
+
+  if (codeUnits is WasmI8ArrayBase) {
+    return _makeU8ListFromWasmI8ArrayBase(
+      unsafeCast<WasmI8ArrayBase>(codeUnits),
+      start,
+      end,
+    );
+  }
+
+  final int length = end - start;
+  final U8List bytes = U8List(length);
+  for (int i = 0; i < length; i++) {
+    int b = codeUnits[start + i];
+    if ((b & ~0xFF) != 0) {
+      // Replace invalid byte values by FF, which is also invalid.
+      b = 0xFF;
+    }
+    bytes.setUnchecked(i, b);
+  }
+  return bytes;
+}
+
+U8List _makeU8ListFromWasmListBase(
+  WasmListBase<int> codeUnits,
+  int start,
+  int end,
+) {
+  final int length = end - start;
+  final U8List bytes = U8List(length);
+  final WasmArray<Object?> listData = codeUnits.data;
+  final WasmArray<WasmI8> bytesData = bytes.data;
+  for (int i = 0; i < length; i++) {
+    int b = unsafeCast<int>(listData[start + i]);
+    if ((b & ~0xFF) != 0) {
+      // Replace invalid byte values by FF, which is also invalid.
+      b = 0xFF;
+    }
+    bytesData.write(i, b);
+  }
+  return bytes;
+}
+
+U8List _makeU8ListFromWasmI8ArrayBase(
+  WasmI8ArrayBase codeUnits,
+  int start,
+  int end,
+) {
+  final int length = end - start;
+  final U8List bytes = U8List(length);
+  final WasmArray<WasmI8> listData = codeUnits.data;
+  final listDataOffset = codeUnits.offsetInBytes;
+  final WasmArray<WasmI8> bytesData = bytes.data;
+  for (int i = 0; i < length; i++) {
+    int b = listData.readSigned(listDataOffset + start + i);
+    if ((b & ~0xFF) != 0) {
+      // Replace invalid byte values by FF, which is also invalid.
+      b = 0xFF;
+    }
+    bytesData.write(i, b);
+  }
+  return bytes;
+}
diff --git a/sdk/lib/_internal/wasm/lib/typed_data.dart b/sdk/lib/_internal/wasm/lib/typed_data.dart
index ced38d4..deef310 100644
--- a/sdk/lib/_internal/wasm/lib/typed_data.dart
+++ b/sdk/lib/_internal/wasm/lib/typed_data.dart
@@ -1976,8 +1976,8 @@
       final fromTypedData = unsafeCast<JSIntegerArrayBase>(from);
 
       final fromElementSize = fromTypedData.elementSizeInBytes;
-      if (fromElementSize == 1 && this is _WasmI8ArrayBase) {
-        final destTypedData = unsafeCast<_WasmI8ArrayBase>(this);
+      if (fromElementSize == 1 && this is WasmI8ArrayBase) {
+        final destTypedData = unsafeCast<WasmI8ArrayBase>(this);
         copyToWasmI8Array(
           fromTypedData.toJSArrayExternRef()!,
           skipCount,
@@ -2565,16 +2565,16 @@
 // Fast lists
 //
 
-abstract class _WasmI8ArrayBase extends WasmTypedDataBase {
+abstract class WasmI8ArrayBase extends WasmTypedDataBase {
   final WasmArray<WasmI8> _data;
   final int _offsetInElements;
   final int length;
 
-  _WasmI8ArrayBase(this.length)
+  WasmI8ArrayBase(this.length)
     : _data = WasmArray(_newArrayLengthCheck(length)),
       _offsetInElements = 0;
 
-  _WasmI8ArrayBase._(this._data, this._offsetInElements, this.length);
+  WasmI8ArrayBase._(this._data, this._offsetInElements, this.length);
 
   int get elementSizeInBytes => 1;
 
@@ -2679,7 +2679,7 @@
   _F64ByteBuffer get buffer => _F64ByteBuffer(_data);
 }
 
-extension WasmI8ArrayBaseExt on _WasmI8ArrayBase {
+extension WasmI8ArrayBaseExt on WasmI8ArrayBase {
   @pragma('wasm:prefer-inline')
   WasmArray<WasmI8> get data => _data;
 
@@ -2719,7 +2719,7 @@
   int get offsetInElements => _offsetInElements;
 }
 
-class I8List extends _WasmI8ArrayBase
+class I8List extends WasmI8ArrayBase
     with
         _IntListMixin,
         _TypedIntListMixin<I8List>,
@@ -2761,7 +2761,7 @@
   }
 }
 
-class U8List extends _WasmI8ArrayBase
+class U8List extends WasmI8ArrayBase
     with
         _IntListMixin,
         _TypedIntListMixin<U8List>,
@@ -2769,7 +2769,7 @@
     implements Uint8List {
   U8List(int length) : super(length);
 
-  U8List._(WasmArray<WasmI8> data, int offsetInElements, int length)
+  U8List.withData(WasmArray<WasmI8> data, int offsetInElements, int length)
     : super._(data, offsetInElements, length);
 
   factory U8List._withMutability(
@@ -2779,7 +2779,7 @@
     bool mutable,
   ) =>
       mutable
-          ? U8List._(buffer, offsetInBytes, length)
+          ? U8List.withData(buffer, offsetInBytes, length)
           : UnmodifiableU8List._(buffer, offsetInBytes, length);
 
   @override
@@ -2792,18 +2792,28 @@
   @pragma("wasm:prefer-inline")
   int operator [](int index) {
     indexCheck(index, length);
-    return _data.readUnsigned(_offsetInElements + index);
+    return getUnchecked(index);
   }
 
   @override
   @pragma("wasm:prefer-inline")
   void operator []=(int index, int value) {
     indexCheck(index, length);
+    setUnchecked(index, value);
+  }
+}
+
+extension U8ListUncheckedOperations on U8List {
+  @pragma("wasm:prefer-inline")
+  int getUnchecked(int index) => _data.readUnsigned(_offsetInElements + index);
+
+  @pragma("wasm:prefer-inline")
+  void setUnchecked(int index, int value) {
     _data.write(_offsetInElements + index, value);
   }
 }
 
-class U8ClampedList extends _WasmI8ArrayBase
+class U8ClampedList extends WasmI8ArrayBase
     with
         _IntListMixin,
         _TypedIntListMixin<U8ClampedList>,
@@ -3200,10 +3210,10 @@
 
 class UnmodifiableU8List extends U8List with _UnmodifiableIntListMixin {
   UnmodifiableU8List(U8List list)
-    : super._(list._data, list._offsetInElements, list.length);
+    : super.withData(list._data, list._offsetInElements, list.length);
 
   UnmodifiableU8List._(WasmArray<WasmI8> data, int offsetInElements, int length)
-    : super._(data, offsetInElements, length);
+    : super.withData(data, offsetInElements, length);
 
   @override
   @pragma('wasm:prefer-inline')
diff --git a/tests/lib/convert/negative_utf8_codeunit_test.dart b/tests/lib/convert/negative_utf8_codeunit_test.dart
new file mode 100644
index 0000000..bfa3dff
--- /dev/null
+++ b/tests/lib/convert/negative_utf8_codeunit_test.dart
@@ -0,0 +1,62 @@
+// Copyright (c) 2024, the Dart project authors.  Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+import "dart:convert";
+import "dart:typed_data";
+
+import "package:expect/expect.dart";
+
+void main() {
+  // "é"
+  final bytes = [195, 169];
+
+  // Same as `bytes` when interpreted as unsigned bytes.
+  final negativeBytes = [-61, -87];
+
+  final decoded = "é";
+
+  final shouldSucceed = [
+    bytes,
+    Uint8List.fromList(bytes),
+    Uint8List.fromList(negativeBytes),
+  ];
+
+  final shouldFail = [
+    negativeBytes,
+    Int8List.fromList(bytes),
+    Int8List.fromList(negativeBytes),
+  ];
+
+  for (var bytes in shouldSucceed) {
+    Expect.equals(utf8.decoder.convert(bytes), decoded);
+
+    final stringSink = StringSink();
+    utf8.decoder.startChunkedConversion(stringSink)
+      ..add(bytes)
+      ..close();
+    Expect.equals(stringSink.buffer.toString(), decoded);
+  }
+
+  for (var bytes in shouldFail) {
+    Expect.throwsFormatException(() => utf8.decoder.convert(bytes));
+
+    final stringSink = StringSink();
+    Expect.throwsFormatException(
+        () => utf8.decoder.startChunkedConversion(stringSink)
+          ..add(bytes)
+          ..close());
+  }
+}
+
+class StringSink implements Sink<String> {
+  StringBuffer buffer = StringBuffer();
+
+  StringSink();
+
+  void add(String str) {
+    buffer.write(str);
+  }
+
+  void close() {}
+}