HiveBrain v1.2.0
Get Started
← Back to all entries
patternMinor

K-nearest neighbours in MATLAB

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
nearestmatlabneighbours

Problem

I implemented K-Nearest Neighbours algorithm, but my experience using MATLAB is lacking. I need you to check the small portion of code and tell me what can be improved or modified. I hope it is a correct implementation of the algorithm.

function test_data = knn(test_data, tr_data,k)

numoftestdata = size(test_data,1);
numoftrainingdata = size(tr_data,1);

for sample=1:numoftestdata

   %Step 1: Computing euclidean distance for each testdata
   R = repmat(test_data(sample,:),numoftrainingdata,1) ;
   euclideandistance  = (R(:,1) - tr_data(:,1)).^2;

   %Step 2: compute k nearest neighbors and store them in an array
    [dist position] = sort(euclideandistance,'ascend');
    knearestneighbors=position(1:k);
    knearestdistances=dist(1:k);

    % Step 3 : Voting 
    for i=1:k
        A(i) = tr_data(knearestneighbors(i),2);  
    end

    M = mode(A);

    if (M~=1)
        test_data(sample,2) = mode(A);
    else 
        test_data(sample,2) = tr_data(knearestneighbors(1),2);
    end
end


To test it you can use :

  • test_data = [6,0; 2,0; 5,0]



  • tr_data = [1,1;0,2;3,2; 4,4; 5,3]

Solution


  • Use consistent indentation.



  • You switch from extremely verbose all lower case variable names like numoftrainingdata to single letter capitalized variable names like A. Make your variable names descriptive and no longer than necessary, and be consistent.



  • Use consistent white space between operators.



-
knn() doesn't need the second column of test_data, and the calling function doesn't need the first column of test_data.

Rather than calling the function like this:

test_data = knn(test_data,tr_data,k);


Call it like this:

test_data(:,2) = knn(test_data(1,:),tr_data,k);


-
Do you want to handle certain error conditions like 0 >= k or k > size(tr_data,1)?

  • Rather than squaring the distance, you can use abs().



  • Remove the ascend parameter from sort(), that is the default mode.



  • knearestdistances is unused.



  • You call mode() a second time rather than using M.



Simplify:

[dist position] = sort(euclideandistance,'ascend');
knearestneighbors = position(1:k);
knearestdistances = dist(1:k);
for i=1:k
    A(i) = tr_data(knearestneighbors(i),2);
end
M = mode(A);
if (M~=1)
    test_data(sample,2) = mode(A);
else 
    test_data(sample,2) = tr_data(knearestneighbors(1),2);
end


to

[~,position] = sort(euclideandistance);
A = tr_data(position(1:k),2);
M = mode(A);
if (M~=1)
    test_data(sample,2) = M;
else
    test_data(sample,2) = tr_data(position(1),2);
end


After applying the above suggestions and vectorizing the function you could write it as:

function out_data = knn(test_data,tr_data,k)
    test_data_n = size(test_data,1);
    tr_data_n = size(tr_data,1);

    % absolute distance between all test and training data
    dist = abs(repmat(test_data,1,tr_data_n) - repmat(tr_data(:,1)',test_data_n,1));

    % indicies of nearest neighbors
    [~,nearest] = sort(dist,2);
    % k nearest
    nearest = nearest(:,1:k);

    % mode of k nearest
    val = reshape(tr_data(nearest,2),[],k);
    out_data = mode(val,2);
    % if mode is 1, output nearest instead
    out_data(out_data==1) = val(out_data==1,1);
end


Edit

Regarding correctness, i'm not sure why you check to see if the mode is 1. There is nothing unique about a mode of 1 in general.

Code Snippets

test_data = knn(test_data,tr_data,k);
test_data(:,2) = knn(test_data(1,:),tr_data,k);
[dist position] = sort(euclideandistance,'ascend');
knearestneighbors = position(1:k);
knearestdistances = dist(1:k);
for i=1:k
    A(i) = tr_data(knearestneighbors(i),2);
end
M = mode(A);
if (M~=1)
    test_data(sample,2) = mode(A);
else 
    test_data(sample,2) = tr_data(knearestneighbors(1),2);
end
[~,position] = sort(euclideandistance);
A = tr_data(position(1:k),2);
M = mode(A);
if (M~=1)
    test_data(sample,2) = M;
else
    test_data(sample,2) = tr_data(position(1),2);
end
function out_data = knn(test_data,tr_data,k)
    test_data_n = size(test_data,1);
    tr_data_n = size(tr_data,1);

    % absolute distance between all test and training data
    dist = abs(repmat(test_data,1,tr_data_n) - repmat(tr_data(:,1)',test_data_n,1));

    % indicies of nearest neighbors
    [~,nearest] = sort(dist,2);
    % k nearest
    nearest = nearest(:,1:k);

    % mode of k nearest
    val = reshape(tr_data(nearest,2),[],k);
    out_data = mode(val,2);
    % if mode is 1, output nearest instead
    out_data(out_data==1) = val(out_data==1,1);
end

Context

StackExchange Code Review Q#46027, answer score: 5

Revisions (0)

No revisions yet.