patternrustModerate
K-Means in Rust
Viewed 0 times
rustmeansstackoverflow
Problem
I have implemented for learning purposes a simple K-Means clustering algorithm in Rust. For those who are not familiar: you are given
To do so, you start with
You can stop after a fixed number of iterations, or after some convergence.
Here is my implementation, with the help of SO. For some reasons, the code runs slower than an equivalent algorithm written in Scala. I think I might be introducing some unnecessary copying or other hidden overhead, but I am not familiar enough with Rust to tell.
Just to be clear: I am not interested in changing algorithm (I want to compare apples to apples), and I would rather have idiomatic Rust than hyper-optimized code.
```
use std::collections::TreeMap;
use point::Point;
fn dist(v: Point, w: Point) -> f64 { (v - w).norm() }
fn avg(points: & Vec) -> Point {
let Point(x, y) = points.iter().fold(Point(0.0, 0.0), |p, &q| p + q);
let k = points.len() as f64;
Point(x / k, y / k)
}
fn closest(x: Point, ys: & Vec) -> Point {
let y0 = ys[0];
let d0 = dist(y0, x);
let (_, y) = ys.iter().fold((d0, y0),
|(m, p), &q| {
let d = dist(q, x);
if d , centroids: & Vec) -> Vec> {
let mut groups: TreeMap> = TreeMap::new();
for x in xs.iter() {
let y = closest(*x, centroids);
let should_insert = match groups.find_mut(&y) {
Some(val) => {
val.push(*x);
false
},
None => true
};
if should_insert {
groups.insert(y, vec![*x]);
}
}
groups.into_iter().map(|(_, v)| v).collect::>>()
}
pub fn run(points: & Vec, n: uint, iters: uint) -> Vec> {
let mut centroids: Vec = Vec::from_fn(n, |i| points[i]);
for
N points, say in the plane, and you want to group them in n clusters of nearby points.To do so, you start with
n random points, for instance the first n of the given ones. Call these centroids. At each iteration:- you group the
Npoints based on the nearest centroid
- you produce a new set of centroids as the average of the groups of the preceding step
You can stop after a fixed number of iterations, or after some convergence.
Here is my implementation, with the help of SO. For some reasons, the code runs slower than an equivalent algorithm written in Scala. I think I might be introducing some unnecessary copying or other hidden overhead, but I am not familiar enough with Rust to tell.
Just to be clear: I am not interested in changing algorithm (I want to compare apples to apples), and I would rather have idiomatic Rust than hyper-optimized code.
```
use std::collections::TreeMap;
use point::Point;
fn dist(v: Point, w: Point) -> f64 { (v - w).norm() }
fn avg(points: & Vec) -> Point {
let Point(x, y) = points.iter().fold(Point(0.0, 0.0), |p, &q| p + q);
let k = points.len() as f64;
Point(x / k, y / k)
}
fn closest(x: Point, ys: & Vec) -> Point {
let y0 = ys[0];
let d0 = dist(y0, x);
let (_, y) = ys.iter().fold((d0, y0),
|(m, p), &q| {
let d = dist(q, x);
if d , centroids: & Vec) -> Vec> {
let mut groups: TreeMap> = TreeMap::new();
for x in xs.iter() {
let y = closest(*x, centroids);
let should_insert = match groups.find_mut(&y) {
Some(val) => {
val.push(*x);
false
},
None => true
};
if should_insert {
groups.insert(y, vec![*x]);
}
}
groups.into_iter().map(|(_, v)| v).collect::>>()
}
pub fn run(points: & Vec, n: uint, iters: uint) -> Vec> {
let mut centroids: Vec = Vec::from_fn(n, |i| points[i]);
for
Solution
I took your code and got it to compile with my version of Rust (
I ran with your parameters (I hope I understood them correctly) and got an average time of 207.8 ms.
I made a few changes here and there, but the main thing was changing
The main problem with using a
In your case, I think it's safe to assume that all
I didn't make any other changes, as 20% of your previous time seems like a good start.
Comparison numbers
Here are the numbers I got from running various versions from your suite. Hopefully this gives some base comparison across machines.
rustc 0.13.0-dev (29ad8539b 2014-12-24 16:21:23 +0000)).I ran with your parameters (I hope I understood them correctly) and got an average time of 207.8 ms.
I made a few changes here and there, but the main thing was changing
TreeMap to HashMap. TreeMap doesn't exist anymore, only BTreeMap. At the same time, I switched to using entry, which avoids doing multiple lookups in the hash to add a value if it is missing.The main problem with using a
HashMap is that f64 isn't hashable. This is for good reason - floating point numbers have lots of edge cases (like 16 million NaN values!). I don't know how Scala or Python deal with these values, so I did the simplest thing: ignored them. We just cast our 64-bit float to a 64-bit uint and away we go!In your case, I think it's safe to assume that all
Points will have non-infinite and non-NaN values (and maybe we can ignore every other detail of floating point?). However, it might be worth adding a constructor that validates these assumptions and dies if they aren't true.I didn't make any other changes, as 20% of your previous time seems like a good start.
extern crate time;
use std::collections::HashMap;
use std::collections::hash_map::Entry::{Occupied,Vacant};
use std::hash::Hash;
use std::hash::sip::SipState;
use std::mem;
use std::num::Float;
use std::rand::{Rng,StdRng,SeedableRng};
#[deriving(Show, PartialEq, Copy, Clone)]
pub struct Point(pub f64, pub f64);
fn sq(x: f64) -> f64 { x * x }
// This needs to have a guarantee that x and y will never be an
// Infinity or NaN
impl Point {
pub fn norm(self: &Point) -> f64 {
let Point(x, y) = *self;
(sq(x) + sq(y)).sqrt()
}
}
impl Hash for Point {
fn hash(&self, state: &mut SipState) {
// Perform a bit-wise transform, relying on the fact that we
// are never Infinity or NaN
let Point(x, y) = *self;
let x: u64 = unsafe { mem::transmute(x) };
let y: u64 = unsafe { mem::transmute(y) };
x.hash(state);
y.hash(state);
}
}
impl Add for Point {
fn add(self, other: Point) -> Point {
let Point(a, b) = self;
let Point(c, d) = other;
Point(a + c, b + d)
}
}
impl Sub for Point {
fn sub(self, other: Point) -> Point {
let Point(a, b) = self;
let Point(c, d) = other;
Point(a - c, b - d)
}
}
impl Eq for Point {}
fn dist(v: Point, w: Point) -> f64 { (v - w).norm() }
fn avg(points: & Vec) -> Point {
let Point(x, y) = points.iter().fold(Point(0.0, 0.0), |p, &q| p + q);
let k = points.len() as f64;
Point(x / k, y / k)
}
fn closest(x: Point, ys: & Vec) -> Point {
let y0 = ys[0];
let d0 = dist(y0, x);
let (_, y) = ys.iter().fold((d0, y0),
|(m, p), &q| {
let d = dist(q, x);
if d , centroids: & Vec) -> Vec> {
let mut groups: HashMap> = HashMap::new();
for x in xs.iter() {
let y = closest(*x, centroids);
// Notable change: avoid double hash lookups
match groups.entry(y) {
Occupied(entry) => entry.into_mut().push(*x),
Vacant(entry) => { entry.set(vec![*x]); () },
}
}
groups.into_iter().map(|(_, v)| v).collect::>>()
}
pub fn run(points: & Vec, n: uint, iters: uint) -> Vec> {
let mut centroids: Vec = Vec::from_fn(n, |i| points[i]);
for _ in range(0, iters) {
centroids = clusters(points, & centroids).iter().map(|g| avg(g)).collect();
}
clusters(points, & centroids)
}
fn main() {
let seed: &[_] = &[1, 2, 3, 4];
let mut rng: StdRng = SeedableRng::from_seed(seed);
let points = Vec::from_fn(100000, |_| Point(rng.gen(), rng.gen()));
println!("Made {} points: {}", points.len(), points.slice_to(3));
let repeat_count = 20u;
let mut total = 0;
for _ in range(0, repeat_count) {
let start = time::precise_time_ns();
let res = run(&points, 10, 15);
let end = time::precise_time_ns();
total += end - start
}
let avg_ns: f64 = total as f64 / repeat_count as f64;
let avg_ms = avg_ns / 1.0e6;
println!("{} runs, avg {}", repeat_count, avg_ms);
}Comparison numbers
Here are the numbers I got from running various versions from your suite. Hopefully this gives some base comparison across machines.
$ cargo run --release
The average time is 209.44
$ cargo run
The average time is 2103.81
$ python kmeans.py
Made 100 iterations with an average of 13504.76 milliseconds
$ pypy kmeans.py
Made 100 iterations with an average of 608.97 milliseconds
$ lein trampoline run
The average time was 2799.54 ms
$ node kmeans.js
Running 100 iterations as required 4741.34 msCode Snippets
extern crate time;
use std::collections::HashMap;
use std::collections::hash_map::Entry::{Occupied,Vacant};
use std::hash::Hash;
use std::hash::sip::SipState;
use std::mem;
use std::num::Float;
use std::rand::{Rng,StdRng,SeedableRng};
#[deriving(Show, PartialEq, Copy, Clone)]
pub struct Point(pub f64, pub f64);
fn sq(x: f64) -> f64 { x * x }
// This needs to have a guarantee that x and y will never be an
// Infinity or NaN
impl Point {
pub fn norm(self: &Point) -> f64 {
let Point(x, y) = *self;
(sq(x) + sq(y)).sqrt()
}
}
impl Hash for Point {
fn hash(&self, state: &mut SipState) {
// Perform a bit-wise transform, relying on the fact that we
// are never Infinity or NaN
let Point(x, y) = *self;
let x: u64 = unsafe { mem::transmute(x) };
let y: u64 = unsafe { mem::transmute(y) };
x.hash(state);
y.hash(state);
}
}
impl Add<Point, Point> for Point {
fn add(self, other: Point) -> Point {
let Point(a, b) = self;
let Point(c, d) = other;
Point(a + c, b + d)
}
}
impl Sub<Point, Point> for Point {
fn sub(self, other: Point) -> Point {
let Point(a, b) = self;
let Point(c, d) = other;
Point(a - c, b - d)
}
}
impl Eq for Point {}
fn dist(v: Point, w: Point) -> f64 { (v - w).norm() }
fn avg(points: & Vec<Point>) -> Point {
let Point(x, y) = points.iter().fold(Point(0.0, 0.0), |p, &q| p + q);
let k = points.len() as f64;
Point(x / k, y / k)
}
fn closest(x: Point, ys: & Vec<Point>) -> Point {
let y0 = ys[0];
let d0 = dist(y0, x);
let (_, y) = ys.iter().fold((d0, y0),
|(m, p), &q| {
let d = dist(q, x);
if d < m { (d, q) } else { (m, p) }
}
);
y
}
fn clusters(xs: & Vec<Point>, centroids: & Vec<Point>) -> Vec<Vec<Point>> {
let mut groups: HashMap<Point, Vec<Point>> = HashMap::new();
for x in xs.iter() {
let y = closest(*x, centroids);
// Notable change: avoid double hash lookups
match groups.entry(y) {
Occupied(entry) => entry.into_mut().push(*x),
Vacant(entry) => { entry.set(vec![*x]); () },
}
}
groups.into_iter().map(|(_, v)| v).collect::<Vec<Vec<Point>>>()
}
pub fn run(points: & Vec<Point>, n: uint, iters: uint) -> Vec<Vec<Point>> {
let mut centroids: Vec<Point> = Vec::from_fn(n, |i| points[i]);
for _ in range(0, iters) {
centroids = clusters(points, & centroids).iter().map(|g| avg(g)).collect();
}
clusters(points, & centroids)
}
fn main() {
let seed: &[_] = &[1, 2, 3, 4];
let mut rng: StdRng = SeedableRng::from_seed(seed);
let points = Vec::from_fn(100000, |_| Point(rng.gen(), rng.gen()));
println!("Made {} points: {}", points.len(), points.slice_to(3));
let repeat_count = 20u;
$ cargo run --release
The average time is 209.44
$ cargo run
The average time is 2103.81
$ python kmeans.py
Made 100 iterations with an average of 13504.76 milliseconds
$ pypy kmeans.py
Made 100 iterations with an average of 608.97 milliseconds
$ lein trampoline run
The average time was 2799.54 ms
$ node kmeans.js
Running 100 iterations as required 4741.34 msContext
StackExchange Code Review Q#67577, answer score: 11
Revisions (0)
No revisions yet.