function [acc,confusm,lhood,models,sy,gt] = trntest_1fold(trainset, testset, ... datapath, dataext, ... ngmm, nsamp, dims, ... beatlevel, verb) % [acc,confusm,lhood,models,sy,gt] = trntest_1fold(trainset, testset, % datapath, dataext, ngmm, % nsamp, dims, % beatlevel, verb) % % Example baseline artist ID task. is the name of a file listing % all the training examples (with corresponding label file assumed). % is the name of a file that lists all the test files. % is a path to prefix to all data file names. % is an extension to append to all data file names. % (1) is the number of gaussians to use in each mixture model % (==1 is a special case for full-covariance single gaussians) % (1000) is the number of randomly-selected samples to train with. % (all) is vector of columns in the feature vectors to use. % (0) says to load beat-level features instead of HTK files. % (0) selects verbose progress reporting. % returns the overall accuracy (0..1) % is a confusion count matrix % returns the raw scores for all tracks across all models % is an array of per-label model structures % is a vector of the raw system classes (integers), and % is the corresponding ground truth values. acc = mean(sy==gt); % % 2007-04-04 Dan Ellis dpwe@ee.columbia.edu % $Header: /homes/drspeech/data/uspop2002/baseline/RCS/do_expt.m,v 1.1 2007/04/06 15:44:05 dpwe Exp dpwe $ % Copyright (c) 2007 Columbia University. % % This file is part of LabROSA-artist20-baseline % % artist20-baseline is free software; you can redistribute it and/or modify % it under the terms of the GNU General Public License version 2 as % published by the Free Software Foundation. % % artist20-baseline is distributed in the hope that it will be useful, but % WITHOUT ANY WARRANTY; without even the implied warranty of % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU % General Public License for more details. % % You should have received a copy of the GNU General Public License % along with artist20-baseline; if not, write to the Free Software % Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA % 02110-1301 USA % % See the file "COPYING" for the text of the license. %%%% Input parameters if nargin < 1; trainset = 'a20-trn-tracks.list'; end if nargin < 2; testset = 'a20-val-tracks.list'; end % How many mixture components? if nargin < 5; ngmm = 1; end % Train on how many samples per file? if nargin < 6; nsamp = 1000; end % Using which cepstral dimensions? if nargin < 7; dims = []; end % load beat-level feature files? if nargin < 8; beatlevel = 0; end % progress messages? if nargin < 9; verb = 0; end if beatlevel % different defaults if nargin < 3; datapath = '../beatmfcc'; end if nargin < 4; dataext = '.mat'; end else if nargin < 3; datapath = '../mfccs'; end if nargin < 4; dataext = '.htk'; end end % plot confusion matrix? doplot = 0; if verb doplot = 1; % seed the RNG when doing a verbose run for consistent results rand('state', 0); end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % Train models %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% train_files = textread(trainset, '%s\n'); train_labs = labelsfor(trainset); % Make the train files into full relative paths for i = 1:length(train_files) train_files{i} = fullfile(datapath, [train_files{i}, dataext]); end % List of all unique labels ulabs = unique(train_labs); nlabs = length(ulabs); % Train models one by one for model = 1:nlabs if verb; disp(['training for ',ulabs{model},' ...']); end % Select filenames that have this label as ground truth files = train_files(strcmp(train_labs, ulabs{model})); models(model) = train_model(files, ngmm, nsamp, dims, beatlevel); end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % Test models %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% test_files = textread(testset,'%s\n'); test_labs = labelsfor(testset); ntest = length(test_files); for file = 1:ntest testfile = test_files{file}; if verb; disp(['testing ',testfile,'...']); end lhood(:,file) = test_track(fullfile(datapath,[testfile,dataext]), ... models, nsamp, dims, beatlevel); % Which model should it have been according to ground truth label? gt(file) = find(strcmp(ulabs, test_labs{file})); end % Which is the most likely model for each test item? [maxv,sy] = max(lhood); % Overall error rate acc = mean(sy==gt); if verb; disp(['Classification accuracy = ', num2str(100*acc),'%']); end % Matrix of which track was classified to which class, and ground truth; cm = 0*lhood; gtm = 0*lhood; for i = 1:ntest cm(sy(i),i) = 1; gtm(gt(i),i) = 1; end % So confusion matrix confusm = gtm*cm'; % rows are true class, columns are reported (model) class if doplot trntest_plot_results(lhood, test_labs, ulabs, confusm); end