HiveBrain v1.2.0
Get Started
← Back to all entries
patternrustMinor

K-means clustering in Rust

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
clusteringrustmeans

Problem

I've implemented K-means clustering in Rust. It's my second Rust project (my first one is here: Randomly selecting an adjective and noun, combining them into a message)

I would like advice on whether I am doing stuff idiomatically, and any sensible optimisations I could make. I also appreciate advice on code-style -- I care a lot about having code that is readable and formatted nicely, but I'm still getting the hang of it in Rust.

A brief explanation of the K-means algorithm:

  • You have a set of data you wish to partition into a known number of groups, a.k.a. clusters



  • The mean of all the data belonging to a cluster is the cluster's centroid. You decide which cluster a datum belongs to by selecting the cluster which is closest, i.e. the one at the smallest Euclidean distance from the data point



  • The algorithm works in two iterative steps. First, we initialise the cluster centroids, then:



  • Assign all the data to their nearest centroid



  • Update the cluster centroids so that they are the mean of all the points assigned to them



  • Repeat until a local minimum is found



Here is the contents of lib.rs, now including my unit tests (I realise I could probably have more tests D:):

```
use std::path::Path;

extern crate csv;
extern crate rustc_serialize;

/// Store one data point's (or one cluster centroid's) x and y co-ordinates
#[derive(Clone, Debug, RustcDecodable)]
pub struct DataPoint {
pub x: f64,
pub y: f64,
}

/// Structure for holding data point's assignments to clusters
#[derive(Clone, Debug)]
pub struct Assignment {
data_point: &'a DataPoint,
cluster_ind: usize,
}

pub fn read_data(file_path: P) -> Vec where P: AsRef {
let mut data = vec![];
let mut reader = csv::Reader::from_file(file_path).unwrap();
for data_point in reader.decode() {
let data_point: DataPoint = data_point.unwrap();
data.push(data_point);
}
data
}

pub fn squared_euclidean_distance(point_a: &DataPoint,

Solution


  • DataPoint is a small enough structure, it might as well be Copy.



-
squared_euclidean_distance can be an instance method on DataPoint.

-
where clauses should be placed on a separate line. This makes
it easier to see what constraints a function has.

-
It's rarely needed to declare a container (like a Vec) and
then fill it up manually. Use iterator adapters instead; map and
collect are the big two. Think about iterator adapters any time
you use a for loop that's not just for side effects.

-
Prefer expect over unwrap. For Results, it includes the
underlying error message. In either case, it allows you or the user
to track down the error easier.

-
Reading from a file should have better error handling than
panicking; does tying the library to CSV even make sense? Maybe this
really belongs in the executable?

-
Never take &Vec or &String. &[T] or &str is broader
and gives you everything you need.

-
There are no spaces inside the closure argument list delimiters
(|).

-
Use braces for multi-line closures

-
It doesn't make sense to return 0 for the minimum value of an
empty slice (or zero-length iterator). This is what Option is
made for.

-
A few places that create a Vec from an iterator just to
iterate over it again. It's more efficient to never collect.

-
get is pretty useless in method names. Dropping it loses
nothing.

-
When the first thing a function does is convert to an iterator,
you might as well just accept anything that can be made into an
iterator.

-
Include spaces around a type constructors { and before the }.

-
There's no reason to include parenthesis in Vec.

-
When you have a multi-line function argument list, place the { on a new line.

-
Use iterator adapters instead of cloning a slice / Vec and
then retaining values.

-
There's no reason to return a vector when you just convert it
into an iterator. I used Box for laziness.

-
The code computes points_in_cluster twice. It's probably more
efficient to combine the calculations.

-
There's no reason to make a mutable tuple and then convert into
a struct, just mutate the struct directly.

-
Instead of a for loop, use fold.

-
Implement addition for DataPoint to make it simpler to understand.

-
Does it make sense to return 0, 0 for the sum of an empty slice?

-
Use iter_mut and enumerate instead of poking at the slices
value via array index (iterators are more efficient).

```
use std::path::Path;

extern crate csv;
extern crate rustc_serialize;

/// Store one data point's (or one cluster centroid's) x and y co-ordinates
#[derive(Copy, Clone, Debug, RustcDecodable)]
pub struct DataPoint {
pub x: f64,
pub y: f64,
}

