Trovare K-vicini più vicini e la sua implementazione

Sto lavorando alla classificazione di dati semplici usando KNN con distanza euclidea. Ho visto un esempio di ciò che vorrei fare che viene eseguito con la funzione knnsearch MATLAB come mostrato di seguito:

 load fisheriris x = meas(:,3:4); gscatter(x(:,1),x(:,2),species) newpoint = [5 1.45]; [n,d] = knnsearch(x,newpoint,'k',10); line(x(n,1),x(n,2),'color',[.5 .5 .5],'marker','o','linestyle','none','markersize',10) 

Il codice precedente prende un nuovo punto, ovvero [5 1.45] e trova i 10 valori più vicini al nuovo punto. Qualcuno può mostrarmi un algoritmo MATLAB con una spiegazione dettagliata di ciò che fa la funzione knnsearch ? C’è un altro modo di fare questo?

La base dell’algoritmo K-Nearest Neighbor (KNN) è che si dispone di una matrice di dati costituita da N righe e M colonne dove N è il numero di punti dati che abbiamo, mentre M è la dimensionalità di ciascun punto dati. Ad esempio, se poniamo le coordinate cartesiane all’interno di una matrice di dati, di solito si tratta di una matrice N x 2 o N x 3 . Con questa matrice di dati, si fornisce un punto di interrogazione e si cercano i punti k più vicini all’interno di questa matrice di dati che sono i più vicini a questo punto di ricerca.

Di solito usiamo la distanza euclidea tra la query e il resto dei tuoi punti nella tua matrice di dati per calcolare le nostre distanze. Tuttavia, vengono utilizzate anche altre distanze come la L1 o la distanza City-Block / Manhattan. Dopo questa operazione, avrai le distanze Euclidee o Manhattan che simboleggiano le distanze tra la query e ogni punto corrispondente nel set di dati. Una volta individuati, è sufficiente cercare i k punti più vicini alla query ordinando le distanze in ordine crescente e recuperando i punti k che hanno la distanza minima tra il set di dati e la query.

Supponendo che la tua matrice di dati sia stata memorizzata in x e newpoint sia un punto di campionamento in cui ha colonne M (cioè 1 x M ), questa è la procedura generale che seguiresti in forma di punto:

  1. Trova la distanza tra Euclide e Manhattan tra newpoint e ogni punto in x .
  2. Ordina queste distanze in ordine crescente.
  3. Restituisci i k punti di dati in x che sono più vicini a newpoint .

Facciamo ogni passo lentamente.


Passo 1

Un modo in cui qualcuno può farlo è forse in un ciclo for modo:

 N = size(x,1); dists = zeros(N,1); for idx = 1 : N dists(idx) = sqrt(sum((x(idx,:) - newpoint).^2)); end 

Se volessi implementare la distanza di Manhattan, questo sarebbe semplicemente:

 N = size(x,1); dists = zeros(N,1); for idx = 1 : N dists(idx) = sum(abs(x(idx,:) - newpoint)); end 

dists sarebbe un vettore di elementi N che contiene le distanze tra ciascun punto di dati in x e newpoint . Facciamo una sottrazione elemento per elemento tra newpoint e un punto dati in x , piazza le differenze, quindi newpoint tutte insieme. Questa sum è quindi radicata in quadrato, che completa la distanza euclidea. Per la distanza di Manhattan, dovresti eseguire un elemento per sottrazione di elementi, prendere i valori assoluti, quindi sumre tutti i componenti insieme. Questa è probabilmente la più semplice delle implementazioni da comprendere, ma potrebbe essere probabilmente la più inefficiente … soprattutto per i set di dati di dimensioni maggiori e una maggiore dimensionalità dei dati.

