HiveBrain v1.2.0
Get Started
← Back to all entries
patterncsharpMinor

Producing the intersection of several sequences

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
theproducingseveralintersectionsequences

Problem

Based on this SO answer I have created a method that produce the set intersection of several sequences:

private static IEnumerable Intersect(this IEnumerable> source)
{   
    using (IEnumerator> sourceIterator = source.GetEnumerator())
    {
        //if there is no first element return an empty list
        if (!sourceIterator.MoveNext())
        {
            yield break;
        }

        IEnumerable firstSequence = sourceIterator.Current;
        if (sourceIterator.MoveNext())
        {
            //create a hashset of the first sequence
            HashSet hashSet = new HashSet(firstSequence);

            //intersect the other sequences
            do
            {                       
                hashSet.IntersectWith(sourceIterator.Current);
            }
            while (sourceIterator.MoveNext());
            foreach (TSource element in hashSet)
            {
                yield return element;
            }
        }
        else
        {                   
            //if there is only one sequence, return it without any intersection
            foreach (TSource element in firstSequence)
            {
                yield return element;
            }
        }
    }
}


I'm particularly annoyed that the method is lot of times slower than the following implementation, when there is only one sequence:

public static IEnumerable
          Intersect(this IEnumerable> source)
{
    return source.Aggregate((x, y) => x.Intersect(y));
}


I already put a special test and code path when there is only one element, but it does not seems to be enough. Beside this problem, is there any obvious way to optimize the method?

Solution

First let me point out that you shouldn't name this extension method as you have named: Intersect(). The name Intersect() already exists and will return the intersection of two collections. Your method is intersecting a collection of collections with each other. It should must be named differently. I would suggest IntersectAll() or possibly IntersectMany().

I feel your implementation for the case of a single collection is wrong. You are performing a set operation here. You need to ensure that for all cases, you maintain set semantics. You assume that for a single collection that you may return that collection. But what if that collection contains duplicates? You would need to filter out those duplicates. The aggregated version should as well.

With that in mind, I would change your implementation like so:

public static IEnumerable IntersectAll(
        this IEnumerable> source)
{
    using (var enumerator = source.GetEnumerator())
    {
        if (!enumerator.MoveNext())
            yield break;

        var set = new HashSet(enumerator.Current);
        while (enumerator.MoveNext())
            set.IntersectWith(enumerator.Current);
        foreach (var item in set)
            yield return item;
    }
}


I'm not sure how you did your testing but even with your implementation, I still get much faster results, even in the single collection case. Tested in LINQPad with optimizations enabled:

Your implementation:

public static IEnumerable IntersectAll1(this IEnumerable> source)
{   
    using (IEnumerator> sourceIterator = source.GetEnumerator())
    {
        //if there is no first element return an empty list
        if (!sourceIterator.MoveNext())
        {
            yield break;
        }

        IEnumerable firstSequence = sourceIterator.Current;
        if (sourceIterator.MoveNext())
        {
            //create a hashset of the first sequence
            HashSet hashSet = new HashSet(firstSequence);

            //intersect the other sequences
            do
            {                       
                hashSet.IntersectWith(sourceIterator.Current);
            }
            while (sourceIterator.MoveNext());
            foreach (TSource element in hashSet)
            {
                yield return element;
            }
        }
        else
        {                   
            //if there is only one sequence, return it without any intersection
            foreach (TSource element in firstSequence)
            {
                yield return element;
            }
        }
    }
}


public static IEnumerable IntersectAll2(
        this IEnumerable> source)
{
    using (var enumerator = source.GetEnumerator())
    {
        if (!enumerator.MoveNext())
            yield break;

        var set = new HashSet(enumerator.Current);
        while (enumerator.MoveNext())
            set.IntersectWith(enumerator.Current);
        foreach (var item in set)
            yield return item;
    }
}


public static IEnumerable IntersectAllControl(
        this IEnumerable> source)
{
    return source.Aggregate((x, y) => x.Intersect(y));
}


void Main()
{
    var rng = new Random();
    var collections = Enumerable.Range(0, 500)
        .Select(i => Enumerable.Range(rng.Next(0, 50), 100).ToArray())
        .Cast>()
        //.Take(1)
        .ToArray();

    const int Iterations = 1000; // 1000000 for single collection

    object result1 = null;
    var timer1 = Stopwatch.StartNew();
    for (var i = 0; i < Iterations; i++)
        result1 = collections.IntersectAll1().ToList();
    timer1.Stop();
    result1.Dump(String.Format("result1: {0}", timer1.Elapsed));

    object result2 = null;
    var timer2 = Stopwatch.StartNew();
    for (var i = 0; i < Iterations; i++)
        result2 = collections.IntersectAll2().ToList();
    timer2.Stop();
    result2.Dump(String.Format("result2: {0}", timer2.Elapsed));

    object controlResult = null;
    var controlTimer = Stopwatch.StartNew();
    for (var i = 0; i < Iterations; i++)
        controlResult = collections.IntersectAllControl().ToList();
    controlTimer.Stop();
    controlResult.Dump(String.Format("controlResult: {0}", controlTimer.Elapsed));
}


