% Binary classfication demos
% discriminant fucntion : linear function 
%
% Copyright (c) 2009, Okito Yamashita, ATR CNS, oyamashi@atr.jp.

clear
close all

data = 1;  % data = 1 or 2 or 3 (optional when you need to download testdata)

fprintf('This code demonstrates how a binary classification problem is solved ...\n');
%----------------------------
% Generate Data
%----------------------------
switch data
     case 1,
        D = 400;
        Ntr = 200;
        Nte = 100;
        % mean
        mu1 = zeros(D,1);
        mu2 = [1.5; 0; zeros(D-2,1)];
        % covariance
        S = diag(ones(D,1));
        ro = 0.8;
        S(1,2) = ro;
        S(2,1) = ro;


        [ttr, xtr, tte, xte, g] = gen_simudata([mu1 mu2], S, Ntr, Nte);

        fprintf('\nThe data is generated from 2 Gaussian Mixture model of which centers (mean) are different.\n');
        fprintf('But only the first dimension has difference in mean value between two classes,\n');
        fprintf('and the other dimension has same mean value.\n');
        fprintf('Therefore only the first dimension is detected as a meaningful feature,\n')
        fprintf('if you select features by feature-wise t-value ranking method.\n');
        fprintf('However due to the correlataion between the second dimension and the first dimension,\n')
        fprintf('inclusion of the second dimension makes classfication more accurate.\n');
        fprintf('For comparison, this demo also computes the classification performance of linear RVM, \n');
        fprintf('which is Bayesian counterpart of support vector machine (SVM).\n');
        fprintf('Input feature dimension is %d. \n', D); 
        fprintf('The number of training samples is %d. \n', Ntr);

    case 2,
        D = 100;
        Ntr = 200;
        Nte = 100;
        % mean
        mu1 = zeros(D,1);
        mu2 = [[1:-0.02:0]'; zeros(D-51,1)];
        % covariance
        S = diag(ones(D,1));

        [ttr, xtr, tte, xte, g] = gen_simudata([mu1 mu2], S, Ntr, Nte);

        fprintf('\nThe data is generated from 2 Gaussian Mixture model of which centers (mean) are different.\n');
        fprintf('The mean value of the first 50 dimension is slightly different between two classes,\n');
        fprintf('whereas the remaining dimension has the same mean value.\n');
        fprintf('The degree of difference in the first 50 dimensions are manipulated\n')
        fprintf('by gradually changing the mean values in class 1 from 0 to 1.\n')
        fprintf('For comparison, this demo also computes the classification performance of linear RVM, \n');
        fprintf('which is Bayesian counterpart of support vector machine (SVM).\n');
        fprintf('Input feature dimension is %d. \n', D); 
        fprintf('The number of training samples is %d. \n', Ntr);
        
    case 3,
        load('../TESTDATA/real_binary', 'TRAIN_DATA', 'TEST_DATA', 'TRAIN_LABEL', 'TEST_LABEL');
        ttr = TRAIN_LABEL;
        tte = TEST_LABEL;
        xtr = TRAIN_DATA;
        xte = TEST_DATA;
        [Ntr,D] = size(TRAIN_DATA);
        [Nte] = size(TEST_DATA,1);
        
        fprintf('\nThe data is generated from a real experimental EEG data.\n');
        fprintf('In the experiment a subject executed either left or right finger tapping.\n');
        fprintf('The data has already been processed appropriately for classification.\n');
        fprintf('Input feature dimension is %d. \n', D); 
        fprintf('The number of training samples is %d. \n', Ntr);
end

%--------------------------------
% Plot data (First 2 dimension)
%--------------------------------
slr_view_data(ttr, xtr);
axis equal;
title('Training Data')

fprintf('\n\nPress any key to proceed \n\n');
pause

%--------------------------------
% Learn Paramters
%--------------------------------
tic
fprintf('\nOLD version (ARD-Laplace)!!\n')
[ww_o, ix_eff_o, errTable_tr_o, errTable_te_o] = biclsfy_slrlap(xtr, ttr, xte, tte,...
    'wdisp_mode', 'off', 'nlearn', 300, 'mean_mode', 'none', 'scale_mode', 'none');
toc

tic
fprintf('\nFast version (ARD-Variational)!!\n')
[ww_f, ix_eff_f, errTable_tr_f, errTable_te_f] = biclsfy_slrvar(xtr, ttr, xte, tte,...
    'nlearn', 300, 'mean_mode', 'none', 'scale_mode', 'none');
toc

tic
fprintf('\n\nLinear RVM (ARD-Variational)!!\n')
    [ww_rvm, ix_eff_rvm, errTable_tr_rvm, errTable_te_rvm, g_rvm] = biclsfy_rvm(xtr, ttr, xte, tte, 0, ...
        'nlearn', 300, 'nstep', 100, 'mean_mode', 'none', 'scale_mode', 'none', 'amax', 1e8);
toc

tic
fprintf('\n\nRLR (ARD-Variational)!!\n')
    [ww_r, ix_eff_r, errTable_tr_r, errTable_te_r, g_r] = biclsfy_rlrvar(xtr, ttr, xte, tte, ...
        'nlearn', 300, 'nstep', 100, 'mean_mode', 'none', 'scale_mode', 'none', 'amax', 1e8);
toc



% Weight vector conversion from kernels to features 
ww_rv = zeros(D,1);
for nn = 1 : Ntr,
ww_rv = ww_rv + ww_rvm(nn)*xtr(nn,:)';
end
ww_rv = [ww_rv; ww_rvm(Ntr+1)];

%--------------------------------
% Plot data (First 2 dimension)
%--------------------------------

figure,
subplot(2,2,1)
slr_view_data(tte, xte, [1 2], ww_o(:,1))
axis equal;
title('SLR-Laplce version');
subplot(2,2,2)
slr_view_data(tte, xte, [1 2], ww_f(:,1))
axis equal;
title('SLR-Variational version');
subplot(2,2,3)
slr_view_data(tte, xte, [1 2], ww_rv(:,1))
axis equal;
title('Linear RVM');
subplot(2,2,4)
slr_view_data(tte, xte, [1 2], ww_r(:,1))
axis equal;
title('RLR-Variational version');


fprintf('Finish demo !\n');