Un’altra soluzione ansible sarebbe quella di replicare newpoint e rendere questa matrice delle stesse dimensioni di x , quindi eseguire una sottrazione elemento per elemento di questa matrice, quindi sumre su tutte le colonne per ogni riga e eseguire la radice quadrata. Pertanto, possiamo fare qualcosa di simile:

 N = size(x, 1); dists = sqrt(sum((x - repmat(newpoint, N, 1)).^2, 2)); 

Per la distanza di Manhattan, dovresti fare:

 N = size(x, 1); dists = sum(abs(x - repmat(newpoint, N, 1)), 2); 

repmat prende una matrice o un vettore e li ripete una certa quantità di volte in una determinata direzione. Nel nostro caso, vogliamo prendere il nostro vettore newpoint e impilare questo N volte l’uno sopra l’altro per creare una matrice N x M , dove ogni riga è lunga M elementi. Sottragiamo queste due matrici insieme, quindi quadriamo ogni componente. Una volta fatto questo, sum tutte le colonne per ogni riga e infine prendiamo la radice quadrata di tutti i risultati. Per la distanza di Manhattan, facciamo la sottrazione, prendiamo il valore assoluto e poi sommiamo.

Tuttavia, il modo più efficiente per farlo a mio avviso sarebbe usare bsxfun . Questo essenzialmente fa la replica di cui abbiamo parlato sotto il cofano con una singola chiamata di funzione. Pertanto, il codice sarebbe semplicemente questo:

 dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2)); 

Per me questo sembra molto più pulito e al punto. Per la distanza di Manhattan, dovresti fare:

 dists = sum(abs(bsxfun(@minus, x, newpoint)), 2); 

Passo 2

Ora che abbiamo le nostre distanze, li ordiniamo semplicemente. Possiamo usare sort per ordinare le nostre distanze:

 [d,ind] = sort(dists); 

d dovrebbe contenere le distanze ordinate in ordine ascendente, mentre ind ti dice per ogni valore nell’array non ordinato dove appare nel risultato ordinato . Dobbiamo usare ind , estrarre i primi k elementi di questo vettore, quindi usare ind per indicizzare nella nostra matrice di dati x per restituire quei punti che erano i più vicini a newpoint .

Passaggio n. 3

Il passaggio finale è ora di restituire quei k punti di dati che sono più vicini a newpoint . Possiamo farlo molto semplicemente:

 ind_closest = ind(1:k); x_closest = x(ind_closest,:); 

ind_closest dovrebbe contenere gli indici nella matrice di dati originale x che sono i più vicini a newpoint . In particolare, ind_closest contiene le righe da campionare in x per ottenere i punti più vicini a newpoint . x_closest conterrà quei punti dati effettivi.


Per il tuo piacere di copiare e incollare, questo è il codice:

 dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2)); %// Or do this for Manhattan % dists = sum(abs(bsxfun(@minus, x, newpoint)), 2); [d,ind] = sort(dists); ind_closest = ind(1:k); x_closest = x(ind_closest,:); 

Esaminando il tuo esempio, vediamo il nostro codice in azione:

 load fisheriris x = meas(:,3:4); newpoint = [5 1.45]; k = 10; %// Use Euclidean dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2)); [d,ind] = sort(dists); ind_closest = ind(1:k); x_closest = x(ind_closest,:); 

Ispezionando ind_closest e x_closest , questo è ciò che otteniamo:

 >> ind_closest ind_closest = 120 53 73 134 84 77 78 51 64 87 >> x_closest x_closest = 5.0000 1.5000 4.9000 1.5000 4.9000 1.5000 5.1000 1.5000 5.1000 1.6000 4.8000 1.4000 5.0000 1.7000 4.7000 1.4000 4.7000 1.4000 4.7000 1.5000 

Se hai eseguito knnsearch , vedrai che la tua variabile n corrisponde a ind_closest . Tuttavia, la variabile d restituisce le distanze da newpoint a ciascun punto x , non i punti di dati effettivi stessi. Se vuoi le distanze effettive, fai semplicemente quanto segue dopo il codice che ho scritto:

 dist_sorted = d(1:k); 

