
function res = learn_msda(prm)
%
% function res = learn_msda(prm)
%
% Learn the projection matrix and other MSDA parameters (e.g. best subclass
% partition, mean for zero mean sample normalization, etc.) using a training 
% set and an appropriate criterion [1,2]. A cross-validation (CV) procedure
% is used to learn the optimum parameters. At each CV cycle a new subclass
% partition is created. The spliting at
% each CV cycle is performed using a new nongaussianity criterion based on
% the the negentropy increment [3,4]. The quality of a partition is
% assessed using the DA stability criterion or the correct recognition rate
% (CCR) at the CV cycles [1,2,5]. Note that we use a slightly different CV 
% approach. That is, at each CV cycle the MSDA parameters are computed 
% using the entire training set, and subsequently we split the training set 
% to learning and validation set for validating the learned parameters at 
% this CV. For classification we use the nearest neighbour classifier 
% in the MSDA subspace.
%
% IN
%
% prm : input parameter structure with the following fields:
%
%  fea_Train: training feature vectors N x F
%
%  gnd_Train: ground truth labels N x 1
%
%  projection matrix:
%       1: use cross-validation (CV) criterion
%       2: DA stability criterion
%
%  Nmin_sub: the minimum permitted number of observations in a subclass
%
%  Hmx: user defined maximum allowed number of subclasses
%
%  normType: method to use for normalizing the observations, i.e. zero mean 
%  sample or unity norm
%
%  distType: similarity (or disimilarity) measure to use for comparing
%  observations, i.e. Euclidean distance, or cosine similarity.
%
%  WithinScatter: Type of matrix to use for measuring the within class
%  scatter, i.e.
%       1: total matrix St,
%       2: within subclass scatter matrix Sws,
%       3: stable within subclass scatter matrix Sws + Sbsb
%
%  useCV: declares the type of criterion to use for learning the MSDA
%
%  noOfRandSplitsLearn: Number of CV splits for learning the optimal MSDA
%  parameters
%
%  learnRatio: ratio of samples to be used at each CV for learning the MSDA
%  parameters, e.g., if learnRatio is 0.7 the validation ratio for testing
%  the computed parameters is 0.3 (at each CV split 70% of the training 
%  observations are used for learning the MSDA parameters and 30% of the
%  observations are used for testing this set of parameters
% 
%
% OUT
%
% res: output parameter structure with the following fields:
%
%  fea_Mean: mean sample vector 1 x F (empty [] if unity norm normalization
%  is used)
%
%  W: optimum projection matrix (according to optimum subclass partition)
%
%  H: optimum number of total subclasses
%
%
%
% Related references:
%
% 1. N. Gkalelis, V. Mezaris, I. Kompatsiaris, "Mixture subclass 
% discriminant analysis", IEEE Signal Processing Letters, vol. 18, no. 5,
% pp. 319-322, May 2011
%
% 2. N. Gkalelis, V. Mezaris, I. Kompatsiaris, T. Stathaki, "Mixture
% subclass discriminant analysis link to restricted Gaussian model and
% other generalizations", IEEE Transactions on Neural Networks and Learning 
% Systems, vol. 24, no. 1, pp. 8-21, January 2013.
%
% 3. Luis F. Lago-Fernandez, Fernando J. Corbacho: Using the Negentropy
% Increment to Determine the Number of Clusters. IWANN (1) 2009: 448-455
%
% 4. Luis F. Lago-Fernandez, Fernando J. Corbacho: Normality-based
% validation for crisp clustering. Pattern Recognition 43(3): 782-795
% (2010)
%
% 5. M. Zhu and A.M. Martinez, "Subclass Discriminant Analysis", IEEE 
% Transactions on Pattern Analysis and Machine Intelligence, Vol. 28,
% No. 8, pp. 1274-1286, 2006
%
%
% Author: Nikolaos Gkalelis - CERTH-ITI
% Email: gkalelis@iti.gr
%
% Created 01 Aug 2013.
%


%% read input parameters
learnRatio = prm.learnRatio;
noOfRandSplitsLearn = prm.noOfRandSplitsLearn;
normType = prm.normType;
distType = prm.distType;
useCV = prm.useCV;
Nmin_sub = prm.Nmin_sub;
Hmx = prm.Hmx;
WithinScatter = prm.WithinScatter;
fea = prm.fea_Train;
gnd = prm.gnd_Train;
clear prm;

valRatio = 1- learnRatio; % validation ratio

%% trace
logFile = char( strcat('log', '_', 'learn_msda', '_' , normType, '_', distType, '_', int2str(WithinScatter), '_', num2str(learnRatio*10), '.txt'));
fd = fopen(logFile, 'a');

fprintf(fd , 'learn_msda>> Entering\n');

%% sort data according to their ground truth label
[fea_sorted ground_sorted] = sortFeaLabels(fea, gnd);
clear fea_Train gnd_Train;
fea = fea_sorted; % now size is samples x features
gnd = ground_sorted;
clear fea_Train_sorted gnd_sorted;

%% normalize training set and learn total mean (using all training set !!!)
N = size(fea, 1);
if strcmp(normType,'zeroMean')
    fea_Mean = mean(fea,1); % mean of ALL TRAIN samples
    fea = fea - repmat(fea_Mean, N,1);
    res.fea_Mean = fea_Mean; % keep parameter !!!
elseif strcmp(normType,'unityNorm')
    for n=1:N
        fea(n,:) = fea(n,:) ./ max(eps,norm(fea(n,:)));
    end
    res.fea_Mean = []; % dummy !!!
elseif strcmp(normType,'NoNorm')
end

%% learning
% learn optimal subclass partition using a cross validation procedure
% to learn the projection matrix we use the overall training set (fea) !!!
% to learn the classifier we use the particular CV train data

% compute maximum allowable number of subclasses
classLbl = sort(unique(gnd));
C = length(classLbl);
Nci = zeros(1,C); % number of observations per class
for i=1:C
    Nci(i) = sum(gnd==i);
end
Nmin_sub = min(Nmin_sub, floor(max(Nci)/2));
Hip = Nci; % maximum allowable subclasses per class
for i=1:C
    Hip(i) = max(1, floor(Nci(i)/Nmin_sub));
end
Hm = sum(Hip); % maximum allowable total number of subclasses
Hm = min(Hm, Hmx);
Hp = Hm - C; % maximum allowable additional subclasses
gnd_subclass = ones(N, 1); % we first intialize one subclass per class
CCR_h = zeros(1,Hp); % initialize CCR to zero for each candidate subclass partition

NG_H = zeros(1, Hp); % negentropy - additional number of subclasses

% CV procedure to select optimum subclass partition
for h = 1:Hp % total number of additional subclasses
	
	fprintf(fd , '\n');
	fprintf(fd , 'learn_msda>> Validating: H(%d)\n', C+h);

    [c_id NG_H(h)]= selectClassToRePartitionPost( fea, gnd, gnd_subclass, distType, Nmin_sub); % select Class to Re-Partition
    Hi = length(unique( gnd_subclass( gnd == c_id, 1 ))); % current number of subclasses of Class to Re-Partition
    % increment the number of subclasses for the selected class
    [ fea_i_tmp gnd_i_tmp Ni stopPartitioningThisClass] = incrementSubclassesOfClass( fea( gnd == c_id, :), Hi+1, distType, Nmin_sub);
    clear Ni;
    fea( gnd == c_id, :) = fea_i_tmp; % sort class observations and ground according to the new subclass labelling
    gnd_subclass(gnd == c_id, :) = gnd_i_tmp;
    
    if stopPartitioningThisClass == true
        error(1, 'Terminating partitioning as reached to subclass with very few samples\n');
    end
    
    % compute projection matrix using all training set
    options = []; options.WithinScatter = WithinScatter;
    [v_lda, Phi(h)] = cmp_msda_mat(gnd, gnd_subclass, options, fea);
    icr(h).W = v_lda;
    fprintf(fd , 'learn_msda>> H(%d) Phi (%d)\n', C+h, Phi(h));
    
    % record CCR for this subclass partition if CV criterion is used
    if useCV == 1
        
        CCR = zeros(1,noOfRandSplitsLearn);
        for  cv = 1:noOfRandSplitsLearn
            %% Preprocessing

            fprintf(fd , '\n------------------------------------------------------------------------------------------\n');
            fprintf(fd , 'learn_msda: split: (%d)\n', cv);
            fprintf(fd , '\n------------------------------------------------------------------------------------------\n');

            % split data for this CV
            trainIdx = false(1,N); % initialize
            testIdx = trainIdx;
            for i=1:C
                c = cvpartition(sum(gnd == classLbl(i)), 'holdout', valRatio);
                trainIdx(gnd == classLbl(i)) = training(c)';
                testIdx(gnd == classLbl(i)) = test(c)';
            end

            fea_Train = fea(trainIdx, :); % training set
            gnd_Train = gnd(trainIdx);
            fea_Test = fea(testIdx, :); % test set
            gnd_Test = gnd(testIdx);

            clear trnInd testInd;

            % sort data according to their ground truth label
            [fea_Train_sorted ground_Train_sorted] = sortFeaLabels(fea_Train, gnd_Train);
            clear fea_Train gnd_Train;
            fea_Train = fea_Train_sorted;
            gnd_Train = ground_Train_sorted;
            clear fea_Train_sorted gnd_sorted;
            
            [fea_Test_sorted ground_Test_sorted] = sortFeaLabels(fea_Test, gnd_Test);
            clear fea_Test gnd_Test;
            fea_Test = fea_Test_sorted;
            gnd_Test = ground_Test_sorted;
            clear fea_Test_sorted ground_Test_sorted;

            % project data and classify
            fea_Train_Proj = fea_Train * v_lda;
            fea_Test_Proj = fea_Test * v_lda;
            recognized = knnclassify(fea_Test_Proj, fea_Train_Proj, gnd_Train, 1, distType, 'nearest');
            CCR(cv) = sum(gnd_Test == recognized)/length(gnd_Test);

        end
        
        CCR_h(h) = mean(CCR); % record CCR for this subclass partition
        fprintf(fd , 'learn_msda>> H(%d) CCR (%d)\n', C+h, CCR_h(h));
    end % CV procedure
    
end

%% record identified parameters

if useCV == 1 % CV criterion is used
    fprintf(fd , 'learn_msda>> CV criterion is used\n');
    [tmp hp_opt ] = max(CCR_h);
else
    [tmp hp_opt ] = min(Phi);
    fprintf(fd , 'learn_msda>> Stability criterion is used\n');
end
clear tmp;

res.W = icr(hp_opt).W; % record identified parameters
res.H = C+hp_opt;

fprintf(fd , 'learn_msda>> Ho(%d) CCRo (%d)\n', C+hp_opt, CCR_h(hp_opt));

fprintf(fd , 'learn_msda>> Exiting\n');

fclose(fd);


