patterncsharpMinor
Implementing a fast DBScan in C#
Viewed 0 times
fastimplementingdbscan
Problem
I tried to implement a DBScan in C# using kd-trees. I followed the implementation from here.
```
public class DbscanAlgorithm
{
private readonly Func _metricFunc;
public DbscanAlgorithm(Func metricFunc)
{
_metricFunc = metricFunc;
}
public void ComputeClusterDbscan(ScanPoint[] allPoints, double epsilon, int minPts, out HashSet clusters)
{
clusters = null;
var allPointsDbscan = allPoints.Select(x => new DbscanPoint(x)).ToArray();
var tree = new KDTree.KDTree(2);
for (var i = 0; i (
allPointsDbscan
.Where(x => x.ClusterId > 0)
.GroupBy(x => x.ClusterId)
.Select(x => x.Select(y => y.ClusterPoint).ToArray())
);
return;
}
private void ExpandCluster(KDTree.KDTree tree, DbscanPoint p, DbscanPoint[] neighborPts, int c, double epsilon, int minPts)
{
p.ClusterId = c;
for (int i = 0; i = minPts)
{
neighborPts = neighborPts.Union(neighborPts2).ToArray();
}
}
if (pn.ClusterId == (int)ClusterIds.UNCLASSIFIED)
pn.ClusterId = c;
}
}
private void RegionQuery(KDTree.KDTree tree, PointD p, double epsilon, out DbscanPoint[] neighborPts)
{
int totalCount = 0;
var pIter = tree.NearestNeighbors(new double[] { p.X, p.Y }, 10, epsilon);
while (pIter.MoveNext())
{
totalCount++;
}
neighborPts = new DbscanPoint[totalCount];
int currCount = 0;
pIter.Reset();
while (pIter.MoveNext())
{
neighborPts[currCount] = pIter.Current;
currCount++;
}
return;
}
}
//Dbscan clustering identifiers
public enum ClusterIds
{
UNCLASSIFIED = 0,
NOISE = -1
}
//Point container for Dbscan clustering
public class DbscanPoint
{
public bool IsVisited;
public ScanPoint ClusterPoint;
```
public class DbscanAlgorithm
{
private readonly Func _metricFunc;
public DbscanAlgorithm(Func metricFunc)
{
_metricFunc = metricFunc;
}
public void ComputeClusterDbscan(ScanPoint[] allPoints, double epsilon, int minPts, out HashSet clusters)
{
clusters = null;
var allPointsDbscan = allPoints.Select(x => new DbscanPoint(x)).ToArray();
var tree = new KDTree.KDTree(2);
for (var i = 0; i (
allPointsDbscan
.Where(x => x.ClusterId > 0)
.GroupBy(x => x.ClusterId)
.Select(x => x.Select(y => y.ClusterPoint).ToArray())
);
return;
}
private void ExpandCluster(KDTree.KDTree tree, DbscanPoint p, DbscanPoint[] neighborPts, int c, double epsilon, int minPts)
{
p.ClusterId = c;
for (int i = 0; i = minPts)
{
neighborPts = neighborPts.Union(neighborPts2).ToArray();
}
}
if (pn.ClusterId == (int)ClusterIds.UNCLASSIFIED)
pn.ClusterId = c;
}
}
private void RegionQuery(KDTree.KDTree tree, PointD p, double epsilon, out DbscanPoint[] neighborPts)
{
int totalCount = 0;
var pIter = tree.NearestNeighbors(new double[] { p.X, p.Y }, 10, epsilon);
while (pIter.MoveNext())
{
totalCount++;
}
neighborPts = new DbscanPoint[totalCount];
int currCount = 0;
pIter.Reset();
while (pIter.MoveNext())
{
neighborPts[currCount] = pIter.Current;
currCount++;
}
return;
}
}
//Dbscan clustering identifiers
public enum ClusterIds
{
UNCLASSIFIED = 0,
NOISE = -1
}
//Point container for Dbscan clustering
public class DbscanPoint
{
public bool IsVisited;
public ScanPoint ClusterPoint;
Solution
_metricFunc is unused, which means it can either be removed, or there's a bug in the program.The first line in
ComputeClusterDbscan, clusters = null;, is superfluous and can be removed.The use of
out parameters can be avoided by just returning a value.Methods that can be marked
static should be marked static.In
RegionQuery, it is probably faster to iterate over nearest neighbours just one, like so:private static DbscanPoint[] RegionQuery(KDTree tree, PointD p, double epsilon)
{
var neighbors = new List();
var e = tree.NearestNeighbors(new[] { p.X, p.Y }, 10, epsilon);
while (e.MoveNext())
{
neighbors.Add(e.Current);
}
return neighbors.ToArray();
}I believe the bottleneck in your program is this line in
ExpandCluster:neighborPts = neighborPts.Union(neighborPts2).ToArray();Try something like this instead:
private static void ExpandCluster(KDTree tree, DbscanPoint p, DbscanPoint[] neighborPts, int c, double epsilon, int minPts)
{
p.ClusterId = c;
var queue = new Queue(neighborPts);
while (queue.Count > 0)
{
var point = queue.Dequeue();
if (point.ClusterId == (int)ClusterIds.UNCLASSIFIED)
{
point.ClusterId = c;
}
if (point.IsVisited)
{
continue;
}
point.IsVisited = true;
var neighbors = RegionQuery(tree, point.ClusterPoint.point, epsilon);
if (neighbors.Length >= minPts)
{
foreach (var neighbor in neighbors.Where(neighbor => !neighbor.IsVisited))
{
queue.Enqueue(neighbor);
}
}
}
}Code Snippets
private static DbscanPoint[] RegionQuery(KDTree<DbscanPoint> tree, PointD p, double epsilon)
{
var neighbors = new List<DbscanPoint>();
var e = tree.NearestNeighbors(new[] { p.X, p.Y }, 10, epsilon);
while (e.MoveNext())
{
neighbors.Add(e.Current);
}
return neighbors.ToArray();
}neighborPts = neighborPts.Union(neighborPts2).ToArray();private static void ExpandCluster(KDTree<DbscanPoint> tree, DbscanPoint p, DbscanPoint[] neighborPts, int c, double epsilon, int minPts)
{
p.ClusterId = c;
var queue = new Queue<DbscanPoint>(neighborPts);
while (queue.Count > 0)
{
var point = queue.Dequeue();
if (point.ClusterId == (int)ClusterIds.UNCLASSIFIED)
{
point.ClusterId = c;
}
if (point.IsVisited)
{
continue;
}
point.IsVisited = true;
var neighbors = RegionQuery(tree, point.ClusterPoint.point, epsilon);
if (neighbors.Length >= minPts)
{
foreach (var neighbor in neighbors.Where(neighbor => !neighbor.IsVisited))
{
queue.Enqueue(neighbor);
}
}
}
}Context
StackExchange Code Review Q#108965, answer score: 8
Revisions (0)
No revisions yet.