function [acc,confusm,lhoods,models,sys,gts] = trntest_folds(folds, doplot) % [acc,confusm,lhoods,models,sys,gts] = trntest_folds(folds, doplot) % Perform a classification experiment by cumulating results over % multiple folds. is the name of a file; each line % defines one fold of the classification experiment by defining % a filename defining the training set, and one defining the % test set (i.e. two strings on each line). % Returns are an overall accuracy rate (0..1), a confusion % matrix (counts of tracks, rows are per true label, columns % per assigned label), lhoods are the per-test-item likelihoods % (one row per class, one column per test track), and % models is an array of model structures, based on the last fold. % sys is a vector of system labels for each test item (integers) % gts is the corresponding ground truth classes % % 2007-04-04 dpwe@ee.columbia.edu % Copyright (c) 2007 Columbia University. % % This file is part of LabROSA-artist20-baseline % % artist20-baseline is free software; you can redistribute it and/or modify % it under the terms of the GNU General Public License version 2 as % published by the Free Software Foundation. % % artist20-baseline is distributed in the hope that it will be useful, but % WITHOUT ANY WARRANTY; without even the implied warranty of % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU % General Public License for more details. % % You should have received a copy of the GNU General Public License % along with artist20-baseline; if not, write to the Free Software % Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA % 02110-1301 USA % % See the file "COPYING" for the text of the license. if nargin < 2; doplot = 0; end if doplot > 0 % seed the RNG when doing a verbose run for consistent results rand('state', 0); end % Check for needed packages if exist('mixgauss_em') ~= 2 disp(['You need to install the HMM toolbox components from ' ... 'http://www.cs.ubc.ca/~murphyk/Software/HMM/hmm.html on ' ... 'your path']); error end % Read the fold definitions [trnfiles, tstfiles] = textread(folds,'%s %s\n'); ncuts = length(trnfiles); % how many labs? Assume all classes are represented in % first training set ll = labelsfor(trnfiles{1}); ulabs = unique(ll); nlabs = length(ulabs); % find the number of test items in each cut for c = 1:ncuts ll = textread(tstfiles{c},'%s\n'); cutlen(c) = length(ll); end cumcl = cumsum([0 cutlen]); ntest = cumcl(end); lhoods = zeros(nlabs, ntest); sys = zeros(1, ntest); gts = zeros(1, ntest); confusm = zeros(nlabs,nlabs); durn = 0; for cut = 1:ncuts tic; [accs(cut),cnf,lh,models,sy,gt] = trntest_1fold(trnfiles{cut},tstfiles{cut}); durn = durn + toc; confusm = confusm+cnf; lhoods(:,cumcl(cut)+[1:cutlen(cut)]) = lh; sys(cumcl(cut)+[1:cutlen(cut)]) = sy; gts(cumcl(cut)+[1:cutlen(cut)]) = gt; end acc = sum(diag(confusm))/sum(confusm(:)); rstr = [datestr(rem(now,1),'HH:MM:SS'),' acc=',num2str(acc), ... ' durn=',num2str(durn)]; disp(rstr); % display actual correct track counts per class? %cstr = sprintf('%2d ',diag(confusm)); %disp(cstr); if doplot for i = 1:nlabs; ulb2{i} = ulabs{i}([1 2]); end imagesc(confusm); axis xy xlabel('recog'); ylabel('true'); set(gca,'YTick',1:nlabs); set(gca,'YTickLabel',ulabs); set(gca,'XTick',1:nlabs); set(gca,'XTickLabel',ulb2); colormap(1-gray) colorbar end