-
Notifications
You must be signed in to change notification settings - Fork 2
/
SOM.m
79 lines (76 loc) · 1.93 KB
/
SOM.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
function [ClusterIm, CCIm] = SOM(data,nodes,ImageType);
m = size(data,1);
n = size(data,2);
b = size(data,3);
if(strcmp('RGB',ImageType)==1)
R = data(:,:,1);
G = data(:,:,2);
B = data(:,:,3);
R = double(R);
G = double(G);
B = double(B);
R = R(:);
G = G(:);
B = B(:);
%summation = R+G+B;
%R = R./summation;
%G = G./summation;
%B = B./summation;
data = [R,G,B];
end
if(strcmp('Hyper',ImageType)==1)
b = 20;
[Y, U, Lambda, Mu] = PCAbyDG(data,b);
data = Y;
end
size(data)
totalW = nodes;
%//initialization of weights
w = rand(totalW,b).*256;
%// the initial learning rate
eta0 = 0.01;
%// the current learning rate (updated every epoch)
etaN = eta0;
%// the constant for calculating learning rate
tau2 = 1000;
%creating map out of node
[I,J] = ind2sub([sqrt(totalW),sqrt(totalW)], 1:totalW);
%N = size(data,2);
%// the size of neighbor
sig0 = 2;
sigN = sig0;
%// tau 1 for updateing sigma
tau1 = 1000/log(sigN);
out = zeros(m*n,1);
%i is number of epoch
%itr = 100;
wold = w;
for itr=1:10
%// j is index of each point.
%// it should iterate through data in a random order rewrite!!
for j=1:m*n
x = data(j,:);
% 2 is for summing rows
dist = sum( sqrt((w - repmat(x,totalW,1)).^2),2);
%// find the winner
[v ind] = min(dist);
%// the 2-D index
ri = [I(ind), J(ind)];
out(j) = ind;
%// distance between this node and the winner node.
dist = 1/(sqrt(2*pi)*sigN).*exp( sum(( ([I( : ), J( : )] - repmat(ri, totalW,1)) .^2) ,2)/(-2*sigN)) * etaN;
%// updating weights
w = w + dist.*( x - w);
end
%// update learning rate
etaN = eta0 * exp(-itr/tau2);
%// update sigma
sigN = sig0*exp(-itr/tau1);
end
%w-wold
ClusterIm = out;
im = reshape(out,m,n);
im = medfilt2(im);
%imagesc( imgaussfilt(im,1) );
CCIm = ConnectedComponent(ClusterIm,nodes,m,n);
imagesc( imgaussfilt(CCIm,1));