Note: The ToList() calls were needed on all invocations to ensure that the intersections are actually generated.

Run on my machine with multiple collections:

result1: 00:00:01.0210291
result2: 00:00:01.0285069
controlResult: 00:00:03.1512838

And with a single collection:

result1: 00:00:02.9254441
result2: 00:00:04.8505489
controlResult: 00:00:00.2102433

I have to apologize, I have changed my review as I wrote this due to some details I overlooked as I was writing. Thank you mjolka for pointing out my mistake.

For multiple collections, the new implementations will always be a win over the aggregation. In the aggregated approach, each sub-instersection needs to be determined for every collection. In our new impleme

Code Snippets

public static IEnumerable<TSource> IntersectAll<TSource>(
        this IEnumerable<IEnumerable<TSource>> source)
{
    using (var enumerator = source.GetEnumerator())
    {
        if (!enumerator.MoveNext())
            yield break;

        var set = new HashSet<TSource>(enumerator.Current);
        while (enumerator.MoveNext())
            set.IntersectWith(enumerator.Current);
        foreach (var item in set)
            yield return item;
    }
}
public static IEnumerable<TSource> IntersectAll1<TSource>(this IEnumerable<IEnumerable<TSource>> source)
{   
    using (IEnumerator<IEnumerable<TSource>> sourceIterator = source.GetEnumerator())
    {
        //if there is no first element return an empty list
        if (!sourceIterator.MoveNext())
        {
            yield break;
        }

        IEnumerable<TSource> firstSequence = sourceIterator.Current;
        if (sourceIterator.MoveNext())
        {
            //create a hashset of the first sequence
            HashSet<TSource> hashSet = new HashSet<TSource>(firstSequence);

            //intersect the other sequences
            do
            {                       
                hashSet.IntersectWith(sourceIterator.Current);
            }
            while (sourceIterator.MoveNext());
            foreach (TSource element in hashSet)
            {
                yield return element;
            }
        }
        else
        {                   
            //if there is only one sequence, return it without any intersection
            foreach (TSource element in firstSequence)
            {
                yield return element;
            }
        }
    }
}
public static IEnumerable<TSource> IntersectAll2<TSource>(
        this IEnumerable<IEnumerable<TSource>> source)
{
    using (var enumerator = source.GetEnumerator())
    {
        if (!enumerator.MoveNext())
            yield break;

        var set = new HashSet<TSource>(enumerator.Current);
        while (enumerator.MoveNext())
            set.IntersectWith(enumerator.Current);
        foreach (var item in set)
            yield return item;
    }
}
public static IEnumerable<TSource> IntersectAllControl<TSource>(
        this IEnumerable<IEnumerable<TSource>> source)
{
    return source.Aggregate((x, y) => x.Intersect(y));
}
void Main()
{
    var rng = new Random();
    var collections = Enumerable.Range(0, 500)
        .Select(i => Enumerable.Range(rng.Next(0, 50), 100).ToArray())
        .Cast<IEnumerable<int>>()
        //.Take(1)
        .ToArray();

    const int Iterations = 1000; // 1000000 for single collection

    object result1 = null;
    var timer1 = Stopwatch.StartNew();
    for (var i = 0; i < Iterations; i++)
        result1 = collections.IntersectAll1().ToList();
    timer1.Stop();
    result1.Dump(String.Format("result1: {0}", timer1.Elapsed));

    object result2 = null;
    var timer2 = Stopwatch.StartNew();
    for (var i = 0; i < Iterations; i++)
        result2 = collections.IntersectAll2().ToList();
    timer2.Stop();
    result2.Dump(String.Format("result2: {0}", timer2.Elapsed));

    object controlResult = null;
    var controlTimer = Stopwatch.StartNew();
    for (var i = 0; i < Iterations; i++)
        controlResult = collections.IntersectAllControl().ToList();
    controlTimer.Stop();
    controlResult.Dump(String.Format("controlResult: {0}", controlTimer.Elapsed));
}

Context

StackExchange Code Review Q#61627, answer score: 6

Revisions (0)

No revisions yet.