HiveBrain v1.2.0
Get Started
← Back to all entries
patternrustModerate

K-Means in Rust

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
rustmeansstackoverflow

Problem

I have implemented for learning purposes a simple K-Means clustering algorithm in Rust. For those who are not familiar: you are given N points, say in the plane, and you want to group them in n clusters of nearby points.

To do so, you start with n random points, for instance the first n of the given ones. Call these centroids. At each iteration:

  • you group the N points based on the nearest centroid



  • you produce a new set of centroids as the average of the groups of the preceding step



You can stop after a fixed number of iterations, or after some convergence.

Here is my implementation, with the help of SO. For some reasons, the code runs slower than an equivalent algorithm written in Scala. I think I might be introducing some unnecessary copying or other hidden overhead, but I am not familiar enough with Rust to tell.

Just to be clear: I am not interested in changing algorithm (I want to compare apples to apples), and I would rather have idiomatic Rust than hyper-optimized code.

```
use std::collections::TreeMap;
use point::Point;

fn dist(v: Point, w: Point) -> f64 { (v - w).norm() }

fn avg(points: & Vec) -> Point {
let Point(x, y) = points.iter().fold(Point(0.0, 0.0), |p, &q| p + q);
let k = points.len() as f64;

Point(x / k, y / k)
}

fn closest(x: Point, ys: & Vec) -> Point {
let y0 = ys[0];
let d0 = dist(y0, x);
let (_, y) = ys.iter().fold((d0, y0),
|(m, p), &q| {
let d = dist(q, x);
if d , centroids: & Vec) -> Vec> {
let mut groups: TreeMap> = TreeMap::new();

for x in xs.iter() {
let y = closest(*x, centroids);
let should_insert = match groups.find_mut(&y) {
Some(val) => {
val.push(*x);
false
},
None => true
};
if should_insert {
groups.insert(y, vec![*x]);
}
}

groups.into_iter().map(|(_, v)| v).collect::>>()
}

pub fn run(points: & Vec, n: uint, iters: uint) -> Vec> {
let mut centroids: Vec = Vec::from_fn(n, |i| points[i]);

for

Solution

I took your code and got it to compile with my version of Rust (rustc 0.13.0-dev (29ad8539b 2014-12-24 16:21:23 +0000)).

I ran with your parameters (I hope I understood them correctly) and got an average time of 207.8 ms.

I made a few changes here and there, but the main thing was changing TreeMap to HashMap. TreeMap doesn't exist anymore, only BTreeMap. At the same time, I switched to using entry, which avoids doing multiple lookups in the hash to add a value if it is missing.

The main problem with using a HashMap is that f64 isn't hashable. This is for good reason - floating point numbers have lots of edge cases (like 16 million NaN values!). I don't know how Scala or Python deal with these values, so I did the simplest thing: ignored them. We just cast our 64-bit float to a 64-bit uint and away we go!

In your case, I think it's safe to assume that all Points will have non-infinite and non-NaN values (and maybe we can ignore every other detail of floating point?). However, it might be worth adding a constructor that validates these assumptions and dies if they aren't true.

I didn't make any other changes, as 20% of your previous time seems like a good start.

extern crate time;

use std::collections::HashMap;
use std::collections::hash_map::Entry::{Occupied,Vacant};
use std::hash::Hash;
use std::hash::sip::SipState;
use std::mem;
use std::num::Float;
use std::rand::{Rng,StdRng,SeedableRng};

#[deriving(Show, PartialEq, Copy, Clone)]
pub struct Point(pub f64, pub f64);

fn sq(x: f64) -> f64 { x * x }

// This needs to have a guarantee that x and y will never be an
// Infinity or NaN
impl Point {
    pub fn norm(self: &Point) -> f64 {
        let Point(x, y) = *self;
        (sq(x) + sq(y)).sqrt()
    }
}

impl Hash for Point {
    fn hash(&self, state: &mut SipState) {
        // Perform a bit-wise transform, relying on the fact that we
        // are never Infinity or NaN
        let Point(x, y) = *self;
        let x: u64 = unsafe { mem::transmute(x) };
        let y: u64 = unsafe { mem::transmute(y) };
        x.hash(state);
        y.hash(state);
    }
}

impl Add for Point {
    fn add(self, other: Point) -> Point {
        let Point(a, b) = self;
        let Point(c, d) = other;

        Point(a + c, b + d)
    }
}

impl Sub for Point {
    fn sub(self, other: Point) -> Point {
        let Point(a, b) = self;
        let Point(c, d) = other;

        Point(a - c, b - d)
    }
}

impl Eq for Point {}

fn dist(v: Point, w: Point) -> f64 { (v - w).norm() }

fn avg(points: & Vec) -> Point {
    let Point(x, y) = points.iter().fold(Point(0.0, 0.0), |p, &q| p + q);
    let k = points.len() as f64;

    Point(x / k, y / k)
}

fn closest(x: Point, ys: & Vec) -> Point {
    let y0 = ys[0];
    let d0 = dist(y0, x);
    let (_, y) = ys.iter().fold((d0, y0),
                                |(m, p), &q| {
                                    let d = dist(q, x);
                                    if d , centroids: & Vec) -> Vec> {
    let mut groups: HashMap> = HashMap::new();

    for x in xs.iter() {
        let y = closest(*x, centroids);

        // Notable change: avoid double hash lookups
        match groups.entry(y) {
            Occupied(entry) => entry.into_mut().push(*x),
            Vacant(entry) => { entry.set(vec![*x]); () },
        }
    }

    groups.into_iter().map(|(_, v)| v).collect::>>()
}

pub fn run(points: & Vec, n: uint, iters: uint) -> Vec> {
    let mut centroids: Vec = Vec::from_fn(n, |i| points[i]);

    for _ in range(0, iters) {
        centroids = clusters(points, & centroids).iter().map(|g| avg(g)).collect();
    }
    clusters(points, & centroids)
}

fn main() {
    let seed: &[_] = &[1, 2, 3, 4];
    let mut rng: StdRng = SeedableRng::from_seed(seed);
    let points = Vec::from_fn(100000, |_| Point(rng.gen(), rng.gen()));

    println!("Made {} points: {}", points.len(), points.slice_to(3));

    let repeat_count = 20u;
    let mut total = 0;
    for _ in range(0, repeat_count) {
        let start = time::precise_time_ns();
        let res = run(&points, 10, 15);
        let end = time::precise_time_ns();
        total += end - start
    }

    let avg_ns: f64 = total as f64 / repeat_count as f64;
    let avg_ms = avg_ns / 1.0e6;

    println!("{} runs, avg {}", repeat_count, avg_ms);
}


Comparison numbers

Here are the numbers I got from running various versions from your suite. Hopefully this gives some base comparison across machines.

$ cargo run --release
The average time is 209.44

$ cargo run
The average time is 2103.81

$ python kmeans.py
Made 100 iterations with an average of 13504.76 milliseconds

$ pypy kmeans.py
Made 100 iterations with an average of 608.97 milliseconds

$ lein trampoline run
The average time was 2799.54 ms

$ node kmeans.js
Running 100 iterations as required 4741.34 ms

Code Snippets

extern crate time;

use std::collections::HashMap;
use std::collections::hash_map::Entry::{Occupied,Vacant};
use std::hash::Hash;
use std::hash::sip::SipState;
use std::mem;
use std::num::Float;
use std::rand::{Rng,StdRng,SeedableRng};

#[deriving(Show, PartialEq, Copy, Clone)]
pub struct Point(pub f64, pub f64);

fn sq(x: f64) -> f64 { x * x }

// This needs to have a guarantee that x and y will never be an
// Infinity or NaN
impl Point {
    pub fn norm(self: &Point) -> f64 {
        let Point(x, y) = *self;
        (sq(x) + sq(y)).sqrt()
    }
}

impl Hash for Point {
    fn hash(&self, state: &mut SipState) {
        // Perform a bit-wise transform, relying on the fact that we
        // are never Infinity or NaN
        let Point(x, y) = *self;
        let x: u64 = unsafe { mem::transmute(x) };
        let y: u64 = unsafe { mem::transmute(y) };
        x.hash(state);
        y.hash(state);
    }
}

impl Add<Point, Point> for Point {
    fn add(self, other: Point) -> Point {
        let Point(a, b) = self;
        let Point(c, d) = other;

        Point(a + c, b + d)
    }
}

impl Sub<Point, Point> for Point {
    fn sub(self, other: Point) -> Point {
        let Point(a, b) = self;
        let Point(c, d) = other;

        Point(a - c, b - d)
    }
}

impl Eq for Point {}

fn dist(v: Point, w: Point) -> f64 { (v - w).norm() }

fn avg(points: & Vec<Point>) -> Point {
    let Point(x, y) = points.iter().fold(Point(0.0, 0.0), |p, &q| p + q);
    let k = points.len() as f64;

    Point(x / k, y / k)
}

fn closest(x: Point, ys: & Vec<Point>) -> Point {
    let y0 = ys[0];
    let d0 = dist(y0, x);
    let (_, y) = ys.iter().fold((d0, y0),
                                |(m, p), &q| {
                                    let d = dist(q, x);
                                    if d < m { (d, q) } else { (m, p) }
                                }
                                );
    y
}

fn clusters(xs: & Vec<Point>, centroids: & Vec<Point>) -> Vec<Vec<Point>> {
    let mut groups: HashMap<Point, Vec<Point>> = HashMap::new();

    for x in xs.iter() {
        let y = closest(*x, centroids);

        // Notable change: avoid double hash lookups
        match groups.entry(y) {
            Occupied(entry) => entry.into_mut().push(*x),
            Vacant(entry) => { entry.set(vec![*x]); () },
        }
    }

    groups.into_iter().map(|(_, v)| v).collect::<Vec<Vec<Point>>>()
}

pub fn run(points: & Vec<Point>, n: uint, iters: uint) -> Vec<Vec<Point>> {
    let mut centroids: Vec<Point> = Vec::from_fn(n, |i| points[i]);

    for _ in range(0, iters) {
        centroids = clusters(points, & centroids).iter().map(|g| avg(g)).collect();
    }
    clusters(points, & centroids)
}

fn main() {
    let seed: &[_] = &[1, 2, 3, 4];
    let mut rng: StdRng = SeedableRng::from_seed(seed);
    let points = Vec::from_fn(100000, |_| Point(rng.gen(), rng.gen()));

    println!("Made {} points: {}", points.len(), points.slice_to(3));

    let repeat_count = 20u;
    
$ cargo run --release
The average time is 209.44

$ cargo run
The average time is 2103.81

$ python kmeans.py
Made 100 iterations with an average of 13504.76 milliseconds

$ pypy kmeans.py
Made 100 iterations with an average of 608.97 milliseconds

$ lein trampoline run
The average time was 2799.54 ms

$ node kmeans.js
Running 100 iterations as required 4741.34 ms

Context

StackExchange Code Review Q#67577, answer score: 11

Revisions (0)

No revisions yet.