trainLDA.cc

  1 const char *help = "\
  2 progname: trainLDA.cc\n\
  3 code2html: This program computes a LDA projection sub-space.\n\
  4 version: Torch3 vision2.0, 2003-2005\n\
  5 (c) Sebastien Marcel (marcel@idiap.ch)\n";
  6 
  7 #include "FileBinDataSet.h"
  8 #include "DiskBinDataSet.h"
  9 #include "MMCTrainer.h"
 10 #include "FisherLDATrainer.h"
 11 #include "FileListCmdOption.h"
 12 #include "CmdLine.h"
 13 
 14 using namespace Torch;
 15 			
 16 int main(int argc, char **argv)
 17 {
 18   	int n_inputs;
 19 	int offset_window;
 20 	int n_inputs_window;
 21   	char *model_file;
 22 	bool verbose;
 23 	int verbose_level;
 24 	bool classprovided;
 25 	bool within_is_identity;
 26 	bool forward;
 27 	bool use_disk;
 28 	bool saveSbSw;
 29 	bool mmc;
 30 	
 31   	Allocator *allocator = new Allocator;
 32   	DiskXFile::setLittleEndianMode();
 33 
 34 
 35 	
 36   	//=================== The command-line ==========================
 37 	FileListCmdOption filelist("file name", "the list files or one data file");
 38         filelist.isArgument(true);
 39 
 40   	// Construct the command line
 41   	CmdLine cmd;
 42 	cmd.setBOption("write log", false);
 43 	
 44   	// Put the help line at the beginning
 45   	cmd.info(help);
 46 
 47   	cmd.addText("\nArguments:");
 48   	cmd.addCmdOption(&filelist);
 49   	cmd.addICmdArg("n_inputs", &n_inputs, "input dimension of the data", true);
 50   	cmd.addText("\nOptions:");
 51 	cmd.addICmdOption("-offset_window", &offset_window, 0, "offset window", true);
 52 	cmd.addICmdOption("-n_inputs_window", &n_inputs_window, -1, "input dimension of the window", true);
 53   	cmd.addBCmdOption("-verbose", &verbose, false, "verbose", true);
 54 	cmd.addICmdOption("-verbose_level", &verbose_level, 0, "verbose level", true);
 55   	cmd.addBCmdOption("-forward", &forward, false, "project all data into lda", true);
 56   	cmd.addBCmdOption("-classprovided", &classprovided, false, "assign the targets", true);
 57   	cmd.addBCmdOption("-within_id", &within_is_identity, false, "the within scatter matrix is an identity matrix", true);
 58   	cmd.addSCmdOption("-save", &model_file, "", "model file", true);
 59 	cmd.addBCmdOption("-use_disk", &use_disk, false, "use disk");
 60 	cmd.addBCmdOption("-saveSbSw", &saveSbSw, false, "save Sb and Sw");
 61 	cmd.addBCmdOption("-mmc", &mmc, false, "computes the Maximum Margin Criterion");
 62 
 63   	// Read the command line
 64   	cmd.read(argc, argv);
 65 
 66 	if(n_inputs_window == -1) n_inputs_window = n_inputs;	
 67 
 68 	//
 69 	if(verbose)
 70 	{
 71 		print(" + n_filenames = %d\n", filelist.n_files);
 72 		for(int i = 0 ; i < filelist.n_files ; i++)
 73 			print("   filename[%d] = %s\n", i, filelist.file_names[i]);
 74 	}
 75 
 76 	int n_classes = filelist.n_files;
 77 
 78   	//
 79 	// The LDA Machine
 80 	LDAMachine *lda_machine = NULL;
 81 
 82 	lda_machine = new(allocator) LDAMachine(n_inputs_window);
 83 	lda_machine->setIOption("verbose_level", verbose_level);
 84 
 85 	//
 86 	// The MMC/LDA Trainer
 87 	Trainer *trainer = NULL;
 88 
 89 	if(mmc)
 90 	{
 91 	   	print("Using MMC\n");
 92 
 93 		trainer = new(allocator) MMCTrainer(n_classes, lda_machine);
 94 	}
 95 	else
 96 	{
 97 	   	print("Using Fisher LDA\n");
 98 
 99 		trainer = new(allocator) FisherLDATrainer(n_classes, lda_machine);
100 		trainer->setBOption("within identity", within_is_identity);
101 	}
102 	trainer->setIOption("verbose_level", verbose_level);
103 	trainer->setBOption("class provided", classprovided);
104 	trainer->setBOption("save", saveSbSw);
105 	
106 	//
107 	// Load all the data in the same dataset
108 	real *targets = NULL;
109 
110 	if(classprovided)
111 	{
112 		targets = (real *) allocator->alloc(sizeof(real) * filelist.n_files);
113 		for(int i = 0 ; i < filelist.n_files ; i++) targets[i] = (float) i;
114 	}
115 
116 	DataSet *bindata = NULL;
117 
118 	if(use_disk)
119 	{
120 	   	DiskBinDataSet *bindata_;
121 		
122   		if(classprovided) bindata_ = new(allocator) DiskBinDataSet(filelist.file_names, filelist.n_files, n_inputs, targets, -1);
123 		else bindata_ = new(allocator) DiskBinDataSet(filelist.file_names, filelist.n_files, n_inputs, -1);
124 	  	bindata_->info(false);
125 
126 		bindata = bindata_;
127 	}
128 	else
129 	{
130 	   	FileBinDataSet *bindata_;
131 		
132 		if(classprovided) bindata_ = new(allocator) FileBinDataSet(filelist.file_names, filelist.n_files, n_inputs, offset_window, n_inputs_window, targets);
133 	    	else bindata_ = new(allocator) FileBinDataSet(filelist.file_names, filelist.n_files, n_inputs, offset_window, n_inputs_window);
134 	  	bindata_->info(false);
135 
136 		bindata = bindata_;
137 	}
138 
139 	//
140 	// Computes LDA
141 	trainer->train(bindata, NULL);
142 
143 
144 	//
145 	// Projects data into LDA sub-space
146 	if(forward)
147 	{
148 		real *realinput = NULL;
149         	Sequence *seq;
150 
151         	realinput = new real [n_inputs_window];
152         	seq = new Sequence(&realinput, 1, n_inputs_window);
153 	
154 		//
155 		lda_machine->setROption("variance", 1.0);
156 		lda_machine->init();
157 	
158 		for(int i=0; i< bindata->n_examples; i++)
159 		{
160 			if(verbose) 
161 	   			print("[%d]:\n", i);
162 
163 			//
164 			bindata->setExample(i);
165 		
166 			if(verbose) 
167 				print(" Input =   [%2.3f %2.3f %2.3f ...]\n", bindata->inputs->frames[0][0], bindata->inputs->frames[0][1], bindata->inputs->frames[0][2]);
168 
169 			//
170 			bindata->inputs->copyTo(realinput);
171 
172 			if(verbose) 
173 				print(" Seq =     [%2.3f %2.3f %2.3f ...]\n", realinput[0], realinput[1], realinput[2]);
174 
175 			//
176 			lda_machine->forward(seq);
177 			
178 			if(verbose) 
179 				print(" Output =   [%2.3f %2.3f %2.3f ...]\n", lda_machine->outputs->frames[0][0], lda_machine->outputs->frames[0][1], lda_machine->outputs->frames[0][2]);
180 		}
181 	
182 		delete [] realinput;
183 		delete seq;
184 	}
185 
186 
187 	//
188 	// Save the model
189 	if(strcmp(model_file, "") != 0)
190 	{
191 	   	print("Saving LDA model ...\n");
192 
193 		DiskXFile *file = NULL;
194 
195 		file = new DiskXFile(model_file, "w");
196 	
197 		lda_machine->saveXFile(file);
198 	
199 		delete file;
200 	}
201 	
202 	//
203   	delete allocator;
204 
205   	return(0);
206 }