function [model,mn,st] = model_train(trnset,mn,st,ngmm,nsamp,dims)
% [model,mn,st] = model_train(trnset,mn,st,ngmm,nsamp,dims)
% trnset is a list (cell array) of data files defining a class
% mn, st are normalization constants (empty -> derive from data)
% ngmm is the number of gaussian components (default 20),
% nsamp is the number of randomly-selcted frames to train on.
% dims is a vector indexing which dimensions of the data to use (empty->all)
% model is the model we return with fields mean, sigma, prior, nmix
% mn, st are returned (e.g. to pass to next invocation)
% 2007-04-03 Dan Ellis dpwe@ee.columbia.edu artist ID demo system
% Standardization parameters are taken from data if not specified
if nargin < 2; mn = []; end
if nargin < 3; st = []; end
% How many gaussian components?
if nargin < 4; ngmm = 20; end
% Randomly sample how many data points?
if nargin < 5; nsamp = 2000; end
% Which dimensions to use in the data? Default to all
if nargin < 6; dims = []; end
data = [];
guarddef = 1000; % default guard spacing from ends
% Load all the data
for i = 1:length(trnset)
d = readdatafile(trnset{i})';
% only add the middle part of each file into the pool
% ri will consist of ntsamp random indices (with replacement) at least
% samples away from end, where guard is larger of 1000 or one
% quarter of the dataset length
guard = min(round(size(d,1)/4),guarddef); % how far to stay away from ends
ri = guard + ceil((size(d,1)-2*guard)*rand(1,nsamp));
data = [data; d(ri,:)];
end
% Select random subset of frames
ri = randperm(size(data,1));
ri = ri(1:nsamp);
% If no dimensions explicitly selected, use them all
if length(dims) == 0
dims = 1:length(mn);
end
% Select this data only
data = data(ri,:);
% Derive normalization constants if not specified
if length(mn) == 0
[data,mn,st] = standardize(data');
data = data';
else
% Normalization by global constants (helps gmm??)
data = standardize(data',mn,st)';
end
% remove unwanted dimensions now we've got all the stand coeffs
data = data(:,dims);
model = model_train_data(data, ngmm);