patterncsharpMinor
Producing the intersection of several sequences
Viewed 0 times
theproducingseveralintersectionsequences
Problem
Based on this SO answer I have created a method that produce the set intersection of several sequences:
I'm particularly annoyed that the method is lot of times slower than the following implementation, when there is only one sequence:
I already put a special test and code path when there is only one element, but it does not seems to be enough. Beside this problem, is there any obvious way to optimize the method?
private static IEnumerable Intersect(this IEnumerable> source)
{
using (IEnumerator> sourceIterator = source.GetEnumerator())
{
//if there is no first element return an empty list
if (!sourceIterator.MoveNext())
{
yield break;
}
IEnumerable firstSequence = sourceIterator.Current;
if (sourceIterator.MoveNext())
{
//create a hashset of the first sequence
HashSet hashSet = new HashSet(firstSequence);
//intersect the other sequences
do
{
hashSet.IntersectWith(sourceIterator.Current);
}
while (sourceIterator.MoveNext());
foreach (TSource element in hashSet)
{
yield return element;
}
}
else
{
//if there is only one sequence, return it without any intersection
foreach (TSource element in firstSequence)
{
yield return element;
}
}
}
}I'm particularly annoyed that the method is lot of times slower than the following implementation, when there is only one sequence:
public static IEnumerable
Intersect(this IEnumerable> source)
{
return source.Aggregate((x, y) => x.Intersect(y));
}I already put a special test and code path when there is only one element, but it does not seems to be enough. Beside this problem, is there any obvious way to optimize the method?
Solution
First let me point out that you shouldn't name this extension method as you have named:
I feel your implementation for the case of a single collection is wrong. You are performing a set operation here. You need to ensure that for all cases, you maintain set semantics. You assume that for a single collection that you may return that collection. But what if that collection contains duplicates? You would need to filter out those duplicates. The aggregated version should as well.
With that in mind, I would change your implementation like so:
I'm not sure how you did your testing but even with your implementation, I still get much faster results, even in the single collection case. Tested in LINQPad with optimizations enabled:
Your implementation:
Note: The
Run on my machine with multiple collections:
result1: 00:00:01.0210291
result2: 00:00:01.0285069
controlResult: 00:00:03.1512838
And with a single collection:
result1: 00:00:02.9254441
result2: 00:00:04.8505489
controlResult: 00:00:00.2102433
I have to apologize, I have changed my review as I wrote this due to some details I overlooked as I was writing. Thank you mjolka for pointing out my mistake.
For multiple collections, the new implementations will always be a win over the aggregation. In the aggregated approach, each sub-instersection needs to be determined for every collection. In our new impleme
Intersect(). The name Intersect() already exists and will return the intersection of two collections. Your method is intersecting a collection of collections with each other. It should must be named differently. I would suggest IntersectAll() or possibly IntersectMany().I feel your implementation for the case of a single collection is wrong. You are performing a set operation here. You need to ensure that for all cases, you maintain set semantics. You assume that for a single collection that you may return that collection. But what if that collection contains duplicates? You would need to filter out those duplicates. The aggregated version should as well.
With that in mind, I would change your implementation like so:
public static IEnumerable IntersectAll(
this IEnumerable> source)
{
using (var enumerator = source.GetEnumerator())
{
if (!enumerator.MoveNext())
yield break;
var set = new HashSet(enumerator.Current);
while (enumerator.MoveNext())
set.IntersectWith(enumerator.Current);
foreach (var item in set)
yield return item;
}
}I'm not sure how you did your testing but even with your implementation, I still get much faster results, even in the single collection case. Tested in LINQPad with optimizations enabled:
Your implementation:
public static IEnumerable IntersectAll1(this IEnumerable> source)
{
using (IEnumerator> sourceIterator = source.GetEnumerator())
{
//if there is no first element return an empty list
if (!sourceIterator.MoveNext())
{
yield break;
}
IEnumerable firstSequence = sourceIterator.Current;
if (sourceIterator.MoveNext())
{
//create a hashset of the first sequence
HashSet hashSet = new HashSet(firstSequence);
//intersect the other sequences
do
{
hashSet.IntersectWith(sourceIterator.Current);
}
while (sourceIterator.MoveNext());
foreach (TSource element in hashSet)
{
yield return element;
}
}
else
{
//if there is only one sequence, return it without any intersection
foreach (TSource element in firstSequence)
{
yield return element;
}
}
}
}public static IEnumerable IntersectAll2(
this IEnumerable> source)
{
using (var enumerator = source.GetEnumerator())
{
if (!enumerator.MoveNext())
yield break;
var set = new HashSet(enumerator.Current);
while (enumerator.MoveNext())
set.IntersectWith(enumerator.Current);
foreach (var item in set)
yield return item;
}
}public static IEnumerable IntersectAllControl(
this IEnumerable> source)
{
return source.Aggregate((x, y) => x.Intersect(y));
}void Main()
{
var rng = new Random();
var collections = Enumerable.Range(0, 500)
.Select(i => Enumerable.Range(rng.Next(0, 50), 100).ToArray())
.Cast>()
//.Take(1)
.ToArray();
const int Iterations = 1000; // 1000000 for single collection
object result1 = null;
var timer1 = Stopwatch.StartNew();
for (var i = 0; i < Iterations; i++)
result1 = collections.IntersectAll1().ToList();
timer1.Stop();
result1.Dump(String.Format("result1: {0}", timer1.Elapsed));
object result2 = null;
var timer2 = Stopwatch.StartNew();
for (var i = 0; i < Iterations; i++)
result2 = collections.IntersectAll2().ToList();
timer2.Stop();
result2.Dump(String.Format("result2: {0}", timer2.Elapsed));
object controlResult = null;
var controlTimer = Stopwatch.StartNew();
for (var i = 0; i < Iterations; i++)
controlResult = collections.IntersectAllControl().ToList();
controlTimer.Stop();
controlResult.Dump(String.Format("controlResult: {0}", controlTimer.Elapsed));
}Note: The
ToList() calls were needed on all invocations to ensure that the intersections are actually generated.Run on my machine with multiple collections:
result1: 00:00:01.0210291
result2: 00:00:01.0285069
controlResult: 00:00:03.1512838
And with a single collection:
result1: 00:00:02.9254441
result2: 00:00:04.8505489
controlResult: 00:00:00.2102433
I have to apologize, I have changed my review as I wrote this due to some details I overlooked as I was writing. Thank you mjolka for pointing out my mistake.
For multiple collections, the new implementations will always be a win over the aggregation. In the aggregated approach, each sub-instersection needs to be determined for every collection. In our new impleme
Code Snippets
public static IEnumerable<TSource> IntersectAll<TSource>(
this IEnumerable<IEnumerable<TSource>> source)
{
using (var enumerator = source.GetEnumerator())
{
if (!enumerator.MoveNext())
yield break;
var set = new HashSet<TSource>(enumerator.Current);
while (enumerator.MoveNext())
set.IntersectWith(enumerator.Current);
foreach (var item in set)
yield return item;
}
}public static IEnumerable<TSource> IntersectAll1<TSource>(this IEnumerable<IEnumerable<TSource>> source)
{
using (IEnumerator<IEnumerable<TSource>> sourceIterator = source.GetEnumerator())
{
//if there is no first element return an empty list
if (!sourceIterator.MoveNext())
{
yield break;
}
IEnumerable<TSource> firstSequence = sourceIterator.Current;
if (sourceIterator.MoveNext())
{
//create a hashset of the first sequence
HashSet<TSource> hashSet = new HashSet<TSource>(firstSequence);
//intersect the other sequences
do
{
hashSet.IntersectWith(sourceIterator.Current);
}
while (sourceIterator.MoveNext());
foreach (TSource element in hashSet)
{
yield return element;
}
}
else
{
//if there is only one sequence, return it without any intersection
foreach (TSource element in firstSequence)
{
yield return element;
}
}
}
}public static IEnumerable<TSource> IntersectAll2<TSource>(
this IEnumerable<IEnumerable<TSource>> source)
{
using (var enumerator = source.GetEnumerator())
{
if (!enumerator.MoveNext())
yield break;
var set = new HashSet<TSource>(enumerator.Current);
while (enumerator.MoveNext())
set.IntersectWith(enumerator.Current);
foreach (var item in set)
yield return item;
}
}public static IEnumerable<TSource> IntersectAllControl<TSource>(
this IEnumerable<IEnumerable<TSource>> source)
{
return source.Aggregate((x, y) => x.Intersect(y));
}void Main()
{
var rng = new Random();
var collections = Enumerable.Range(0, 500)
.Select(i => Enumerable.Range(rng.Next(0, 50), 100).ToArray())
.Cast<IEnumerable<int>>()
//.Take(1)
.ToArray();
const int Iterations = 1000; // 1000000 for single collection
object result1 = null;
var timer1 = Stopwatch.StartNew();
for (var i = 0; i < Iterations; i++)
result1 = collections.IntersectAll1().ToList();
timer1.Stop();
result1.Dump(String.Format("result1: {0}", timer1.Elapsed));
object result2 = null;
var timer2 = Stopwatch.StartNew();
for (var i = 0; i < Iterations; i++)
result2 = collections.IntersectAll2().ToList();
timer2.Stop();
result2.Dump(String.Format("result2: {0}", timer2.Elapsed));
object controlResult = null;
var controlTimer = Stopwatch.StartNew();
for (var i = 0; i < Iterations; i++)
controlResult = collections.IntersectAllControl().ToList();
controlTimer.Stop();
controlResult.Dump(String.Format("controlResult: {0}", controlTimer.Elapsed));
}Context
StackExchange Code Review Q#61627, answer score: 6
Revisions (0)
No revisions yet.