HiveBrain v1.2.0
Get Started
← Back to all entries
patternjavaMinor

Intersection of a Stream of Collections

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
streamcollectionsintersection

Problem

An API (I have no control over) gives me a bunch of Collections (Lists, to be precise,) and I need the elements that are common to all of them. In this concrete case, I'm trying to find Selenium WebElements that match several selectors. But I went off on a Yak-Shaving spree and tried to figure out how to make an intersection of a stream of Collections. This is what I came up with (and I think it works, at least my tests say so.)

/**
 * Given a stream of collections, intersect it.
 *
 * @param stream A stream of collections you want to intersect
 * @param  Any item that has a decent .equals()/.hashCode().
 *            Seriously, don't attempt this without a working .equals()/.hashCode().
 * @return The unique(!) elements present in *all* collections in the stream
 */
public static  Collection intersect(Stream> stream) {
  // Optimization: sorting by size so that the biggest constrainer
  // (smallest collection) comes first
  final Iterator> allLists = stream.sorted(
        (l1, l2) -> l1.size() - l2.size()
  ).iterator();

  final Set result = new HashSet<>(allLists.next());
  while(allLists.hasNext()) {
      result.retainAll(allLists.next());
  }
  return result;
}


I fiddled around with Collectors in the beginning, but that seemed to be too convoluted, since it'd require the unit element to be some special class.

Any better way of doing it? Is this woefully inefficient? I think its complexity is somewhere around O(n) where n should be the total number of elements in all the collections.

Solution

Using built-ins

You can use built-in comparators instead of creating your own. For example

final Iterator> allLists = stream.sorted(
      (l1, l2) -> l1.size() - l2.size()
).iterator();


can be written more simply using the comparingInt comparator, which compares elements according to the result of applying the given function (returning an int) to each element. In this case, you could have:

final Iterator> allLists =
    stream.sorted(Comparator.comparingInt(Collection::size)).iterator();


The function returning the integer to compare is written as a method-reference refering to Collection.size(). Also, it gets rid of the comparison by subtracting the two int values, which has a corner-case.

Bug if the stream is empty

There is a bug in the current method if the given stream is empty. This is because of

final Set result = new HashSet<>(allLists.next());


which unconditionally invokes next() on the iterator of the stream. You can call intersect(Stream.empty()) to verify it; it will throw a NoSuchElementException. In the case of an empty stream, the method should return an empty collection instead.

Better generics

With the current signature of

 Collection intersect(Stream> stream)


the issue is that passing for example a Stream> will not compile. To verify this, you can have

intersect(Arrays.asList(Arrays.asList(1)).stream())


Instead, we can introduce a second generic type C for the collection with the following signature:

> Collection intersect(Stream stream)


This makes sure that you can pass any collection for the elements of the input stream. With such a signature, you can verify that the above compilation error is not there anymore.

Huge performance improvement

If you're not dealing with sets, the retainAll process is very inefficient:

result.retainAll(allLists.next());


Since this method checks to see if the collection given as parameter doesn't contain the elements of this collection (in order to remove them), using this on a List is O(n), making the whole operation O(n²). Instead, pass a new HashSet:

result.retainAll(new HashSet<>(allLists.next()));


Since the contains operation is constant time for sets, this will always be O(n), and, thus, a lot faster (at the expense of more memory).

Why sort?

The comment in your code says that the sorting step is used as an optimization, to ensure that shorter collections comes first. Intrigued, I made a benchmark comparing the code with and without sorting. It applied the two methods to a stream having 1000 and 10.000 elements where each inner collections had 100 and 1.000 elements. The elements chosen were random integers. Here are the results (Windows 10 x64, JDK 1.8.0_102, i5, 2.90 GHz):

Benchmark                 (lengthOfEach)  (totalLength)  Mode  Cnt    Score   Error  Units
StreamTest.intersect                 100           1000  avgt   30    1,757 ± 0,069  ms/op
StreamTest.intersect                 100          10000  avgt   30   18,876 ± 0,954  ms/op
StreamTest.intersect                1000           1000  avgt   30   17,287 ± 0,378  ms/op
StreamTest.intersect                1000          10000  avgt   30  177,633 ± 7,043  ms/op
StreamTest.intersectSort             100           1000  avgt   30    1,805 ± 0,080  ms/op
StreamTest.intersectSort             100          10000  avgt   30   18,434 ± 0,621  ms/op
StreamTest.intersectSort            1000           1000  avgt   30   19,472 ± 0,981  ms/op
StreamTest.intersectSort            1000          10000  avgt   30  184,440 ± 5,380  ms/op


For the values tested, this shows that there is really no measurable difference between the two, so I'd just get rid of this sorting.

Code of benchmark for completeness:

```
@Warmup(iterations = 10, time = 700, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 10, time = 700, timeUnit = TimeUnit.MILLISECONDS)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Fork(3)
public class StreamTest {

public static > Collection intersectSort(Stream stream) {
final Iterator allLists = stream.sorted(Comparator.comparingInt(Collection::size)).iterator();

if (!allLists.hasNext()) return Collections.emptySet();

final Set result = new HashSet<>(allLists.next());
while (allLists.hasNext()) {
result.retainAll(new HashSet<>(allLists.next()));
}
return result;
}

public static > Collection intersect(Stream stream) {
final Iterator allLists = stream.iterator();

if (!allLists.hasNext()) return Collections.emptySet();

final Set result = new HashSet<>(allLists.next());
while (allLists.hasNext()) {
result.retainAll(new HashSet<>(allLists.next()));
}
return result;
}

@State(Scope.Benchmark)
public static class Container {

@Param({ "100", "1000" })
private int totalLength;

@Pa

Code Snippets

final Iterator<Collection<T>> allLists = stream.sorted(
      (l1, l2) -> l1.size() - l2.size()
).iterator();
final Iterator<Collection<T>> allLists =
    stream.sorted(Comparator.comparingInt(Collection::size)).iterator();
final Set<T> result = new HashSet<>(allLists.next());
<T> Collection<T> intersect(Stream<Collection<T>> stream)
intersect(Arrays.asList(Arrays.asList(1)).stream())

Context

StackExchange Code Review Q#145594, answer score: 5

Revisions (0)

No revisions yet.