Si noti che la risposta di cui sopra utilizza solo un punto di query in un batch di N esempi. Molto spesso KNN viene utilizzato su più esempi contemporaneamente. Supponendo che abbiamo Q punti di domanda che vogliamo testare nella KNN. Ciò risulterebbe in una matrice kx M x Q dove per ogni esempio o ogni fetta, restituiamo i punti k più vicini con una dimensionalità di M In alternativa, possiamo restituire gli ID dei punti k più vicini risultando in una matrice Q xk . Calcoliamo entrambi.

Un modo ingenuo per farlo sarebbe quello di applicare il codice sopra in un ciclo e ripetere su ogni esempio.

Qualcosa del genere funzionerebbe dove bsxfun una matrice Q xk e applichiamo l’approccio basato su bsxfun per impostare ogni riga della matrice di output sui punti k più vicini nel set di dati, dove useremo il set di dati Fisher Iris proprio come quello che avevamo prima. Manterremo anche la stessa dimensionalità che avevamo nell’esempio precedente e userò quattro esempi, quindi Q = 4 e M = 2 :

 %// Load the data and create the query points load fisheriris; x = meas(:,3:4); newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5]; %// Define k and the output matrices Q = size(newpoints, 1); M = size(x, 2); k = 10; x_closest = zeros(k, M, Q); ind_closest = zeros(Q, k); %// Loop through each point and do logic as seen above: for ii = 1 : Q %// Get the point newpoint = newpoints(ii, :); %// Use Euclidean dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2)); [d,ind] = sort(dists); %// New - Output the IDs of the match as well as the points themselves ind_closest(ii, :) = ind(1 : k).'; x_closest(:, :, ii) = x(ind_closest(ii, :), :); end 

Anche se questo è molto bello, possiamo fare ancora meglio. C’è un modo per calcolare in modo efficiente la distanza Euclidea quadrata tra due serie di vettori. Lo lascerò come esercizio se vuoi farlo con Manhattan. Consultando questo blog , dato che A è una matrice Q1 x M dove ogni riga è un punto di dimensionalità M con punti Q1 e B è una matrice Q2 x M dove ogni riga è anche un punto di dimensionalità M con punti Q2 , possiamo efficientemente calcolare una matrice di distanze D(i, j) dove l’elemento alla riga i e la colonna j denota la distanza tra la riga i di A e la riga j di B usando la seguente formulazione di matrice:

 nA = sum(A.^2, 2); %// Sum of squares for each row of A nB = sum(B.^2, 2); %// Sum of squares for each row of B D = bsxfun(@plus, nA, nB.') - 2*A*B.'; %// Compute distance matrix D = sqrt(D); %// Compute square root to complete calculation 

Pertanto, se lasciamo che A sia una matrice di punti interrogativi e B sia il set di dati costituito dai tuoi dati originali, possiamo determinare i k punti più vicini ordinando ogni riga individualmente e determinando le posizioni k di ogni riga che erano le più piccole. Possiamo inoltre utilizzarlo anche per recuperare i punti effettivi stessi.

Perciò:

 %// Load the data and create the query points load fisheriris; x = meas(:,3:4); newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5]; %// Define k and other variables k = 10; Q = size(newpoints, 1); M = size(x, 2); nA = sum(newpoints.^2, 2); %// Sum of squares for each row of A nB = sum(x.^2, 2); %// Sum of squares for each row of B D = bsxfun(@plus, nA, nB.') - 2*newpoints*x.'; %// Compute distance matrix D = sqrt(D); %// Compute square root to complete calculation %// Sort the distances [d, ind] = sort(D, 2); %// Get the indices of the closest distances ind_closest = ind(:, 1:k); %// Also get the nearest points x_closest = permute(reshape(x(ind_closest(:), :).', M, k, []), [2 1 3]); 

