HiveBrain v1.2.0
Get Started
← Back to all entries
patterncsharpMinor

Implementing a fast DBScan in C#

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
fastimplementingdbscan

Problem

I tried to implement a DBScan in C# using kd-trees. I followed the implementation from here.

```
public class DbscanAlgorithm
{
private readonly Func _metricFunc;

public DbscanAlgorithm(Func metricFunc)
{
_metricFunc = metricFunc;
}

public void ComputeClusterDbscan(ScanPoint[] allPoints, double epsilon, int minPts, out HashSet clusters)
{
clusters = null;
var allPointsDbscan = allPoints.Select(x => new DbscanPoint(x)).ToArray();

var tree = new KDTree.KDTree(2);
for (var i = 0; i (
allPointsDbscan
.Where(x => x.ClusterId > 0)
.GroupBy(x => x.ClusterId)
.Select(x => x.Select(y => y.ClusterPoint).ToArray())
);

return;
}

private void ExpandCluster(KDTree.KDTree tree, DbscanPoint p, DbscanPoint[] neighborPts, int c, double epsilon, int minPts)
{
p.ClusterId = c;
for (int i = 0; i = minPts)
{
neighborPts = neighborPts.Union(neighborPts2).ToArray();
}
}
if (pn.ClusterId == (int)ClusterIds.UNCLASSIFIED)
pn.ClusterId = c;
}
}

private void RegionQuery(KDTree.KDTree tree, PointD p, double epsilon, out DbscanPoint[] neighborPts)
{
int totalCount = 0;
var pIter = tree.NearestNeighbors(new double[] { p.X, p.Y }, 10, epsilon);
while (pIter.MoveNext())
{
totalCount++;
}
neighborPts = new DbscanPoint[totalCount];
int currCount = 0;
pIter.Reset();
while (pIter.MoveNext())
{
neighborPts[currCount] = pIter.Current;
currCount++;
}

return;
}
}

//Dbscan clustering identifiers
public enum ClusterIds
{
UNCLASSIFIED = 0,
NOISE = -1
}

//Point container for Dbscan clustering
public class DbscanPoint
{
public bool IsVisited;
public ScanPoint ClusterPoint;

Solution

_metricFunc is unused, which means it can either be removed, or there's a bug in the program.

The first line in ComputeClusterDbscan, clusters = null;, is superfluous and can be removed.

The use of out parameters can be avoided by just returning a value.

Methods that can be marked static should be marked static.

In RegionQuery, it is probably faster to iterate over nearest neighbours just one, like so:

private static DbscanPoint[] RegionQuery(KDTree tree, PointD p, double epsilon)
{
    var neighbors = new List();
    var e = tree.NearestNeighbors(new[] { p.X, p.Y }, 10, epsilon);
    while (e.MoveNext())
    {
        neighbors.Add(e.Current);
    }

    return neighbors.ToArray();
}


I believe the bottleneck in your program is this line in ExpandCluster:

neighborPts = neighborPts.Union(neighborPts2).ToArray();


Try something like this instead:

private static void ExpandCluster(KDTree tree, DbscanPoint p, DbscanPoint[] neighborPts, int c, double epsilon, int minPts)
{
    p.ClusterId = c;

    var queue = new Queue(neighborPts);
    while (queue.Count > 0)
    {
        var point = queue.Dequeue();
        if (point.ClusterId == (int)ClusterIds.UNCLASSIFIED)
        {
            point.ClusterId = c;
        }

        if (point.IsVisited)
        {
            continue;
        }

        point.IsVisited = true;
        var neighbors = RegionQuery(tree, point.ClusterPoint.point, epsilon);
        if (neighbors.Length >= minPts)
        {
            foreach (var neighbor in neighbors.Where(neighbor => !neighbor.IsVisited))
            {
                queue.Enqueue(neighbor);
            }
        }
    }
}

Code Snippets

private static DbscanPoint[] RegionQuery(KDTree<DbscanPoint> tree, PointD p, double epsilon)
{
    var neighbors = new List<DbscanPoint>();
    var e = tree.NearestNeighbors(new[] { p.X, p.Y }, 10, epsilon);
    while (e.MoveNext())
    {
        neighbors.Add(e.Current);
    }

    return neighbors.ToArray();
}
neighborPts = neighborPts.Union(neighborPts2).ToArray();
private static void ExpandCluster(KDTree<DbscanPoint> tree, DbscanPoint p, DbscanPoint[] neighborPts, int c, double epsilon, int minPts)
{
    p.ClusterId = c;

    var queue = new Queue<DbscanPoint>(neighborPts);
    while (queue.Count > 0)
    {
        var point = queue.Dequeue();
        if (point.ClusterId == (int)ClusterIds.UNCLASSIFIED)
        {
            point.ClusterId = c;
        }

        if (point.IsVisited)
        {
            continue;
        }

        point.IsVisited = true;
        var neighbors = RegionQuery(tree, point.ClusterPoint.point, epsilon);
        if (neighbors.Length >= minPts)
        {
            foreach (var neighbor in neighbors.Where(neighbor => !neighbor.IsVisited))
            {
                queue.Enqueue(neighbor);
            }
        }
    }
}

Context

StackExchange Code Review Q#108965, answer score: 8

Revisions (0)

No revisions yet.