patternrustMinor
K-means clustering in Rust
Viewed 0 times
clusteringrustmeans
Problem
I've implemented K-means clustering in Rust. It's my second Rust project (my first one is here: Randomly selecting an adjective and noun, combining them into a message)
I would like advice on whether I am doing stuff idiomatically, and any sensible optimisations I could make. I also appreciate advice on code-style -- I care a lot about having code that is readable and formatted nicely, but I'm still getting the hang of it in Rust.
A brief explanation of the K-means algorithm:
Here is the contents of lib.rs, now including my unit tests (I realise I could probably have more tests D:):
```
use std::path::Path;
extern crate csv;
extern crate rustc_serialize;
/// Store one data point's (or one cluster centroid's) x and y co-ordinates
#[derive(Clone, Debug, RustcDecodable)]
pub struct DataPoint {
pub x: f64,
pub y: f64,
}
/// Structure for holding data point's assignments to clusters
#[derive(Clone, Debug)]
pub struct Assignment {
data_point: &'a DataPoint,
cluster_ind: usize,
}
pub fn read_data(file_path: P) -> Vec where P: AsRef {
let mut data = vec![];
let mut reader = csv::Reader::from_file(file_path).unwrap();
for data_point in reader.decode() {
let data_point: DataPoint = data_point.unwrap();
data.push(data_point);
}
data
}
pub fn squared_euclidean_distance(point_a: &DataPoint,
I would like advice on whether I am doing stuff idiomatically, and any sensible optimisations I could make. I also appreciate advice on code-style -- I care a lot about having code that is readable and formatted nicely, but I'm still getting the hang of it in Rust.
A brief explanation of the K-means algorithm:
- You have a set of data you wish to partition into a known number of groups, a.k.a. clusters
- The mean of all the data belonging to a cluster is the cluster's centroid. You decide which cluster a datum belongs to by selecting the cluster which is closest, i.e. the one at the smallest Euclidean distance from the data point
- The algorithm works in two iterative steps. First, we initialise the cluster centroids, then:
- Assign all the data to their nearest centroid
- Update the cluster centroids so that they are the mean of all the points assigned to them
- Repeat until a local minimum is found
Here is the contents of lib.rs, now including my unit tests (I realise I could probably have more tests D:):
```
use std::path::Path;
extern crate csv;
extern crate rustc_serialize;
/// Store one data point's (or one cluster centroid's) x and y co-ordinates
#[derive(Clone, Debug, RustcDecodable)]
pub struct DataPoint {
pub x: f64,
pub y: f64,
}
/// Structure for holding data point's assignments to clusters
#[derive(Clone, Debug)]
pub struct Assignment {
data_point: &'a DataPoint,
cluster_ind: usize,
}
pub fn read_data(file_path: P) -> Vec where P: AsRef {
let mut data = vec![];
let mut reader = csv::Reader::from_file(file_path).unwrap();
for data_point in reader.decode() {
let data_point: DataPoint = data_point.unwrap();
data.push(data_point);
}
data
}
pub fn squared_euclidean_distance(point_a: &DataPoint,
Solution
DataPointis a small enough structure, it might as well beCopy.
-
squared_euclidean_distance can be an instance method on DataPoint.-
where clauses should be placed on a separate line. This makesit easier to see what constraints a function has.
-
It's rarely needed to declare a container (like a
Vec) andthen fill it up manually. Use iterator adapters instead;
map andcollect are the big two. Think about iterator adapters any timeyou use a
for loop that's not just for side effects.-
Prefer
expect over unwrap. For Results, it includes theunderlying error message. In either case, it allows you or the user
to track down the error easier.
-
Reading from a file should have better error handling than
panicking; does tying the library to CSV even make sense? Maybe this
really belongs in the executable?
-
Never take
&Vec or &String. &[T] or &str is broaderand gives you everything you need.
-
There are no spaces inside the closure argument list delimiters
(
|).-
Use braces for multi-line closures
-
It doesn't make sense to return 0 for the minimum value of an
empty slice (or zero-length iterator). This is what
Option ismade for.
-
A few places that create a
Vec from an iterator just toiterate over it again. It's more efficient to never collect.
-
get is pretty useless in method names. Dropping it losesnothing.
-
When the first thing a function does is convert to an iterator,
you might as well just accept anything that can be made into an
iterator.
-
Include spaces around a type constructors
{ and before the }.-
There's no reason to include parenthesis in
Vec.-
When you have a multi-line function argument list, place the
{ on a new line.-
Use iterator adapters instead of cloning a slice /
Vec andthen
retaining values.-
There's no reason to return a vector when you just convert it
into an iterator. I used
Box for laziness.-
The code computes
points_in_cluster twice. It's probably moreefficient to combine the calculations.
-
There's no reason to make a mutable tuple and then convert into
a struct, just mutate the struct directly.
-
Instead of a
for loop, use fold.-
Implement addition for
DataPoint to make it simpler to understand.-
Does it make sense to return 0, 0 for the sum of an empty slice?
-
Use
iter_mut and enumerate instead of poking at the slicesvalue via array index (iterators are more efficient).
```
use std::path::Path;
extern crate csv;
extern crate rustc_serialize;
/// Store one data point's (or one cluster centroid's) x and y co-ordinates
#[derive(Copy, Clone, Debug, RustcDecodable)]
pub struct DataPoint {
pub x: f64,
pub y: f64,
}
impl DataPoint {
fn zero() -> DataPoint {
DataPoint {
x: 0.0,
y: 0.0,
}
}
pub fn squared_euclidean_distance(&self, other: &DataPoint) -> f64 {
(other.x - self.x).powi(2) + (other.y - self.y).powi(2)
}
}
impl std::ops::Add for DataPoint {
type Output = DataPoint;
fn add(self, other: DataPoint) -> DataPoint {
DataPoint {
x: self.x + other.x,
y: self.y + other.y,
}
}
}
/// Structure for holding data point's assignments to clusters
#[derive(Clone, Debug)]
pub struct Assignment {
data_point: &'a DataPoint,
cluster_ind: usize,
}
pub fn read_data(file_path: P) -> Vec
where P: AsRef
{
let mut reader = csv::Reader::from_file(file_path).unwrap();
reader.decode().map(|point| point.unwrap()).collect()
}
pub fn index_of_min_val(floats: I) -> Option
where I: IntoIterator,
{
let mut iter = floats.into_iter().enumerate();
iter.next().map(|(i, min)| {
iter.fold((i, min), |(min_i, min_val), (i, val)| {
if val (data: &'a [DataPoint],
cluster_centroids: &[DataPoint]) -> Vec>
{
data.iter().map(|point| {
let distances = cluster_centroids.iter().map(|cluster| point.squared_euclidean_distance(cluster));
let index = index_of_min_val(distances).expect("No minimum value found");
Assignment { data_point: point, cluster_ind: index }
}).collect()
}
pub fn count_assignments(assignments: &[Assignment],
cluster_ind: usize) -> usize
{
points_in_cluster(assignments, cluster_ind).count()
}
pub fn points_in_cluster(assignments: &'a [Assignment],
expected_cluster_ind: usize) -> Box> + 'a>
{
let i = assignments.into_iter()
.cloned()
.filter(move |&Assignment { cluster_ind, .. }| expected_cluster_ind == cluster_ind);
Box::new(i)
}
pub fn sum_assigned_values(assignments: &[Assignment],
cluster_ind: usize) -> DataPoint
{
points_in_cluster(assignments, cluster_ind)
.into_iter()
.fold(DataPoint::zero(), |acc, point| acc + *point.data_point)
}
/// Update cluster centres
fn maximisation(cluster_centroids: &mut [DataPoint],
Code Snippets
use std::path::Path;
extern crate csv;
extern crate rustc_serialize;
/// Store one data point's (or one cluster centroid's) x and y co-ordinates
#[derive(Copy, Clone, Debug, RustcDecodable)]
pub struct DataPoint {
pub x: f64,
pub y: f64,
}
impl DataPoint {
fn zero() -> DataPoint {
DataPoint {
x: 0.0,
y: 0.0,
}
}
pub fn squared_euclidean_distance(&self, other: &DataPoint) -> f64 {
(other.x - self.x).powi(2) + (other.y - self.y).powi(2)
}
}
impl std::ops::Add for DataPoint {
type Output = DataPoint;
fn add(self, other: DataPoint) -> DataPoint {
DataPoint {
x: self.x + other.x,
y: self.y + other.y,
}
}
}
/// Structure for holding data point's assignments to clusters
#[derive(Clone, Debug)]
pub struct Assignment<'a> {
data_point: &'a DataPoint,
cluster_ind: usize,
}
pub fn read_data<P>(file_path: P) -> Vec<DataPoint>
where P: AsRef<Path>
{
let mut reader = csv::Reader::from_file(file_path).unwrap();
reader.decode().map(|point| point.unwrap()).collect()
}
pub fn index_of_min_val<I>(floats: I) -> Option<usize>
where I: IntoIterator<Item = f64>,
{
let mut iter = floats.into_iter().enumerate();
iter.next().map(|(i, min)| {
iter.fold((i, min), |(min_i, min_val), (i, val)| {
if val < min_val {
(i, val)
} else {
(min_i, min_val)
}
}).0
})
}
/// Assign points to clusters
fn expectation<'a>(data: &'a [DataPoint],
cluster_centroids: &[DataPoint]) -> Vec<Assignment<'a>>
{
data.iter().map(|point| {
let distances = cluster_centroids.iter().map(|cluster| point.squared_euclidean_distance(cluster));
let index = index_of_min_val(distances).expect("No minimum value found");
Assignment { data_point: point, cluster_ind: index }
}).collect()
}
pub fn count_assignments(assignments: &[Assignment],
cluster_ind: usize) -> usize
{
points_in_cluster(assignments, cluster_ind).count()
}
pub fn points_in_cluster<'a>(assignments: &'a [Assignment],
expected_cluster_ind: usize) -> Box<Iterator<Item = Assignment<'a>> + 'a>
{
let i = assignments.into_iter()
.cloned()
.filter(move |&Assignment { cluster_ind, .. }| expected_cluster_ind == cluster_ind);
Box::new(i)
}
pub fn sum_assigned_values(assignments: &[Assignment],
cluster_ind: usize) -> DataPoint
{
points_in_cluster(assignments, cluster_ind)
.into_iter()
.fold(DataPoint::zero(), |acc, point| acc + *point.data_point)
}
/// Update cluster centres
fn maximisation(cluster_centroids: &mut [DataPoint],
assignments: &[Assignment])
{
for (i, centroid) in cluster_centroids.iter_mut().enumerate() {
let num_points = count_assignments(&assignments, i);
let sum_points = #[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_squared_euclidean_distance_simple_case() {
let origin = DataPoint { x: 0.0, y: 0.0 };
let point = DataPoint { x: 1.0, y: 1.0 };
assert_eq!(2.0, origin.squared_euclidean_distance(&point))
}
#[test]
fn test_squared_euclidean_distance_gives_0_for_same_point() {
let point_a = DataPoint { x: -999.3, y: 10.5 };
assert_eq!(0.0, point_a.squared_euclidean_distance(&point_a));
}
#[test]
fn test_index_of_min_val() {
let floats = vec![0.0_f64, 1.0_f64, 3.0_f64, -5.5_f64];
assert_eq!(Some(3), index_of_min_val(floats))
}
#[test]
fn test_count_assignments_returns_0_when_no_occurences() {
let dp = DataPoint { x: 0.0, y: 0.0 };
let assignments = [Assignment { data_point: &dp, cluster_ind: 0 },
Assignment { data_point: &dp, cluster_ind: 0 },
Assignment { data_point: &dp, cluster_ind: 1 },
Assignment { data_point: &dp, cluster_ind: 5 },
Assignment { data_point: &dp, cluster_ind: 0 }];
assert_eq!(0, count_assignments(&assignments, 4))
}
#[test]
fn test_count_assignments_returns_3_when_3_occurences() {
let dp = DataPoint { x: 0.0, y: 0.0 };
let assignments = [Assignment { data_point: &dp, cluster_ind: 0 },
Assignment { data_point: &dp, cluster_ind: 0 },
Assignment { data_point: &dp, cluster_ind: 1 },
Assignment { data_point: &dp, cluster_ind: 5 },
Assignment { data_point: &dp, cluster_ind: 0 }];
assert_eq!(3, count_assignments(&assignments, 0));
}
#[test]
fn test_sum_assigned_values_returns_0_when_none_assigned() {
let dp = DataPoint { x: 5.0, y: 5.0 };
let assignments = [Assignment { data_point: &dp, cluster_ind: 0 },
Assignment { data_point: &dp, cluster_ind: 0 },
Assignment { data_point: &dp, cluster_ind: 1 },
Assignment { data_point: &dp, cluster_ind: 5 },
Assignment { data_point: &dp, cluster_ind: 0 }];
assert_eq!(DataPoint { x: 0.0, y: 0.0 }, sum_assigned_values(&assignments, 2))
}
#[test]
fn test_sum_assigned_values_returns_correctly_when_some_assigned() {
let dp = DataPoint { x: 1.0, y: 1.0 };
let assignments = [Assignment { data_point: &dp, cluster_ind: 0 },
Assignment { data_point: &dp, cluster_ind: 0 },
Assignment { data_point: &dp, cluster_ind: 1 },
Assignment { data_point: &dp, cluster_ind: 5 },
Assignment { data_point: &dp, cluster_ind: 0 }];
assert_eq!(DataPoint { x: 3.0, y: 3.0 }, sum_assigned_values(&assignments, 0));
}
}Context
StackExchange Code Review Q#126303, answer score: 4
Revisions (0)
No revisions yet.