Vediamo che abbiamo usato la logica per calcolare che la matrice della distanza è la stessa, ma alcune variabili sono cambiate per adattarsi all’esempio. Inoltre, ordiniamo ogni riga in modo indipendente utilizzando la versione a due input di sort e quindi ind conterrà gli ID per riga e d conterrà le distanze corrispondenti. Individuiamo quindi quali indici sono i più vicini a ogni punto di query semplicemente troncando questa matrice su k colonne. Quindi usiamo permute e reshape per determinare quali sono i punti più vicini associati. Per prima cosa usiamo tutti gli indici più vicini e creiamo una matrice di punti che impila tutti gli ID uno sopra l’altro in modo da ottenere una matrice Q * kx M L’uso di reshape e permute ci consente di creare la nostra matrice 3D in modo che diventi una matrice kx M x Q come abbiamo specificato. Se si desidera ottenere le distanze effettive, possiamo indicizzare in e afferrare ciò di cui abbiamo bisogno. Per fare ciò, è necessario utilizzare sub2ind per ottenere gli indici lineari in modo da poter indicizzare in d in un colpo. I valori di ind_closest ci forniscono già le colonne a cui dobbiamo accedere. Le righe a cui dobbiamo accedere sono semplicemente 1, k volte, 2, k volte, ecc. Fino a Q k indica il numero di punti che volevamo restituire:

 row_indices = repmat((1:Q).', 1, k); linear_ind = sub2ind(size(d), row_indices, ind_closest); dist_sorted = D(linear_ind); 

Quando eseguiamo il codice sopra riportato per i suddetti punti di query, questi sono gli indici, i punti e le distanze che otteniamo:

 >> ind_closest ind_closest = 120 134 53 73 84 77 78 51 64 87 123 119 118 106 132 108 131 136 126 110 107 62 86 122 71 127 139 115 60 52 99 65 58 94 60 61 80 44 54 72 >> x_closest x_closest(:,:,1) = 5.0000 1.5000 6.7000 2.0000 4.5000 1.7000 3.0000 1.1000 5.1000 1.5000 6.9000 2.3000 4.2000 1.5000 3.6000 1.3000 4.9000 1.5000 6.7000 2.2000 x_closest(:,:,2) = 4.5000 1.6000 3.3000 1.0000 4.9000 1.5000 6.6000 2.1000 4.9000 2.0000 3.3000 1.0000 5.1000 1.6000 6.4000 2.0000 4.8000 1.8000 3.9000 1.4000 x_closest(:,:,3) = 4.8000 1.4000 6.3000 1.8000 4.8000 1.8000 3.5000 1.0000 5.0000 1.7000 6.1000 1.9000 4.8000 1.8000 3.5000 1.0000 4.7000 1.4000 6.1000 2.3000 x_closest(:,:,4) = 5.1000 2.4000 1.6000 0.6000 4.7000 1.4000 6.0000 1.8000 3.9000 1.4000 4.0000 1.3000 4.7000 1.5000 6.1000 2.5000 4.5000 1.5000 4.0000 1.3000 >> dist_sorted dist_sorted = 0.0500 0.1118 0.1118 0.1118 0.1803 0.2062 0.2500 0.3041 0.3041 0.3041 0.3000 0.3162 0.3606 0.4123 0.6000 0.7280 0.9055 0.9487 1.0198 1.0296 0.9434 1.0198 1.0296 1.0296 1.0630 1.0630 1.0630 1.1045 1.1045 1.1180 2.6000 2.7203 2.8178 2.8178 2.8320 2.9155 2.9155 2.9275 2.9732 2.9732 

Per confrontare questo con knnsearch , devi specificare una matrice di punti per il secondo parametro in cui ogni riga è un punto di query e vedrai che gli indici e le distanze ordinate corrispondono tra questa implementazione e knnsearch .


Spero che questo ti aiuti. In bocca al lupo!