trainLDA.cc

  1 const char *help = "\
  2 progname: trainLDA.cc\n\
  3 code2html: This program computes a LDA projection sub-space.\n\
  4 version: Torch3 vision2.0, 2003-2005\n\
  5 (c) Sebastien Marcel (marcel@idiap.ch)\n";
  6 
  7 #include "FileBinDataSet.h"
  8 #include "DiskBinDataSet.h"
  9 #include "DirectLDATrainer.h"
 10 //#include "QZLDATrainer.h"
 11 #include "FileListCmdOption.h"
 12 #include "CmdLine.h"
 13 
 14 using namespace Torch;
 15 			
 16 int main(int argc, char **argv)
 17 {
 18   	int n_inputs;
 19 	int offset_window;
 20 	int n_inputs_window;
 21   	char *model_file;
 22 	bool verbose;
 23 	int verbose_level;
 24 	bool classprovided;
 25 	bool within_is_identity;
 26 	bool forward;
 27 	bool use_disk;
 28 	bool saveSbSw;
 29 	bool qz;
 30 	//bool kcf;
 31 	//double eps;
 32 	
 33   	Allocator *allocator = new Allocator;
 34   	DiskXFile::setLittleEndianMode();
 35 
 36 
 37 	
 38   	//=================== The command-line ==========================
 39 	FileListCmdOption filelist("file name", "the list files or one data file");
 40         filelist.isArgument(true);
 41 
 42   	// Construct the command line
 43   	CmdLine cmd;
 44 	cmd.setBOption("write log", false);
 45 	
 46   	// Put the help line at the beginning
 47   	cmd.info(help);
 48 
 49   	cmd.addText("\nArguments:");
 50   	cmd.addCmdOption(&filelist);
 51   	cmd.addICmdArg("n_inputs", &n_inputs, "input dimension of the data", true);
 52   	cmd.addText("\nOptions:");
 53 	cmd.addICmdOption("-offset_window", &offset_window, 0, "offset window", true);
 54 	cmd.addICmdOption("-n_inputs_window", &n_inputs_window, -1, "input dimension of the window", true);
 55   	cmd.addBCmdOption("-verbose", &verbose, false, "verbose", true);
 56 	cmd.addICmdOption("-verbose_level", &verbose_level, 0, "verbose level", true);
 57   	cmd.addBCmdOption("-forward", &forward, false, "project all data into lda", true);
 58   	cmd.addBCmdOption("-classprovided", &classprovided, false, "assign the targets", true);
 59   	cmd.addBCmdOption("-within_id", &within_is_identity, false, "the within scatter matrix is an identity matrix", true);
 60   	cmd.addSCmdOption("-save", &model_file, "", "model file", true);
 61 	cmd.addBCmdOption("-use_disk", &use_disk, false, "use disk");
 62 	cmd.addBCmdOption("-saveSbSw", &saveSbSw, false, "save Sb and Sw");
 63 	cmd.addBCmdOption("-qz", &qz, false, "use QZ algorithm to computes LDA");
 64 	//cmd.addBCmdOption("-kcf", &kcf, false, "computes KCF first");
 65 	//cmd.addDCmdOption("-eps", &eps, 0., "precision for QZ");
 66 
 67   	// Read the command line
 68   	cmd.read(argc, argv);
 69 
 70 	if(n_inputs_window == -1) n_inputs_window = n_inputs;	
 71 
 72 	//
 73 	if(verbose)
 74 	{
 75 		print(" + n_filenames = %d\n", filelist.n_files);
 76 		for(int i = 0 ; i < filelist.n_files ; i++)
 77 			print("   filename[%d] = %s\n", i, filelist.file_names[i]);
 78 	}
 79 
 80 	int n_classes = filelist.n_files;
 81 
 82   	//
 83 	// The LDA Machine
 84 	LDAMachine *lda_machine = NULL;
 85 	lda_machine = new(allocator) LDAMachine(n_inputs_window);
 86 	lda_machine->setIOption("verbose_level", verbose_level);
 87 
 88 	//
 89 	// The LDA Trainer
 90 	LDATrainer *lda_trainer = NULL;
 91 
 92 	if(qz)
 93 	{
 94 	   	print("QZ LDA is not available here as it requieres the guptri package\n");
 95 
 96 		//lda_trainer = new(allocator) QZLDATrainer(n_classes, lda_machine, eps, kcf);
 97 	}
 98 	else
 99 	{
100 	   	print("Using Direct LDA\n");
101 
102 		lda_trainer = new(allocator) DirectLDATrainer(n_classes, lda_machine);
103 	}
104 	lda_trainer->setIOption("verbose_level", verbose_level);
105 	lda_trainer->setBOption("class provided", classprovided);
106 	lda_trainer->setBOption("within identity", within_is_identity);
107 	lda_trainer->setBOption("save", saveSbSw);
108 	
109 	//
110 	// Load all the data in the same dataset
111 	real *targets = NULL;
112 
113 	if(classprovided)
114 	{
115 		targets = (real *) allocator->alloc(sizeof(real) * filelist.n_files);
116 		for(int i = 0 ; i < filelist.n_files ; i++) targets[i] = (float) i;
117 	}
118 
119 	DataSet *bindata = NULL;
120 
121 	if(use_disk)
122 	{
123 	   	DiskBinDataSet *bindata_;
124 		
125   		if(classprovided) bindata_ = new(allocator) DiskBinDataSet(filelist.file_names, filelist.n_files, n_inputs, targets, -1);
126 		else bindata_ = new(allocator) DiskBinDataSet(filelist.file_names, filelist.n_files, n_inputs, -1);
127 	  	bindata_->info(false);
128 
129 		bindata = bindata_;
130 	}
131 	else
132 	{
133 	   	FileBinDataSet *bindata_;
134 		
135 		if(classprovided) bindata_ = new(allocator) FileBinDataSet(filelist.file_names, filelist.n_files, n_inputs, offset_window, n_inputs_window, targets);
136 	    	else bindata_ = new(allocator) FileBinDataSet(filelist.file_names, filelist.n_files, n_inputs, offset_window, n_inputs_window);
137 	  	bindata_->info(false);
138 
139 		bindata = bindata_;
140 	}
141 
142 	//
143 	// Computes LDA
144 	lda_trainer->train(bindata);
145 
146 
147 	//
148 	// Projects data into LDA sub-space
149 	if(forward)
150 	{
151 		real *realinput = NULL;
152         	Sequence *seq;
153 
154         	realinput = new real [n_inputs_window];
155         	seq = new Sequence(&realinput, 1, n_inputs_window);
156 	
157 		//
158 		lda_machine->setROption("variance", 1.0);
159 		lda_machine->init();
160 	
161 		for(int i=0; i< bindata->n_examples; i++)
162 		{
163 			if(verbose) 
164 	   			print("[%d]:\n", i);
165 
166 			//
167 			bindata->setExample(i);
168 		
169 			if(verbose) 
170 				print(" Input =   [%2.3f %2.3f %2.3f ...]\n", bindata->inputs->frames[0][0], bindata->inputs->frames[0][1], bindata->inputs->frames[0][2]);
171 
172 			//
173 			bindata->inputs->copyTo(realinput);
174 
175 			if(verbose) 
176 				print(" Seq =     [%2.3f %2.3f %2.3f ...]\n", realinput[0], realinput[1], realinput[2]);
177 
178 			//
179 			lda_machine->forward(seq);
180 			
181 			if(verbose) 
182 				print(" Output =   [%2.3f %2.3f %2.3f ...]\n", lda_machine->outputs->frames[0][0], lda_machine->outputs->frames[0][1], lda_machine->outputs->frames[0][2]);
183 		}
184 	
185 		delete [] realinput;
186 		delete seq;
187 	}
188 
189 
190 	//
191 	// Save the model
192 	if(strcmp(model_file, "") != 0)
193 	{
194 	   	print("Saving LDA model ...\n");
195 
196 		DiskXFile *file = NULL;
197 
198 		file = new DiskXFile(model_file, "w");
199 	
200 		lda_machine->saveXFile(file);
201 	
202 		delete file;
203 	}
204 	
205 	//
206   	delete allocator;
207 
208   	return(0);
209 }