impl DataPoint {
fn zero() -> DataPoint {
DataPoint {
x: 0.0,
y: 0.0,
}
}

pub fn squared_euclidean_distance(&self, other: &DataPoint) -> f64 {
(other.x - self.x).powi(2) + (other.y - self.y).powi(2)
}
}

impl std::ops::Add for DataPoint {
type Output = DataPoint;

fn add(self, other: DataPoint) -> DataPoint {
DataPoint {
x: self.x + other.x,
y: self.y + other.y,
}
}
}

/// Structure for holding data point's assignments to clusters
#[derive(Clone, Debug)]
pub struct Assignment {
data_point: &'a DataPoint,
cluster_ind: usize,
}

pub fn read_data(file_path: P) -> Vec
where P: AsRef
{
let mut reader = csv::Reader::from_file(file_path).unwrap();
reader.decode().map(|point| point.unwrap()).collect()
}

pub fn index_of_min_val(floats: I) -> Option
where I: IntoIterator,
{
let mut iter = floats.into_iter().enumerate();

iter.next().map(|(i, min)| {
iter.fold((i, min), |(min_i, min_val), (i, val)| {
if val (data: &'a [DataPoint],
cluster_centroids: &[DataPoint]) -> Vec>
{
data.iter().map(|point| {
let distances = cluster_centroids.iter().map(|cluster| point.squared_euclidean_distance(cluster));
let index = index_of_min_val(distances).expect("No minimum value found");
Assignment { data_point: point, cluster_ind: index }
}).collect()
}

pub fn count_assignments(assignments: &[Assignment],
cluster_ind: usize) -> usize
{
points_in_cluster(assignments, cluster_ind).count()
}

pub fn points_in_cluster(assignments: &'a [Assignment],
expected_cluster_ind: usize) -> Box> + 'a>
{
let i = assignments.into_iter()
.cloned()
.filter(move |&Assignment { cluster_ind, .. }| expected_cluster_ind == cluster_ind);
Box::new(i)
}

pub fn sum_assigned_values(assignments: &[Assignment],
cluster_ind: usize) -> DataPoint
{
points_in_cluster(assignments, cluster_ind)
.into_iter()
.fold(DataPoint::zero(), |acc, point| acc + *point.data_point)
}

/// Update cluster centres
fn maximisation(cluster_centroids: &mut [DataPoint],

Code Snippets

use std::path::Path;

extern crate csv;
extern crate rustc_serialize;

/// Store one data point's (or one cluster centroid's) x and y co-ordinates
#[derive(Copy, Clone, Debug, RustcDecodable)]
pub struct DataPoint {
    pub x: f64,
    pub y: f64,
}

impl DataPoint {
    fn zero() -> DataPoint {
        DataPoint {
            x: 0.0,
            y: 0.0,
        }
    }

    pub fn squared_euclidean_distance(&self, other: &DataPoint) -> f64 {
        (other.x - self.x).powi(2) + (other.y - self.y).powi(2)
    }
}

impl std::ops::Add for DataPoint {
    type Output = DataPoint;

    fn add(self, other: DataPoint) -> DataPoint {
        DataPoint {
            x: self.x + other.x,
            y: self.y + other.y,
        }
    }
}

/// Structure for holding data point's assignments to clusters
#[derive(Clone, Debug)]
pub struct Assignment<'a> {
    data_point: &'a DataPoint,
    cluster_ind: usize,
}

pub fn read_data<P>(file_path: P) -> Vec<DataPoint>
    where P: AsRef<Path>
{
    let mut reader = csv::Reader::from_file(file_path).unwrap();
    reader.decode().map(|point| point.unwrap()).collect()
}

pub fn index_of_min_val<I>(floats: I) -> Option<usize>
    where I: IntoIterator<Item = f64>,
{
    let mut iter = floats.into_iter().enumerate();

    iter.next().map(|(i, min)| {
        iter.fold((i, min), |(min_i, min_val), (i, val)| {
            if val < min_val {
                (i, val)
            } else {
                (min_i, min_val)
            }
        }).0
    })
}


/// Assign points to clusters
fn expectation<'a>(data: &'a [DataPoint],
                   cluster_centroids: &[DataPoint]) -> Vec<Assignment<'a>>
{
    data.iter().map(|point| {
        let distances = cluster_centroids.iter().map(|cluster| point.squared_euclidean_distance(cluster));
        let index = index_of_min_val(distances).expect("No minimum value found");
        Assignment { data_point: point, cluster_ind: index }
    }).collect()
}

pub fn count_assignments(assignments: &[Assignment],
                         cluster_ind: usize) -> usize
{
    points_in_cluster(assignments, cluster_ind).count()
}


pub fn points_in_cluster<'a>(assignments: &'a [Assignment],
                                 expected_cluster_ind: usize) -> Box<Iterator<Item = Assignment<'a>> + 'a>
{
    let i = assignments.into_iter()
        .cloned()
        .filter(move |&Assignment { cluster_ind, .. }| expected_cluster_ind == cluster_ind);
    Box::new(i)
}

pub fn sum_assigned_values(assignments: &[Assignment],
                           cluster_ind: usize) -> DataPoint
{
    points_in_cluster(assignments, cluster_ind)
        .into_iter()
        .fold(DataPoint::zero(), |acc, point| acc + *point.data_point)
}


/// Update cluster centres
fn maximisation(cluster_centroids: &mut [DataPoint],
                assignments: &[Assignment])
{
    for (i, centroid) in cluster_centroids.iter_mut().enumerate() {
        let num_points = count_assignments(&assignments, i);
        let sum_points = 
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_squared_euclidean_distance_simple_case() {
        let origin = DataPoint { x: 0.0, y: 0.0 };
        let point = DataPoint { x: 1.0, y: 1.0 };
        assert_eq!(2.0, origin.squared_euclidean_distance(&point))
    }

    #[test]
    fn test_squared_euclidean_distance_gives_0_for_same_point() {
        let point_a = DataPoint { x: -999.3, y: 10.5 };
        assert_eq!(0.0, point_a.squared_euclidean_distance(&point_a));
    }

    #[test]
    fn test_index_of_min_val() {
        let floats = vec![0.0_f64, 1.0_f64, 3.0_f64, -5.5_f64];
        assert_eq!(Some(3), index_of_min_val(floats))
    }

    #[test]
    fn test_count_assignments_returns_0_when_no_occurences() {
        let dp = DataPoint { x: 0.0, y: 0.0 };
        let assignments = [Assignment { data_point: &dp, cluster_ind: 0 },
                           Assignment { data_point: &dp, cluster_ind: 0 },
                           Assignment { data_point: &dp, cluster_ind: 1 },
                           Assignment { data_point: &dp, cluster_ind: 5 },
                           Assignment { data_point: &dp, cluster_ind: 0 }];
        assert_eq!(0, count_assignments(&assignments, 4))
    }

    #[test]
    fn test_count_assignments_returns_3_when_3_occurences() {
        let dp = DataPoint { x: 0.0, y: 0.0 };
        let assignments = [Assignment { data_point: &dp, cluster_ind: 0 },
                           Assignment { data_point: &dp, cluster_ind: 0 },
                           Assignment { data_point: &dp, cluster_ind: 1 },
                           Assignment { data_point: &dp, cluster_ind: 5 },
                           Assignment { data_point: &dp, cluster_ind: 0 }];
        assert_eq!(3, count_assignments(&assignments, 0));
    }

    #[test]
    fn test_sum_assigned_values_returns_0_when_none_assigned() {
        let dp = DataPoint { x: 5.0, y: 5.0 };
        let assignments = [Assignment { data_point: &dp, cluster_ind: 0 },
                           Assignment { data_point: &dp, cluster_ind: 0 },
                           Assignment { data_point: &dp, cluster_ind: 1 },
                           Assignment { data_point: &dp, cluster_ind: 5 },
                           Assignment { data_point: &dp, cluster_ind: 0 }];
        assert_eq!(DataPoint { x: 0.0, y: 0.0 }, sum_assigned_values(&assignments, 2))
    }

    #[test]
    fn test_sum_assigned_values_returns_correctly_when_some_assigned() {
        let dp = DataPoint { x: 1.0, y: 1.0 };
        let assignments = [Assignment { data_point: &dp, cluster_ind: 0 },
                           Assignment { data_point: &dp, cluster_ind: 0 },
                           Assignment { data_point: &dp, cluster_ind: 1 },
                           Assignment { data_point: &dp, cluster_ind: 5 },
                           Assignment { data_point: &dp, cluster_ind: 0 }];
        assert_eq!(DataPoint { x: 3.0, y: 3.0 }, sum_assigned_values(&assignments, 0));
    }
}

Context

StackExchange Code Review Q#126303, answer score: 4

Revisions (0)

No revisions yet.