Add `choose` extension method to iterable, choosing a number of elements at random. (#158)
* Add `sample` extension method to iterable, choosing a number of elements at random.
Inspired by: https://stackoverflow.com/questions/64117939/picking-n-unique-random-enums-in-dart
diff --git a/lib/src/iterable_extensions.dart b/lib/src/iterable_extensions.dart
index 728c1ec..82217e9 100644
--- a/lib/src/iterable_extensions.dart
+++ b/lib/src/iterable_extensions.dart
@@ -2,6 +2,8 @@
// for details. All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.
+import 'dart:math' show Random;
+
import 'package:collection/src/utils.dart';
import 'algorithms.dart';
@@ -16,6 +18,37 @@
/// iterables with specific element types include those of
/// [IterableComparableExtension] and [IterableNullableExtension].
extension IterableExtension<T> on Iterable<T> {
+ /// Selects [count] elements at random from this iterable.
+ ///
+ /// The returned list contains [count] different elements of the iterable.
+ /// If the iterable contains fewer that [count] elements,
+ /// the result will contain all of them, but will be shorter than [count].
+ /// If the same value occurs more than once in the iterable,
+ /// it can also occur more than once in the chosen elements.
+ ///
+ /// Each element of the iterable has the same chance of being chosen.
+ /// The chosen elements are not in any specific order.
+ List<T> sample(int count, [Random? random]) {
+ RangeError.checkNotNegative(count, 'count');
+ var iterator = this.iterator;
+ var chosen = <T>[];
+ for (var i = 0; i < count; i++) {
+ if (iterator.moveNext()) {
+ chosen.add(iterator.current);
+ } else {
+ return chosen;
+ }
+ }
+ var index = count;
+ random ??= Random();
+ while (iterator.moveNext()) {
+ index++;
+ var position = random.nextInt(index);
+ if (position < count) chosen[position] = iterator.current;
+ }
+ return chosen;
+ }
+
/// The elements that do not satisfy [test].
Iterable<T> whereNot(bool Function(T element) test) =>
where((element) => !test(element));
@@ -740,7 +773,7 @@
/// Extensions on comparator functions.
extension ComparatorExtension<T> on Comparator<T> {
/// The inverse ordering of this comparator.
- int Function(T, T) get inverse => (T a, T b) => this(b, a);
+ Comparator<T> get inverse => (T a, T b) => this(b, a);
/// Makes a comparator on [R] values using this comparator.
///
diff --git a/test/extensions_test.dart b/test/extensions_test.dart
index 3289757..5d36020 100644
--- a/test/extensions_test.dart
+++ b/test/extensions_test.dart
@@ -2,7 +2,7 @@
// for details. All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.
-import 'dart:math' show pow;
+import 'dart:math' show pow, Random;
import 'package:test/test.dart';
@@ -989,6 +989,53 @@
expect(iterable(['4', '3', '2', '1']).isSorted(), false);
});
});
+ group('.sample', () {
+ test('errors', () {
+ expect(() => iterable([1]).sample(-1), throwsRangeError);
+ });
+ test('empty', () {
+ var empty = iterable(<int>[]);
+ expect(empty.sample(0), []);
+ expect(empty.sample(5), []);
+ });
+ test('single', () {
+ var single = iterable([1]);
+ expect(single.sample(0), []);
+ expect(single.sample(1), [1]);
+ expect(single.sample(5), [1]);
+ });
+ test('multiple', () {
+ var multiple = iterable([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
+ expect(multiple.sample(0), []);
+ var one = multiple.sample(1);
+ expect(one, hasLength(1));
+ expect(one.first, inInclusiveRange(1, 10));
+ var some = multiple.sample(3);
+ expect(some, hasLength(3));
+ expect(some[0], inInclusiveRange(1, 10));
+ expect(some[1], inInclusiveRange(1, 10));
+ expect(some[2], inInclusiveRange(1, 10));
+ expect(some[0], isNot(some[1]));
+ expect(some[0], isNot(some[2]));
+ expect(some[1], isNot(some[2]));
+
+ var seen = <int>{};
+ do {
+ seen.addAll(multiple.sample(3));
+ } while (seen.length < 10);
+ // Should eventually terminate.
+ });
+ test('random', () {
+ // Passing in a `Random` makes result deterministic.
+ var multiple = iterable([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
+ var seed = 12345;
+ var some = multiple.sample(5, Random(seed));
+ for (var i = 0; i < 10; i++) {
+ var other = multiple.sample(5, Random(seed));
+ expect(other, some);
+ }
+ });
+ });
});
group('Comparator', () {