patternMinor
K-nearest neighbours in MATLAB
Viewed 0 times
nearestmatlabneighbours
Problem
I implemented K-Nearest Neighbours algorithm, but my experience using MATLAB is lacking. I need you to check the small portion of code and tell me what can be improved or modified. I hope it is a correct implementation of the algorithm.
To test it you can use :
function test_data = knn(test_data, tr_data,k)
numoftestdata = size(test_data,1);
numoftrainingdata = size(tr_data,1);
for sample=1:numoftestdata
%Step 1: Computing euclidean distance for each testdata
R = repmat(test_data(sample,:),numoftrainingdata,1) ;
euclideandistance = (R(:,1) - tr_data(:,1)).^2;
%Step 2: compute k nearest neighbors and store them in an array
[dist position] = sort(euclideandistance,'ascend');
knearestneighbors=position(1:k);
knearestdistances=dist(1:k);
% Step 3 : Voting
for i=1:k
A(i) = tr_data(knearestneighbors(i),2);
end
M = mode(A);
if (M~=1)
test_data(sample,2) = mode(A);
else
test_data(sample,2) = tr_data(knearestneighbors(1),2);
end
endTo test it you can use :
- test_data = [6,0; 2,0; 5,0]
- tr_data = [1,1;0,2;3,2; 4,4; 5,3]
Solution
- Use consistent indentation.
- You switch from extremely verbose all lower case variable names like
numoftrainingdatato single letter capitalized variable names likeA. Make your variable names descriptive and no longer than necessary, and be consistent.
- Use consistent white space between operators.
-
knn() doesn't need the second column of test_data, and the calling function doesn't need the first column of test_data.Rather than calling the function like this:
test_data = knn(test_data,tr_data,k);Call it like this:
test_data(:,2) = knn(test_data(1,:),tr_data,k);-
Do you want to handle certain error conditions like
0 >= k or k > size(tr_data,1)?- Rather than squaring the distance, you can use
abs().
- Remove the ascend parameter from
sort(), that is the default mode.
knearestdistancesis unused.
- You call
mode()a second time rather than usingM.
Simplify:
[dist position] = sort(euclideandistance,'ascend');
knearestneighbors = position(1:k);
knearestdistances = dist(1:k);
for i=1:k
A(i) = tr_data(knearestneighbors(i),2);
end
M = mode(A);
if (M~=1)
test_data(sample,2) = mode(A);
else
test_data(sample,2) = tr_data(knearestneighbors(1),2);
endto
[~,position] = sort(euclideandistance);
A = tr_data(position(1:k),2);
M = mode(A);
if (M~=1)
test_data(sample,2) = M;
else
test_data(sample,2) = tr_data(position(1),2);
endAfter applying the above suggestions and vectorizing the function you could write it as:
function out_data = knn(test_data,tr_data,k)
test_data_n = size(test_data,1);
tr_data_n = size(tr_data,1);
% absolute distance between all test and training data
dist = abs(repmat(test_data,1,tr_data_n) - repmat(tr_data(:,1)',test_data_n,1));
% indicies of nearest neighbors
[~,nearest] = sort(dist,2);
% k nearest
nearest = nearest(:,1:k);
% mode of k nearest
val = reshape(tr_data(nearest,2),[],k);
out_data = mode(val,2);
% if mode is 1, output nearest instead
out_data(out_data==1) = val(out_data==1,1);
endEdit
Regarding correctness, i'm not sure why you check to see if the mode is 1. There is nothing unique about a mode of 1 in general.
Code Snippets
test_data = knn(test_data,tr_data,k);test_data(:,2) = knn(test_data(1,:),tr_data,k);[dist position] = sort(euclideandistance,'ascend');
knearestneighbors = position(1:k);
knearestdistances = dist(1:k);
for i=1:k
A(i) = tr_data(knearestneighbors(i),2);
end
M = mode(A);
if (M~=1)
test_data(sample,2) = mode(A);
else
test_data(sample,2) = tr_data(knearestneighbors(1),2);
end[~,position] = sort(euclideandistance);
A = tr_data(position(1:k),2);
M = mode(A);
if (M~=1)
test_data(sample,2) = M;
else
test_data(sample,2) = tr_data(position(1),2);
endfunction out_data = knn(test_data,tr_data,k)
test_data_n = size(test_data,1);
tr_data_n = size(tr_data,1);
% absolute distance between all test and training data
dist = abs(repmat(test_data,1,tr_data_n) - repmat(tr_data(:,1)',test_data_n,1));
% indicies of nearest neighbors
[~,nearest] = sort(dist,2);
% k nearest
nearest = nearest(:,1:k);
% mode of k nearest
val = reshape(tr_data(nearest,2),[],k);
out_data = mode(val,2);
% if mode is 1, output nearest instead
out_data(out_data==1) = val(out_data==1,1);
endContext
StackExchange Code Review Q#46027, answer score: 5
Revisions (0)
No revisions yet.