|
@@ -0,0 +1,289 @@
|
|
|
+%%% pca + mismatch network
|
|
|
+
|
|
|
+% y comes into pca network, which runs one iter. (immediately clusters)
|
|
|
+% output of pca network is unary context signal
|
|
|
+% y and context signal are then fed into mismatch network
|
|
|
+
|
|
|
+% pca network starts with relatively low excitability, and excitability
|
|
|
+% increases every time total mismatch is over some threshold
|
|
|
+
|
|
|
+
|
|
|
+%%% Important notes:
|
|
|
+%%% -this specific implementation could fail if we're unlucky with the
|
|
|
+%%% intial directions of the weights (ie if no weight is close enough
|
|
|
+%%% to input vector, we'd get into trouble)
|
|
|
+%%% -this version has been edited to work with very high dim inputs!
|
|
|
+%%% -v2 tries to introduce normalizations to mm weights
|
|
|
+%%%
|
|
|
+%%%
|
|
|
+
|
|
|
+%%% Things to check!
|
|
|
+%%% -Mismatch error should be decreasing over time, although that may
|
|
|
+%%% take a while
|
|
|
+%%% -Rank of W should initially be close to 100 (or whatever
|
|
|
+%%% initialized), but should end up close to number of clusters
|
|
|
+%%% -Weight magnitudes (for W and M) shouldn't be too crazy
|
|
|
+%%% -Cluster IDs should be consistent - after running network once,
|
|
|
+%%% re-running on same points should give same clusters (unless network
|
|
|
+%%% takes very long to converge)
|
|
|
+%%% -Mismatch error should be much lower if same data run on trained network
|
|
|
+%%% -Context signal should be relatively sparse!
|
|
|
+
|
|
|
+clear
|
|
|
+close all
|
|
|
+
|
|
|
+h1 = figure(1);
|
|
|
+set(h1, 'Position', [2209 -31 512 384])
|
|
|
+h2 = figure(5);
|
|
|
+set(h2, 'Position', [1370 -319 822 672])
|
|
|
+% figure(10)
|
|
|
+% figure(25)
|
|
|
+pause
|
|
|
+%% Configs
|
|
|
+seed = 12345; %10 was working for many of these demos %56; %23, 10 for 100D 10 clusters
|
|
|
+ % seed choice affects initial weight direction, which affects peaks of
|
|
|
+ % error plots
|
|
|
+%seed = 1123;
|
|
|
+trackVars = 1; %turning on slows things WAY DOWN
|
|
|
+
|
|
|
+
|
|
|
+% Select Dataset
|
|
|
+%load('genData1.mat')
|
|
|
+% load('genData2D_4c_200ppc.mat')
|
|
|
+load('genData2D_3c_500ppc.mat')
|
|
|
+% load('genData100D_10.mat')
|
|
|
+% load('genData_realSpec_try2.mat')
|
|
|
+% load('genData_realSpec_5c.mat')
|
|
|
+
|
|
|
+
|
|
|
+% Define how data will be presented
|
|
|
+twoClust_dm_unshuff = allPts_dm_unshuff(1:100, :);
|
|
|
+twoClust_unshuff = allPts_unshuff(1:100, :);
|
|
|
+
|
|
|
+shuffIdx = 1:100;%randperm(500);
|
|
|
+
|
|
|
+part1_dm = twoClust_dm_unshuff(shuffIdx,:);
|
|
|
+part1 = twoClust_unshuff(shuffIdx, :);
|
|
|
+
|
|
|
+% bigY is input fed into clustering network - note that it's de-meaned
|
|
|
+bigY = [allPts_dm_unshuff']; %[part1_dm' allPts_dm'];%_unshuff'];% allPts_dm' allPts_dm'];%(1:500,:)'];% allPts_dm'];
|
|
|
+
|
|
|
+% mmY is input fed into mismatch network - just a non-de-meaned version of
|
|
|
+% bigY
|
|
|
+mmY = [allPts_unshuff']; %[part1' allPts'];%_unshuff'];% allPts' allPts'];%(1:500,:)'];% allPts'];
|
|
|
+
|
|
|
+% PCA CONFIGS
|
|
|
+pcaNet.changeThresh = 1e-4; % for output convergence
|
|
|
+initWeightMag = 0.1; %0.8 for some demos % should be 0.1
|
|
|
+pcaNet.capW = 10;%500; %max weight to each output cell -- CHANGING THIS DIDN'T MAKE A DIFFERENCE
|
|
|
+pcaNet.inhibCap = 10; %NEED 10 FOR SYNTHETIC, 25 FOR REAL... -- CHANGING THIS DOES MATTER
|
|
|
+
|
|
|
+pcaNet.etaW = 0.05; %originally each was 0.5 -- MAKING THIS TOO HIGH DOESN'T MATTER MUCH
|
|
|
+pcaNet.etaM = 0.05; % originally 0.5 -- CHANGING THIS DOESN'T MATTER MUCH
|
|
|
+pcaNet.maxM = 1;
|
|
|
+pcaNet.maxW = 1;
|
|
|
+
|
|
|
+% MISMATCH CONFIGS
|
|
|
+mmNet.eta = 0.025; % CHANGING THIS DRASTICALLY AFFECTS SHAPE OF ERROR PLOTS
|
|
|
+mmNet.thresh = 0.05; % CHANGING THIS AFFECTS SHAPE OF ERROR PLOTS
|
|
|
+
|
|
|
+% MISTMATCH STRUCTURE CONFIGS
|
|
|
+mmNet.signed_synapses = 1; %force positive weight coeffs
|
|
|
+mmNet.c_plastic = 1; %
|
|
|
+mmNet.y_plastic = 0;
|
|
|
+
|
|
|
+% MISMATCH ARCHITECTURE CONFIGS
|
|
|
+yE_type = 'rand'; %'rand' or 'randnorm'
|
|
|
+cE_type = 'rand';
|
|
|
+yI_type = 'rand';
|
|
|
+cI_type = 'rand';
|
|
|
+yR_type = 'direct'; %'rand' or 'direct'
|
|
|
+cR_type = 'direct'; %'rand' or 'direct'
|
|
|
+
|
|
|
+% MISMATCH NETWORK CONFIGS
|
|
|
+nCells_y = size(mmY,1); %one input per input dim
|
|
|
+nCells_c = 10; % (500 works on synthetic) 10 clusters in 100D case (nb 100d is a lot)
|
|
|
+nCells_Ny = 5; %not super sure how high or low this needs to be
|
|
|
+nCells_Nc = 5; %but 100 and 100 should be enough to give us convex cone
|
|
|
+iterations = size(bigY,2);
|
|
|
+
|
|
|
+
|
|
|
+% Total Configs -- this doesn't matter, we're always learning
|
|
|
+% (learningThresh = 0)
|
|
|
+sigmaThresh = 2.5; %1 works for seed 10 %2.5;%2.5;%%0.25;%25; %1 worked well
|
|
|
+deltaLearn = 0.2; %(0.1) % no buffer -> incorrect clustering (not obvious why needed for shuffled data)
|
|
|
+learningThresh = 0;%1; %make it deltaLearn to basically eliminate deltaLearn %0.5; %always learning -> incorrect clustering
|
|
|
+pcaNet.sigmaThresh = sigmaThresh;
|
|
|
+
|
|
|
+
|
|
|
+pcaNet.learning = 1; %this is set to 0 if mismatch is low, 1 if high
|
|
|
+learningSig = 1;
|
|
|
+rng(seed)
|
|
|
+%% PCA - Setup
|
|
|
+W_init = initWeightMag*randn(nCells_c, nCells_y);
|
|
|
+M_init = 0*initWeightMag*rand(nCells_c); %initially 0
|
|
|
+
|
|
|
+for idx = 1:nCells_c
|
|
|
+ M_init(idx, idx) = 0; %nrns don't drive selves
|
|
|
+end
|
|
|
+
|
|
|
+D_init = zeros(nCells_c,1);
|
|
|
+pcaNet.bigC = zeros(nCells_c, iterations);
|
|
|
+
|
|
|
+pcaNet.W = W_init;
|
|
|
+pcaNet.M = M_init;
|
|
|
+pcaNet.D = D_init;
|
|
|
+
|
|
|
+c = updateC_v5(...
|
|
|
+ pcaNet.W, pcaNet.M, bigY(:,1), pcaNet.changeThresh);
|
|
|
+
|
|
|
+[maxx, thisCluster] = max(c);
|
|
|
+
|
|
|
+if maxx > 0
|
|
|
+ pcaNet.clusters = thisCluster;
|
|
|
+ pcaNet.bigC(thisCluster, 1) = c(thisCluster);
|
|
|
+else
|
|
|
+ pcaNet.clusters = 0;
|
|
|
+end
|
|
|
+
|
|
|
+cT = zeros(size(c));
|
|
|
+cT(thisCluster) = c(thisCluster); % c at timestep t
|
|
|
+y = bigY(:,1);
|
|
|
+
|
|
|
+rng(seed);
|
|
|
+%% Mismatch Setup
|
|
|
+switch yE_type
|
|
|
+ case 'rand'
|
|
|
+ mmNet.we_yn = (rand(nCells_Ny, nCells_y))./nCells_y;
|
|
|
+ case 'randnorm'
|
|
|
+ mmNet.we_yn = (randn(nCells_Ny, nCells_y))./nCells_y;
|
|
|
+end
|
|
|
+
|
|
|
+switch yI_type
|
|
|
+ case 'rand'
|
|
|
+ mmNet.wi_yn = (rand(nCells_Nc, nCells_y))./nCells_y;
|
|
|
+ case 'randnorm'
|
|
|
+ mmNet.wi_yn = (randn(nCells_Nc, nCells_y))./nCells_y;
|
|
|
+end
|
|
|
+
|
|
|
+switch yR_type
|
|
|
+ case 'rand'
|
|
|
+ mmNet.r_yn = rand(nCells_y);
|
|
|
+ case 'direct'
|
|
|
+ mmNet.r_yn = eye(nCells_y);
|
|
|
+end
|
|
|
+
|
|
|
+switch cE_type
|
|
|
+ case 'rand'
|
|
|
+ mmNet.we_cn = (rand(nCells_Nc, nCells_c));%./nCells_c;
|
|
|
+ %
|
|
|
+ case 'randnorm'
|
|
|
+ mmNet.we_cn = (randn(nCells_Nc, nCells_c));%./nCells_c;
|
|
|
+end
|
|
|
+
|
|
|
+switch cI_type
|
|
|
+ case 'rand'
|
|
|
+ mmNet.wi_cn = (rand(nCells_Ny, nCells_c));%./nCells_c;
|
|
|
+ %
|
|
|
+ case 'randnorm'
|
|
|
+ mmNet.wi_cn = (randn(nCells_Ny, nCells_c));%./nCells_c;
|
|
|
+end
|
|
|
+
|
|
|
+switch cR_type
|
|
|
+ case 'rand'
|
|
|
+ mmNet.r_cn = rand(nCells_c);
|
|
|
+ case 'direct'
|
|
|
+ mmNet.r_cn = eye(nCells_c);
|
|
|
+end
|
|
|
+
|
|
|
+timesteps = length(bigY(1,:));
|
|
|
+
|
|
|
+if trackVars == 1
|
|
|
+ mmNet.Vs_y = zeros(nCells_Ny, timesteps);
|
|
|
+ mmNet.Vs_c = zeros(nCells_Nc, timesteps);
|
|
|
+ mmNet.Fs_y = zeros(nCells_Ny, timesteps);
|
|
|
+ mmNet.Fs_c = zeros(nCells_Nc, timesteps);
|
|
|
+
|
|
|
+
|
|
|
+ mmNet.yWs_e = zeros(nCells_Ny, nCells_y, timesteps+1);
|
|
|
+ mmNet.cWs_e = zeros(nCells_Nc, nCells_c, timesteps+1);
|
|
|
+ mmNet.yWs_i = zeros(nCells_Nc, nCells_y, timesteps+1);
|
|
|
+ mmNet.cWs_i = zeros(nCells_Ny, nCells_c, timesteps+1);
|
|
|
+
|
|
|
+ mmNet.wyChanges_e = zeros(nCells_Ny, nCells_y, timesteps);
|
|
|
+ mmNet.wcChanges_e = zeros(nCells_Nc, nCells_c, timesteps);
|
|
|
+ mmNet.wyChanges_i = zeros(nCells_Nc, nCells_y, timesteps);
|
|
|
+ mmNet.wcChanges_i = zeros(nCells_Ny, nCells_c, timesteps);
|
|
|
+
|
|
|
+ mmNet.yWs_e(:,:,1) = mmNet.we_yn;
|
|
|
+ mmNet.cWs_e(:,:,1) = mmNet.we_cn;
|
|
|
+ mmNet.yWs_i(:,:,1) = mmNet.wi_yn;
|
|
|
+ mmNet.cWs_i(:,:,1) = mmNet.wi_cn;
|
|
|
+end
|
|
|
+
|
|
|
+mmNet.errors_y = zeros(timesteps, 1);
|
|
|
+mmNet.errors_c = zeros(timesteps, 1);
|
|
|
+mmNet.allErrors = zeros(timesteps,1);
|
|
|
+
|
|
|
+% we've already done 1 iter of pca, so here do 1 iter of mm
|
|
|
+mmNet = mismatchIter_v2(cT, mmY, 1, mmNet, trackVars);
|
|
|
+
|
|
|
+
|
|
|
+disp('setup complete')
|
|
|
+%% Run through algorithm (make functions for PCAIter and MNIter)
|
|
|
+for ts_idx = 2:iterations
|
|
|
+ pcaNet = pcaIter_v6(bigY, ts_idx, pcaNet);
|
|
|
+
|
|
|
+
|
|
|
+ pcaC = pcaNet.bigC(:, ts_idx);
|
|
|
+
|
|
|
+ pcaC = pcaC>eps;
|
|
|
+
|
|
|
+ mmNet = mismatchIter_v2(pcaC, mmY, ts_idx, mmNet, trackVars);
|
|
|
+
|
|
|
+ sigmaNode = mmNet.allErrors(ts_idx);
|
|
|
+ if sigmaNode > sigmaThresh
|
|
|
+ learningSig = learningSig + deltaLearn;
|
|
|
+ if learningSig > 1
|
|
|
+ learningSig = 1;
|
|
|
+ end
|
|
|
+ else
|
|
|
+ learningSig = learningSig - deltaLearn;
|
|
|
+ if learningSig < 0
|
|
|
+ learningSig = 0;
|
|
|
+ end
|
|
|
+ end
|
|
|
+
|
|
|
+ if learningSig >= learningThresh
|
|
|
+ pcaNet.learning = 1;
|
|
|
+ else
|
|
|
+ pcaNet.learning = 0;
|
|
|
+ end
|
|
|
+
|
|
|
+% agnoPlotting_2D(bigY, mmNet, pcaNet, ts_idx, iterations)
|
|
|
+% agnoPlotting_100D(bigY, mmNet, pcaNet, ts_idx, iterations)
|
|
|
+end
|
|
|
+
|
|
|
+agnoPlotting_2D(bigY, mmNet, pcaNet, ts_idx, iterations)
|
|
|
+
|
|
|
+%% Plotting
|
|
|
+%
|
|
|
+% % Cluster ID Distribution
|
|
|
+% figure(2)
|
|
|
+% subplot(2,2,[1 2])
|
|
|
+% histogram(pcaNet.clusters(1:iterations), [-0.5:1:nCells_c+0.5])
|
|
|
+% title('Inputs per Cluster - Total Dataset')
|
|
|
+% xlabel('Cluster ID')
|
|
|
+% ylabel('Number of Inputs')
|
|
|
+%
|
|
|
+% subplot(2,2,3)
|
|
|
+% histogram(pcaNet.clusters(1:iterations/2), [-0.5:1:nCells_c+0.5])
|
|
|
+% title('Inputs per Cluster - First Half')
|
|
|
+% xlabel('Cluster ID')
|
|
|
+% ylabel('Number of Inputs')
|
|
|
+%
|
|
|
+% subplot(2,2,4)
|
|
|
+% histogram(pcaNet.clusters((iterations/2)+1:end), [-0.5:1:nCells_c+0.5])
|
|
|
+% title('Inputs per Cluster - Second Half')
|
|
|
+% xlabel('Cluster ID')
|
|
|
+% ylabel('Number of Inputs')
|