Make SplayTreeMap entry values mutable.

No longer used as `MapEntry` directly, so no reason to
not update values in-place and avoid breaking key-based iteration.

Change-Id: I59a89811c783e33b6e4dccd553bc717e03968a9d
Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/396801
Reviewed-by: Stephen Adams <sra@google.com>
Reviewed-by: Martin Kustermann <kustermann@google.com>
diff --git a/sdk/lib/_internal/js_runtime/lib/collection_patch.dart b/sdk/lib/_internal/js_runtime/lib/collection_patch.dart
index 0b8c8ea..b7008ac 100644
--- a/sdk/lib/_internal/js_runtime/lib/collection_patch.dart
+++ b/sdk/lib/_internal/js_runtime/lib/collection_patch.dart
@@ -402,7 +402,7 @@
 base class _CustomHashMap<K, V> extends _HashMap<K, V> {
   final _Equality<K> _equals;
   final _Hasher<K> _hashCode;
-  final _Predicate _validKey;
+  final bool Function(Object?) _validKey;
 
   _CustomHashMap(this._equals, this._hashCode, bool validKey(potentialKey)?)
       : _validKey = (validKey != null) ? validKey : ((v) => v is K);
@@ -569,7 +569,7 @@
 base class _LinkedCustomHashMap<K, V> extends JsLinkedHashMap<K, V> {
   final _Equality<K> _equals;
   final _Hasher<K> _hashCode;
-  final _Predicate _validKey;
+  final bool Function(Object?) _validKey;
 
   _LinkedCustomHashMap(
       this._equals, this._hashCode, bool validKey(potentialKey)?)
@@ -955,7 +955,7 @@
 base class _CustomHashSet<E> extends _HashSet<E> {
   _Equality<E> _equality;
   _Hasher<E> _hasher;
-  _Predicate _validKey;
+  bool Function(Object?) _validKey;
   _CustomHashSet(this._equality, this._hasher, bool validKey(potentialKey)?)
       : _validKey = (validKey != null) ? validKey : ((x) => x is E);
 
@@ -1429,7 +1429,7 @@
 base class _LinkedCustomHashSet<E> extends _LinkedHashSet<E> {
   _Equality<E> _equality;
   _Hasher<E> _hasher;
-  _Predicate _validKey;
+  bool Function(Object?) _validKey;
   _LinkedCustomHashSet(
       this._equality, this._hasher, bool validKey(potentialKey)?)
       : _validKey = (validKey != null) ? validKey : ((x) => x is E);
diff --git a/sdk/lib/_internal/vm_shared/lib/collection_patch.dart b/sdk/lib/_internal/vm_shared/lib/collection_patch.dart
index 700f19d..d14297a 100644
--- a/sdk/lib/_internal/vm_shared/lib/collection_patch.dart
+++ b/sdk/lib/_internal/vm_shared/lib/collection_patch.dart
@@ -262,8 +262,8 @@
 base class _CustomHashMap<K, V> extends _HashMap<K, V> {
   final _Equality<K> _equals;
   final _Hasher<K> _hashCode;
-  final _Predicate _validKey;
-  _CustomHashMap(this._equals, this._hashCode, _Predicate? validKey)
+  final bool Function(Object?) _validKey;
+  _CustomHashMap(this._equals, this._hashCode, bool Function(Object?)? validKey)
     : _validKey = (validKey != null) ? validKey : TypeTest<K>().test;
 
   bool containsKey(Object? key) {
@@ -807,8 +807,8 @@
 base class _CustomHashSet<E> extends _HashSet<E> {
   final _Equality<E> _equality;
   final _Hasher<E> _hasher;
-  final _Predicate _validKey;
-  _CustomHashSet(this._equality, this._hasher, _Predicate? validKey)
+  final bool Function(Object?) _validKey;
+  _CustomHashSet(this._equality, this._hasher, bool Function(Object?)? validKey)
     : _validKey = (validKey != null) ? validKey : TypeTest<E>().test;
 
   bool remove(Object? element) {
diff --git a/sdk/lib/collection/splay_tree.dart b/sdk/lib/collection/splay_tree.dart
index 40378cf..65e3029 100644
--- a/sdk/lib/collection/splay_tree.dart
+++ b/sdk/lib/collection/splay_tree.dart
@@ -4,8 +4,6 @@
 
 part of dart.collection;
 
-typedef _Predicate<T> = bool Function(T value);
-
 /// A node in a splay tree. It holds the sorting key and the left
 /// and right children in the tree.
 class _SplayTreeNode<K, Node extends _SplayTreeNode<K, Node>> {
@@ -27,46 +25,44 @@
 /// A [_SplayTreeNode] that also contains a value.
 class _SplayTreeMapNode<K, V>
     extends _SplayTreeNode<K, _SplayTreeMapNode<K, V>> {
-  final V value;
+  V value;
   _SplayTreeMapNode(K key, this.value) : super(key);
-
-  _SplayTreeMapNode<K, V> _replaceValue(V value) =>
-      _SplayTreeMapNode<K, V>(key, value)
-        .._left = _left
-        .._right = _right;
 }
 
 /// A splay tree is a self-balancing binary search tree.
 ///
 /// It has the additional property that recently accessed elements
-/// are quick to access again.
+/// are expected to be quick to access again.
 /// It performs basic operations such as insertion, look-up and
-/// removal, in O(log(n)) amortized time.
+/// removal, in O(log(n)) expected amortized time.
 abstract class _SplayTree<K, Node extends _SplayTreeNode<K, Node>> {
   // The root node of the splay tree. It will contain either the last
   // element inserted or the last element looked up.
-  Node? get _root;
-  set _root(Node? newValue);
+  abstract Node? _root;
 
   // Number of elements in the splay tree.
   int _count = 0;
 
   /// Counter incremented whenever the keys in the map change.
   ///
-  /// Used to detect concurrent modifications.
+  /// Used to detect concurrent modifications while iterating.
   int _modificationCount = 0;
 
-  /// Counter incremented whenever the tree structure changes.
+  /// Counter incremented whenever the tree structure changes, but keys do not.
   ///
   /// Used to detect that an in-place traversal cannot use
   /// cached information that relies on the tree structure.
   int _splayCount = 0;
 
   /// The comparator that is used for this splay tree.
-  Comparator<K> get _compare;
+  abstract final int Function(K, K) _compare;
 
-  /// The predicate to determine that a given object is a valid key.
-  _Predicate get _validKey;
+  /// The predicate to determine whether a given object is a valid key.
+  ///
+  /// Used by operations which accept [Object?].
+  ///
+  /// If [null], the key must just be a [K].
+  abstract final bool Function(Object?)? _validKey;
 
   /// Perform the splay operation for the given key. Moves the node with
   /// the given key to the top of the tree.  If no node has the given
@@ -74,10 +70,10 @@
   /// tree. This is the simplified top-down splaying algorithm from:
   /// "Self-adjusting Binary Search Trees" by Sleator and Tarjan.
   ///
-  /// Returns the result of comparing the new root of the tree to [key].
+  /// Returns the comparison of the key of the new root of the tree to [key].
   /// Returns -1 if the table is empty.
   int _splay(K key) {
-    var root = _root;
+    final root = _root;
     if (root == null) {
       // Ensure key is compatible with `_compare`.
       _compare(key, key);
@@ -85,12 +81,14 @@
     }
 
     // The right and newTreeRight variables start out null, and are set
-    // after the first move left.  The right node is the destination
+    // after the first move left. The right node is the destination
     // for subsequent left rebalances, and newTreeRight holds the left
-    // child of the final tree.  The newTreeRight variable is set at most
+    // child of the final tree. The newTreeRight variable is set at most
     // once, after the first move left, and is null iff right is null.
     // The left and newTreeLeft variables play the corresponding role for
     // right rebalances.
+    final originalModificationCount = _modificationCount;
+    final originalSplayCount = _splayCount;
     Node? right;
     Node? newTreeRight;
     Node? left;
@@ -98,14 +96,24 @@
     var current = root;
     // Hoist the field read out of the loop.
     var compare = _compare;
-    int comp;
+    int comparison;
     while (true) {
-      comp = compare(current.key, key);
-      if (comp > 0) {
+      comparison = compare(current.key, key);
+      // Extra sanity checks which can only fail if `_compare` accesses the map.
+      assert(
+        originalModificationCount == _modificationCount,
+        throw ConcurrentModificationError(this),
+      );
+      assert(
+        originalSplayCount == _splayCount,
+        throw ConcurrentModificationError(this),
+      );
+
+      if (comparison > 0) {
         var currentLeft = current._left;
         if (currentLeft == null) break;
-        comp = compare(currentLeft.key, key);
-        if (comp > 0) {
+        comparison = compare(currentLeft.key, key);
+        if (comparison > 0) {
           // Rotate right.
           current._left = currentLeft._right;
           currentLeft._right = current;
@@ -115,18 +123,18 @@
         }
         // Link right.
         if (right == null) {
-          // First left rebalance, store the eventual right child
+          // First left rebalance, store the eventual right child.
           newTreeRight = current;
         } else {
           right._left = current;
         }
         right = current;
         current = currentLeft;
-      } else if (comp < 0) {
+      } else if (comparison < 0) {
         var currentRight = current._right;
         if (currentRight == null) break;
-        comp = compare(currentRight.key, key);
-        if (comp < 0) {
+        comparison = compare(currentRight.key, key);
+        if (comparison < 0) {
           // Rotate left.
           current._right = currentRight._left;
           currentRight._left = current;
@@ -136,7 +144,7 @@
         }
         // Link left.
         if (left == null) {
-          // First right rebalance, store the eventual left child
+          // First right rebalance, store the eventual left child.
           newTreeLeft = current;
         } else {
           left._right = current;
@@ -147,6 +155,7 @@
         break;
       }
     }
+
     // Assemble.
     if (left != null) {
       left._right = current._left;
@@ -160,7 +169,7 @@
       _root = current;
       _splayCount++;
     }
-    return comp;
+    return comparison;
   }
 
   // Emulates splaying with a key that is smaller than any in the subtree
@@ -169,14 +178,19 @@
   // in any parent tree or root pointer.
   Node _splayMin(Node node) {
     var current = node;
-    var nextLeft = current._left;
-    while (nextLeft != null) {
-      var left = nextLeft;
-      current._left = left._right;
-      left._right = current;
-      current = left;
-      nextLeft = current._left;
+    var modified = 0;
+    while (true) {
+      var left = current._left;
+      if (left != null) {
+        current._left = left._right;
+        left._right = current;
+        current = left;
+        modified = 1;
+      } else {
+        break;
+      }
     }
+    _splayCount += modified;
     return current;
   }
 
@@ -187,79 +201,74 @@
   // in any parent tree or root pointer.
   Node _splayMax(Node node) {
     var current = node;
-    var nextRight = current._right;
-    while (nextRight != null) {
-      var right = nextRight;
-      current._right = right._left;
-      right._left = current;
-      current = right;
-      nextRight = current._right;
+    var modified = 0;
+    while (true) {
+      var right = current._right;
+      if (right != null) {
+        current._right = right._left;
+        right._left = current;
+        current = right;
+        modified = 1;
+      } else {
+        break;
+      }
     }
+    _splayCount += modified;
     return current;
   }
 
-  Node? _remove(K key) {
-    if (_root == null) return null;
-    int comp = _splay(key);
-    if (comp != 0) return null;
-    var root = _root!;
-    var result = root;
-    var left = root._left;
-    _count--;
-    // assert(_count >= 0);
+  /// Removes the root node.
+  void _removeRoot() {
+    assert(_count > 0);
+    final root = _root!;
+    final left = root._left;
+    final right = root._right;
     if (left == null) {
-      _root = root._right;
+      _root = right;
+    } else if (right == null) {
+      _root = left;
     } else {
-      var right = root._right;
-      // Splay to make sure that the new root has an empty right child.
-      root = _splayMax(left);
-
-      // Insert the original right child as the right child of the new
-      // root.
-      root._right = right;
-      _root = root;
+      // Splay to make sure that the new root has an empty right subtree.
+      // Insert the original right subtree as the right subtree of the new root.
+      _root = _splayMax(left).._right = right;
     }
+    _count--;
     _modificationCount++;
-    return result;
   }
 
-  /// Adds a new root node with the given [key] or [value].
+  /// Adds a new root [node] with a key (and value for a map node).
   ///
-  /// The [comp] value is the result of comparing the existing root's key
-  /// with key.
-  void _addNewRoot(Node node, int comp) {
-    _count++;
+  /// The [comparison] value is the result of comparing the existing root's key
+  /// with the new root node's key.
+  void _addNewRoot(Node node, int comparison) {
+    final root = _root;
+    if (root != null) {
+      assert(_count > 0);
+      if (comparison < 0) {
+        node._left = root;
+        node._right = root._right;
+        root._right = null;
+      } else {
+        node._right = root;
+        node._left = root._left;
+        root._left = null;
+      }
+    }
     _modificationCount++;
-    var root = _root;
-    if (root == null) {
-      _root = node;
-      return;
-    }
-    // assert(_count >= 0);
-    if (comp < 0) {
-      node._left = root;
-      node._right = root._right;
-      root._right = null;
-    } else {
-      node._right = root;
-      node._left = root._left;
-      root._left = null;
-    }
+    _count++;
     _root = node;
   }
 
   Node? get _first {
     var root = _root;
-    if (root == null) return null;
-    _root = _splayMin(root);
-    return _root;
+    if (root != null) _root = root = _splayMin(root);
+    return root;
   }
 
   Node? get _last {
-    var root = _root;
+    final root = _root;
     if (root == null) return null;
-    _root = _splayMax(root);
-    return _root;
+    return _root = _splayMax(root);
   }
 
   void _clear() {
@@ -268,18 +277,30 @@
     _modificationCount++;
   }
 
-  bool _containsKey(Object? key) {
-    return _validKey(key) && _splay(key as dynamic) == 0;
+  /// Checks if key is a [_validKey] and then splays with it.
+  ///
+  /// Returns the new root node only if its key is equal to the [key].
+  Node? _untypedLookup(Object? key) {
+    final isValidKey = _validKey;
+    if (isValidKey == null) {
+      if (key is! K) return null;
+    } else {
+      if (!isValidKey(key)) return null;
+      key as K;
+    }
+    if (_splay(key) == 0) return _root;
+    return null;
   }
 }
 
 int _dynamicCompare(dynamic a, dynamic b) => Comparable.compare(a, b);
 
-Comparator<K> _defaultCompare<K>() {
+int Function(K, K) _defaultCompare<K>() {
   // If K <: Comparable, then we can just use Comparable.compare
-  // with no casts.
+  // with no extra casts. (There are will be internal generic downcasts.)
   Object compare = Comparable.compare;
-  if (compare is Comparator<K>) {
+  if (compare is int Function(K, K)) {
+    // Ensures K <: Comparable<Object?>.
     return compare;
   }
   // Otherwise wrap and cast the arguments on each call.
@@ -391,14 +412,14 @@
     with MapMixin<K, V> {
   _SplayTreeMapNode<K, V>? _root;
 
-  Comparator<K> _compare;
-  _Predicate _validKey;
+  int Function(K, K) _compare;
+  bool Function(Object?)? _validKey;
 
   SplayTreeMap([
     int Function(K key1, K key2)? compare,
     bool Function(dynamic potentialKey)? isValidKey,
   ]) : _compare = compare ?? _defaultCompare<K>(),
-       _validKey = isValidKey ?? ((dynamic a) => a is K);
+       _validKey = isValidKey;
 
   /// Creates a [SplayTreeMap] that contains all key/value pairs of [other].
   ///
@@ -411,7 +432,7 @@
   /// print(fromBaseMap); // {1: A, 2: B, 3: C}
   /// ```
   factory SplayTreeMap.from(
-    Map<dynamic, dynamic> other, [
+    Map<Object?, Object?> other, [
     int Function(K key1, K key2)? compare,
     bool Function(dynamic potentialKey)? isValidKey,
   ]) {
@@ -496,86 +517,82 @@
     return map;
   }
 
-  V? operator [](Object? key) {
-    if (!_validKey(key)) return null;
-    if (_root != null) {
-      int comp = _splay(key as dynamic);
-      if (comp == 0) {
-        return _root!.value;
-      }
-    }
-    return null;
-  }
+  V? operator [](Object? key) => _untypedLookup(key)?.value;
 
   V? remove(Object? key) {
-    if (!_validKey(key)) return null;
-    _SplayTreeMapNode<K, V>? mapRoot = _remove(key as dynamic);
-    if (mapRoot != null) return mapRoot.value;
-    return null;
+    final root = _untypedLookup(key);
+    if (root == null) return null;
+    _removeRoot();
+    return root.value;
   }
 
   void operator []=(K key, V value) {
     // Splay on the key to move the last node on the search path for
     // the key to the root of the tree.
-    int comp = _splay(key);
-    if (comp == 0) {
-      _root = _root!._replaceValue(value);
-      // To represent structure change, in case someone caches the old node.
-      _splayCount += 1;
+    int comparison = _splay(key);
+    if (comparison == 0) {
+      _root!.value = value;
       return;
     }
-    _addNewRoot(_SplayTreeMapNode(key, value), comp);
+    _addNewRoot(_SplayTreeMapNode(key, value), comparison);
   }
 
   V putIfAbsent(K key, V ifAbsent()) {
-    int comp = _splay(key);
-    if (comp == 0) {
+    int comparison = _splay(key);
+    if (comparison == 0) {
       return _root!.value;
     }
-    int modificationCount = _modificationCount;
-    int splayCount = _splayCount;
+    int originalModificationCount = _modificationCount;
+    int originalSplayCount = _splayCount;
     V value = ifAbsent();
-    if (modificationCount != _modificationCount || splayCount != _splayCount) {
-      comp = _splay(key);
-      if (comp == 0) {
-        // Key was added.
-        _root = _root!._replaceValue(value);
-        _splayCount += 1; // Tree restructured.
+    if (originalModificationCount != _modificationCount ||
+        originalSplayCount != _splayCount) {
+      comparison = _splay(key);
+      if (comparison == 0) {
+        // Key was added by `ifAbsent`, change value.
+        _root!.value = value;
         return value;
       }
       // Key is still not there.
     }
-    _addNewRoot(_SplayTreeMapNode(key, value), comp);
+    _addNewRoot(_SplayTreeMapNode(key, value), comparison);
     return value;
   }
 
   V update(K key, V update(V value), {V Function()? ifAbsent}) {
-    var comp = _splay(key);
-    if (comp == 0) {
-      var modificationCount = _modificationCount;
-      var splayCount = _splayCount;
+    var comparison = _splay(key);
+    if (comparison == 0) {
+      final originalModificationCount = _modificationCount;
+      final originalSplayCount = _splayCount;
       var newValue = update(_root!.value);
-      if (modificationCount != _modificationCount) {
+      if (originalModificationCount != _modificationCount) {
         throw ConcurrentModificationError(this);
       }
-      if (splayCount != _splayCount) {
-        _splay(key);
+      if (originalSplayCount != _splayCount) {
+        comparison = _splay(key);
+        // Can only fail to find the same key in a tree with the same
+        // modification count if a key has changed its comparison since
+        // it was added to the tree (which means the tree might no be
+        // well-ordered, so much can go wrong).
+        if (comparison != 0) throw ConcurrentModificationError(this);
       }
-      _root = _root!._replaceValue(newValue);
-      _splayCount += 1;
+      _root!.value = newValue;
       return newValue;
     }
     if (ifAbsent != null) {
-      var modificationCount = _modificationCount;
-      var splayCount = _splayCount;
+      final originalModificationCount = _modificationCount;
+      final originalSplayCount = _splayCount;
       var newValue = ifAbsent();
-      if (modificationCount != _modificationCount) {
+      if (originalModificationCount != _modificationCount) {
         throw ConcurrentModificationError(this);
       }
-      if (splayCount != _splayCount) {
-        comp = _splay(key);
+      if (originalSplayCount != _splayCount) {
+        comparison = _splay(key);
+        // Can only happen if a key changed its comparison since being
+        // added to the tree.
+        if (comparison == 0) throw ConcurrentModificationError(this);
       }
-      _addNewRoot(_SplayTreeMapNode(key, newValue), comp);
+      _addNewRoot(_SplayTreeMapNode(key, newValue), comparison);
       return newValue;
     }
     throw ArgumentError.value(key, "key", "Key not in map.");
@@ -620,7 +637,7 @@
     _clear();
   }
 
-  bool containsKey(Object? key) => _containsKey(key);
+  bool containsKey(Object? key) => _untypedLookup(key) != null;
 
   bool containsValue(Object? value) {
     int initialSplayCount = _splayCount;
@@ -653,16 +670,18 @@
   ///
   /// Returns `null` if the map is empty.
   K? firstKey() {
-    if (_root == null) return null;
-    return _first!.key;
+    final root = _root;
+    if (root == null) return null;
+    return (_root = _splayMin(root)).key;
   }
 
   /// The last key in the map.
   ///
   /// Returns `null` if the map is empty.
   K? lastKey() {
-    if (_root == null) return null;
-    return _last!.key;
+    final root = _root;
+    if (root == null) return null;
+    return (_root = _splayMax(root)).key;
   }
 
   /// The last key in the map that is strictly smaller than [key].
@@ -671,8 +690,8 @@
   K? lastKeyBefore(K key) {
     if (key == null) throw ArgumentError(key);
     if (_root == null) return null;
-    int comp = _splay(key);
-    if (comp < 0) return _root!.key;
+    int comparison = _splay(key);
+    if (comparison < 0) return _root!.key;
     _SplayTreeMapNode<K, V>? node = _root!._left;
     if (node == null) return null;
     var nodeRight = node._right;
@@ -688,8 +707,8 @@
   K? firstKeyAfter(K key) {
     if (key == null) throw ArgumentError(key);
     if (_root == null) return null;
-    int comp = _splay(key);
-    if (comp > 0) return _root!.key;
+    int comparison = _splay(key);
+    if (comparison > 0) return _root!.key;
     _SplayTreeMapNode<K, V>? node = _root!._right;
     if (node == null) return null;
     var nodeLeft = node._left;
@@ -741,12 +760,18 @@
   /// This can be caused by a splay operation.
   /// If the key-set changes, iteration is aborted before getting
   /// here, so we know that the keys are the same as before, it's
-  /// only the tree that has been reordered.
+  /// only the tree nodes that has been reordered.
   void _rebuildPath(K key) {
     _path.clear();
-    _tree._splay(key);
-    _path.add(_tree._root!);
-    _splayCount = _tree._splayCount;
+    var comparison = _tree._splay(key);
+    if (comparison == 0) {
+      _path.add(_tree._root!);
+      _splayCount = _tree._splayCount;
+      return;
+    }
+    // Should not be able to happen unless an element changes
+    // its comparison order while in the tree.
+    throw ConcurrentModificationError(this);
   }
 
   void _findLeftMostDescendent(Node? node) {
@@ -801,12 +826,15 @@
   bool get isEmpty => _tree._count == 0;
   Iterator<K> get iterator => _SplayTreeKeyIterator<K, Node>(_tree);
 
-  bool contains(Object? o) => _tree._containsKey(o);
+  bool contains(Object? element) => _tree._untypedLookup(element) != null;
 
   Set<K> toSet() {
     SplayTreeSet<K> set = SplayTreeSet<K>(_tree._compare, _tree._validKey);
-    set._count = _tree._count;
-    set._root = set._copyNode<Node>(_tree._root);
+    var root = _tree._root;
+    if (root != null) {
+      set._root = set._copyNode<Node>(root);
+      set._count = _tree._count;
+    }
     return set;
   }
 }
@@ -840,16 +868,39 @@
 class _SplayTreeValueIterator<K, V>
     extends _SplayTreeIterator<K, _SplayTreeMapNode<K, V>, V> {
   _SplayTreeValueIterator(SplayTreeMap<K, V> map) : super(map);
-  V _getValue(_SplayTreeMapNode<K, V> node) => node.value;
+  // SplayTreeMapNode.value is mutable, so cache it when moveNext returns true.
+  // Cache it eagerly since the type `V` may be nullable, so we can't tell
+  // if `_current` has been assigned yet or not.
+  V? _current;
+  bool moveNext() {
+    var result = super.moveNext();
+    _current = result ? _path.last.value : null;
+    return result;
+  }
+
+  V _getValue(_SplayTreeMapNode<K, V> node) => _current as V;
 }
 
 class _SplayTreeMapEntryIterator<K, V>
     extends _SplayTreeIterator<K, _SplayTreeMapNode<K, V>, MapEntry<K, V>> {
   _SplayTreeMapEntryIterator(SplayTreeMap<K, V> map) : super(map);
+  // `SplayTreeMapNode.value` is mutable, so cache the value the first time
+  // `current` is read. (Avoids doing an allocation if `current` is not read.
+  // Unlike `SplayTreeValueIterator`, the type of [current] is known to be
+  // non-nullable.)
+  MapEntry<K, V>? _current;
+
   MapEntry<K, V> _getValue(_SplayTreeMapNode<K, V> node) =>
-      MapEntry<K, V>(node.key, node.value);
+      _current ??= MapEntry<K, V>(node.key, node.value);
+
+  bool moveNext() {
+    _current = null;
+    return super.moveNext();
+  }
 
   // Replaces the value of the current node.
+  //
+  // Used by [SplayTreeMap.updateAll].
   void _replaceValue(V value) {
     assert(_path.isNotEmpty);
     if (_modificationCount != _tree._modificationCount) {
@@ -858,21 +909,7 @@
     if (_splayCount != _tree._splayCount) {
       _rebuildPath(_path.last.key);
     }
-    var last = _path.removeLast();
-    var newLast = last._replaceValue(value);
-    if (_path.isEmpty) {
-      _tree._root = newLast;
-    } else {
-      var parent = _path.last;
-      if (identical(last, parent._left)) {
-        parent._left = newLast;
-      } else {
-        assert(identical(last, parent._right));
-        parent._right = newLast;
-      }
-    }
-    _path.add(newLast);
-    _splayCount = ++_tree._splayCount;
+    _path.last.value = value;
   }
 }
 
@@ -962,8 +999,8 @@
     with Iterable<E>, SetMixin<E> {
   _SplayTreeSetNode<E>? _root;
 
-  Comparator<E> _compare;
-  _Predicate _validKey;
+  int Function(E, E) _compare;
+  bool Function(Object?)? _validKey;
 
   /// Create a new [SplayTreeSet] with the given compare function.
   ///
@@ -991,7 +1028,7 @@
     int Function(E key1, E key2)? compare,
     bool Function(dynamic potentialKey)? isValidKey,
   ]) : _compare = compare ?? _defaultCompare<E>(),
-       _validKey = isValidKey ?? ((dynamic v) => v is E);
+       _validKey = isValidKey;
 
   /// Creates a [SplayTreeSet] that contains all [elements].
   ///
@@ -1059,25 +1096,26 @@
   bool get isNotEmpty => _root != null;
 
   E get first {
-    if (_count == 0) throw IterableElementError.noElement();
-    return _first!.key;
+    final root = _root;
+    if (root == null) throw IterableElementError.noElement();
+    return (_root = _splayMin(root)).key;
   }
 
   E get last {
-    if (_count == 0) throw IterableElementError.noElement();
-    return _last!.key;
+    final root = _root;
+    if (root == null) throw IterableElementError.noElement();
+    return (_root = _splayMax(root)).key;
   }
 
   E get single {
-    if (_count == 0) throw IterableElementError.noElement();
-    if (_count > 1) throw IterableElementError.tooMany();
-    return _root!.key;
+    if (_count == 1) return _root!.key;
+    throw _count == 0
+        ? IterableElementError.noElement()
+        : IterableElementError.tooMany();
   }
 
   // From Set.
-  bool contains(Object? element) {
-    return _validKey(element) && _splay(element as E) == 0;
-  }
+  bool contains(Object? element) => _untypedLookup(element) != null;
 
   bool add(E element) => _add(element);
 
@@ -1089,8 +1127,9 @@
   }
 
   bool remove(Object? object) {
-    if (!_validKey(object)) return false;
-    return _remove(object as E) != null;
+    if (_untypedLookup(object) == null) return false;
+    _removeRoot();
+    return true;
   }
 
   void addAll(Iterable<E> elements) {
@@ -1101,23 +1140,23 @@
 
   void removeAll(Iterable<Object?> elements) {
     for (Object? element in elements) {
-      if (_validKey(element)) _remove(element as E);
+      if (_untypedLookup(element) != null) {
+        _removeRoot();
+      }
     }
   }
 
   void retainAll(Iterable<Object?> elements) {
     // Build a set with the same sense of equality as this set.
     SplayTreeSet<E> retainSet = SplayTreeSet<E>(_compare, _validKey);
-    int modificationCount = _modificationCount;
+    final int originalModificationCount = _modificationCount;
     for (Object? object in elements) {
-      if (modificationCount != _modificationCount) {
+      if (originalModificationCount != _modificationCount) {
         // The iterator should not have side effects.
         throw ConcurrentModificationError(this);
       }
-      // Equivalent to this.contains(object).
-      if (_validKey(object) && _splay(object as E) == 0) {
-        retainSet.add(_root!.key);
-      }
+      final root = _untypedLookup(object);
+      if (root != null) retainSet.add(root.key);
     }
     // Take over the elements from the retained set, if it differs.
     if (retainSet._count != _count) {
@@ -1127,27 +1166,28 @@
     }
   }
 
-  E? lookup(Object? object) {
-    if (!_validKey(object)) return null;
-    int comp = _splay(object as E);
-    if (comp != 0) return null;
-    return _root!.key;
-  }
+  E? lookup(Object? object) => _untypedLookup(object)?.key;
 
-  Set<E> intersection(Set<Object?> other) {
-    Set<E> result = SplayTreeSet<E>(_compare, _validKey);
-    for (E element in this) {
-      if (other.contains(element)) result.add(element);
-    }
-    return result;
-  }
+  Set<E> intersection(Set<Object?> other) => _filter(other, true);
 
-  Set<E> difference(Set<Object?> other) {
-    Set<E> result = SplayTreeSet<E>(_compare, _validKey);
+  Set<E> difference(Set<Object?> other) => _filter(other, false);
+
+  SplayTreeSet<E> _filter(Set<Object?> other, bool include) {
+    // Copy nodes selectively.
+    // Simulates repeated `add(element)` with elements that are
+    // known to be in increasing order, which creates a left-spine structure.
+    _SplayTreeSetNode<E>? root = null;
+    var count = 0;
     for (E element in this) {
-      if (!other.contains(element)) result.add(element);
+      if (other.contains(element) == include) {
+        assert(root == null || _compare(root.key, element) <= 0);
+        root = _SplayTreeSetNode<E>(element).._left = root;
+        count++;
+      }
     }
-    return result;
+    return SplayTreeSet<E>(_compare, _validKey)
+      .._root = root
+      .._count = count;
   }
 
   Set<E> union(Set<E> other) {
@@ -1156,45 +1196,46 @@
 
   SplayTreeSet<E> _clone() {
     var set = SplayTreeSet<E>(_compare, _validKey);
-    set._count = _count;
-    set._root = _copyNode<_SplayTreeSetNode<E>>(_root);
+    var root = _root;
+    if (root != null) {
+      set._root = _copyNode<_SplayTreeSetNode<E>>(root);
+      set._count = _count;
+    }
     return set;
   }
 
-  // Copies the structure of a SplayTree into a new similar structure.
+  // Copies the structure of a SplayTree into a new similar SplayTreeSet
+  // structure.
   // Works on _SplayTreeMapNode as well, but only copies the keys,
+  // which is used for `.keys.toSet()`.
   _SplayTreeSetNode<E>? _copyNode<Node extends _SplayTreeNode<E, Node>>(
-    Node? node,
+    Node source,
   ) {
-    if (node == null) return null;
-    // Given a source node and a destination node, copy the left
-    // and right subtrees of the source node into the destination node.
-    // The left subtree is copied recursively, but the right spine
-    // of every subtree is copied iteratively.
-    void copyChildren(Node node, _SplayTreeSetNode<E> dest) {
-      Node? left;
-      Node? right;
-      do {
-        left = node._left;
-        right = node._right;
-        if (left != null) {
-          var newLeft = _SplayTreeSetNode<E>(left.key);
-          dest._left = newLeft;
+    // The left subtree is copied recursively if there are two children,
+    // and the right spine of every subtree, and any left-only child,
+    // is copied iteratively.
+    _SplayTreeSetNode<E> result = _SplayTreeSetNode<E>(source.key);
+    // Copy of `source` that hasn't had children added yet.
+    var target = result;
+    while (true) {
+      var sourceLeft = source._left;
+      var sourceRight = source._right;
+      if (sourceLeft != null) {
+        if (sourceRight != null) {
           // Recursively copy the left tree.
-          copyChildren(left, newLeft);
+          target._left = _copyNode<Node>(sourceLeft);
+        } else {
+          // Iteratively copy the left and only child.
+          source = sourceLeft;
+          target = target._left = _SplayTreeSetNode<E>(source.key);
+          continue;
         }
-        if (right != null) {
-          var newRight = _SplayTreeSetNode<E>(right.key);
-          dest._right = newRight;
-          // Set node and dest to copy the right tree iteratively.
-          node = right;
-          dest = newRight;
-        }
-      } while (right != null);
+      } else if (sourceRight == null) {
+        break; // Done when reaching a leaf node.
+      }
+      source = sourceRight;
+      target = target._right = _SplayTreeSetNode<E>(sourceRight.key);
     }
-
-    var result = _SplayTreeSetNode<E>(node.key);
-    copyChildren(node, result);
     return result;
   }
 
diff --git a/tests/corelib/splay_tree_test.dart b/tests/corelib/splay_tree_test.dart
index d155d25..a10c810 100644
--- a/tests/corelib/splay_tree_test.dart
+++ b/tests/corelib/splay_tree_test.dart
@@ -2,13 +2,11 @@
 // for details. All rights reserved. Use of this source code is governed by a
 // BSD-style license that can be found in the LICENSE file.
 
-// Formatting can break multitests, so don't format them.
-// dart format off
-
-// Dart test for Splaytrees.
+// Dart test for Splay-tree data structures.
 library splay_tree_test;
 
 import "package:expect/expect.dart";
+import "package:expect/variations.dart";
 import 'dart:collection';
 
 main() {
@@ -58,6 +56,79 @@
   regressRemoveWhere2();
   regressFromCompare();
   regressIncomparable();
+
+  // Setting values do not break iteration.
+  // Setting values during iteration may show either old or new value,
+  // but must be consistent when read.
+  var map =
+      SplayTreeMap<String, int>()
+        ..["a"] = 1
+        ..["b"] = 2;
+  var index = 0;
+  for (var v in map.values) {
+    if (index == 0) {
+      Expect.equals(1, v);
+      map["b"] = 42;
+    } else {
+      Expect.equals(1, index);
+      if (v != 42 && v != 2) {
+        Expect.fail('map["b"] not 2 or 42');
+      }
+      map["b"] = 2;
+    }
+    index++;
+  }
+
+  index = 0;
+  // Same using explicit iterator.
+  for (var iterator = map.values.iterator; iterator.moveNext(); index++) {
+    if (index == 0) {
+      Expect.equals(1, iterator.current);
+      map["b"] = 42;
+    } else {
+      Expect.equals(1, index);
+      var v = iterator.current;
+      if (v != 42 && v != 2) {
+        Expect.fail('map["b"] not 2 or 42');
+      }
+      map["b"] = 2;
+      var v2 = iterator.current;
+      Expect.equals(v, v2, "current getter not consistent: $v -> $v2");
+    }
+  }
+
+  // Same for values accessed through `.entries`.
+  for (var entry in map.entries) {
+    if (entry.key == "a") {
+      Expect.equals(1, entry.value);
+      map["b"] = 42;
+    } else {
+      Expect.equals("b", entry.key);
+      var v = entry.value;
+      if (v != 42 && v != 2) {
+        Expect.fail('map["b"] not 2 or 42');
+      }
+      map["b"] = 2;
+      var v2 = entry.value;
+      Expect.equals(v, v2, "current getter not consistent: $v -> $v2");
+    }
+  }
+
+  for (var iterator = map.entries.iterator; iterator.moveNext();) {
+    if (iterator.current.key == "a") {
+      Expect.equals(1, iterator.current.value);
+      map["b"] = 42;
+    } else {
+      Expect.equals("b", iterator.current.key);
+      var v = iterator.current.value;
+      if (v != 42 && v != 2) {
+        Expect.fail('map["b"] not 2 or 42');
+      }
+      map["b"] = 2;
+      var v2 = iterator.current.value;
+      Expect.equals(v, v2, "current getter not consistent: $v -> $v2");
+    }
+  }
 }
 
 void regressRemoveWhere() {
@@ -131,8 +202,6 @@
   Expect.equals(null, map[key(5)]);
   Expect.equals(null, map[1]);
   Expect.equals(null, map["string"]);
-  map[1] = 42; //# 01: compile-time error
-  map["string"] = 42; //# 02: compile-time error
   map[key(5)] = 42;
   Expect.equals(4, map.length);
   Expect.equals(42, map[key(5)]);
@@ -140,19 +209,27 @@
 
 // Incomparable keys throw when added, even on an empty collection.
 void regressIncomparable() {
-  var set = SplayTreeSet();
+  // With no `compare` function given, it defaults to one that does
+  // dynamic downcast of both arguments to `Comparable<Object?>`,
+  // then invoking its `compareTo` with the latter.
+  // Since `IncomparableKey` can't be downcast to `Comparable`, it should throw.
+  var set = SplayTreeSet<Object?>();
   Expect.throws(() => set.add(IncomparableKey(0)));
   Expect.throws(() => set.lookup(IncomparableKey(0)));
   set.add(1);
-  Expect.throws(() => set.add(IncomparableKey(0)));
-  Expect.throws(() => set.lookup(IncomparableKey(0)));
+  if (checkedImplicitDowncasts) {
+    Expect.throws(() => set.add(IncomparableKey(0)));
+    Expect.throws(() => set.lookup(IncomparableKey(0)));
+  }
 
   var map = SplayTreeMap();
   Expect.throws(() => map[IncomparableKey(0)] = 0);
   Expect.throws(() => map.putIfAbsent(IncomparableKey(0), () => 0));
   map[1] = 1;
-  Expect.throws(() => map[IncomparableKey(0)] = 0);
-  Expect.throws(() => map.putIfAbsent(IncomparableKey(0), () => 0));
+  if (checkedImplicitDowncasts) {
+    Expect.throws(() => map[IncomparableKey(0)] = 0);
+    Expect.throws(() => map.putIfAbsent(IncomparableKey(0), () => 0));
+  }
 
   // But not if the compare function allows them.
   // This now includes `null`.