trainMLP.cc

  1 const char *help = "\
  2 progname: trainMLP.cc\n\
  3 code2html: This program trains a MLP with sigmoid outputs for 2 class classification.\n\
  4 version: Torch3 vision2.0, 2003-2005\n\
  5 (c) Sebastien Marcel (marcel@idiap.ch)\n";
  6 
  7 /** Torch
  8 */
  9 #include "Random.h"
 10 #include "DiskXFile.h"
 11 
 12 // criterions
 13 #include "ClassNLLCriterion.h"
 14 #include "TwoClassNLLCriterion.h"
 15 #include "MSECriterion.h"
 16 
 17 // class formats
 18 #include "TwoClassFormat.h"
 19 
 20 // measurers
 21 #include "ClassMeasurer.h"
 22 #include "MSEMeasurer.h"
 23 #include "NLLMeasurer.h"
 24 
 25 // trainers
 26 #include "StochasticGradient.h"
 27 #include "KFold.h"
 28 
 29 // command-lines
 30 #include "CmdLine.h"
 31 #include "FileListCmdOption.h"
 32 
 33 
 34 /** Torch3vision
 35 */
 36 
 37 // datasets
 38 #include "FileBinDataSet.h"
 39 
 40 // custommachines
 41 #include "MyMLP.h"
 42 #include "MeanVarNorm.h"
 43 
 44 // image processing
 45 #include "ipHistoEqual.h"
 46 #include "ipSmoothGaussian3.h"
 47 
 48 using namespace Torch;
 49 
 50 int main(int argc, char **argv)
 51 {
 52    	//
 53   	int n_inputs;
 54 
 55 	//
 56 	real the_target;
 57 
 58 	//
 59   	int n_hu;
 60 	
 61 	//
 62 	int width_pattern;
 63 	int height_pattern;
 64 
 65 	//
 66   	int max_load;
 67   	int the_seed;
 68 
 69 	//
 70   	real accuracy;
 71   	real learning_rate;
 72   	real decay;
 73   	int max_iter;
 74   	int k_fold;
 75   	real weight_decay;
 76 
 77 	//
 78 	bool use_mse;
 79 	bool use_nll;
 80 	bool use_linear_output;
 81 	bool image_normalize;
 82 
 83 	//
 84   	char *dir_name;
 85   	char *model_file;
 86 
 87 	//
 88   	Allocator *allocator = new Allocator;
 89   	DiskXFile::setLittleEndianMode();
 90 
 91   	//=================== The command-line ==========================
 92 
 93         FileListCmdOption filelist_class1("file name", "the list files or one data file of positive patterns");
 94         filelist_class1.isArgument(true);
 95   
 96         FileListCmdOption filelist_class0("file name", "the list files or one data file of negative patterns");
 97         filelist_class0.isArgument(true);
 98 
 99   	// Construct the command line
100   	CmdLine cmd;
101 	cmd.setBOption("write log", false);
102 
103   	// Put the help line at the beginning
104   	cmd.info(help);
105 
106   	// Train mode
107   	cmd.addText("\nArguments:");
108 	cmd.addCmdOption(&filelist_class1);
109         cmd.addCmdOption(&filelist_class0);
110   	cmd.addICmdArg("n_inputs", &n_inputs, "input dimension of the data");
111 
112   	cmd.addText("\nModel Options:");
113   	cmd.addICmdOption("-nhu", &n_hu, 25, "number of hidden units");
114 
115   	cmd.addText("\nLearning Options:");
116   	cmd.addICmdOption("-iter", &max_iter, 25, "max number of iterations");
117   	cmd.addRCmdOption("-lr", &learning_rate, 0.01, "learning rate");
118   	cmd.addRCmdOption("-e", &accuracy, 0.00001, "end accuracy");
119   	cmd.addRCmdOption("-lrd", &decay, 0, "learning rate decay");
120   	cmd.addICmdOption("-kfold", &k_fold, -1, "number of folds, if you want to do cross-validation");
121   	cmd.addRCmdOption("-wd", &weight_decay, 0, "weight decay");
122   	cmd.addBCmdOption("-mse", &use_mse, false, "use MSE criterion");
123   	cmd.addBCmdOption("-nll", &use_nll, false, "use NLL criterion");
124   	cmd.addBCmdOption("-linear", &use_linear_output, false, "use linear output (tanh otherwise)");
125   	cmd.addRCmdOption("-target", &the_target, 0.6, "the target value (overided if NLL)");
126 
127   	cmd.addText("\nMisc Options:");
128   	cmd.addICmdOption("-seed", &the_seed, -1, "the random seed");
129   	cmd.addICmdOption("-load", &max_load, -1, "max number of examples to load for train");
130   	cmd.addSCmdOption("-dir", &dir_name, ".", "directory to save measures");
131   	cmd.addSCmdOption("-save", &model_file, "", "the model file");
132 
133   	cmd.addText("\nImage Options:");
134   	cmd.addICmdOption("-width", &width_pattern, 19, "the width of the pattern");
135   	cmd.addICmdOption("-height", &height_pattern, 19, "the height of the pattern");
136   	cmd.addBCmdOption("-imagenorm", &image_normalize, false, "considers the input pattern as an image and performs a photometric normalization");
137 
138   	// Read the command line
139   	cmd.read(argc, argv);
140 
141 	//
142 	if(use_mse == use_nll) error("choose between MSE or NLL criterion.");
143 
144 	//
145 	if(use_mse) print("Using MSE criterion ...\n");
146 	if(use_nll) print("Using NLL criterion ...\n");
147 	if(image_normalize)
148 	{
149 		print("Perform photometric normalization ...\n");
150 
151 		if(width_pattern * height_pattern != n_inputs) error("incorrect image size.");
152 		
153 		print("The input pattern is an %dx%d image.\n", width_pattern, height_pattern);
154 	}
155 
156 	//
157     	if(the_seed == -1) Random::seed();
158     	else Random::manualSeed((long)the_seed);
159   
160 	//
161 	cmd.setWorkingDirectory(dir_name);
162 	
163 	//
164         print(" + class 1:\n");
165         print("   n_filenames = %d\n", filelist_class1.n_files);
166         for(int i = 0 ; i < filelist_class1.n_files ; i++)
167                 print("   filename[%d] = %s\n", i, filelist_class1.file_names[i]);
168 
169         print(" + class 0:\n");
170         print("   n_filenames = %d\n", filelist_class0.n_files);
171         for(int i = 0 ; i < filelist_class0.n_files ; i++)
172                 print("   filename[%d] = %s\n", i, filelist_class0.file_names[i]);
173 
174   	// Create the MLP
175 	MyMLP *mlp = NULL;
176 
177 	if(n_hu != 0)
178 	{
179 		if(use_linear_output) mlp = new(allocator) MyMLP(2, n_inputs, "tanh", n_hu, "linear", 1);
180 		else mlp = new(allocator) MyMLP(2, n_inputs, "tanh", n_hu, "tanh", 1);
181 	}
182 	else
183 	{
184 		if(use_linear_output) mlp = new(allocator) MyMLP(n_inputs, "linear", 1);
185 		else mlp = new(allocator) MyMLP(n_inputs, "tanh", 1);
186 	}
187 
188 	mlp->setWeightDecay(weight_decay);
189   	mlp->setPartialBackprop();
190 	mlp->info();
191 
192   	//
193   	// Create the training dataset (normalize inputs)
194   	MeanVarNorm *mv_norm = NULL;
195 	 
196 	if(use_nll) the_target = 1.0;
197 
198     	FileBinDataSet *bindata = NULL;
199 
200 	bindata = new(allocator) FileBinDataSet(filelist_class1.file_names, filelist_class1.n_files, the_target, 
201 			filelist_class0.file_names, filelist_class0.n_files, -the_target, n_inputs);
202 	
203 	bindata->info(false);
204 
205 	if(image_normalize)
206 	{
207 		ipHistoEqual *enhancing = new(allocator) ipHistoEqual(width_pattern, height_pattern, "float");
208 		ipCore *smoothing = new(allocator) ipSmoothGaussian3(width_pattern, height_pattern, "gray", 0.25);
209 		
210 		for(int i=0; i< bindata->n_examples; i++)
211 		{
212 	   		bindata->setExample(i);
213 
214 			enhancing->process(bindata->inputs);
215 			smoothing->process(enhancing->seq_out);
216 
217 			for(int j = 0 ; j < width_pattern * height_pattern ; j++)
218 				bindata->inputs->frames[0][j] = smoothing->seq_out->frames[0][j];
219 		}
220 	}
221 
222     	mv_norm = new(allocator) MeanVarNorm(bindata);
223     	bindata->preProcess(mv_norm);
224    
225   	// The list of measurers...
226   	MeasurerList measurers;
227 
228   	// The class format
229   	TwoClassFormat *class_format = NULL;
230     	class_format = new(allocator) TwoClassFormat(bindata);
231 
232   	// Measurers on the training dataset
233     	ClassMeasurer *class_meas = NULL;
234 
235     	class_meas = new(allocator) ClassMeasurer(mlp->outputs, bindata, class_format, cmd.getXFile("classerror.measure"));
236     	measurers.addNode(class_meas);
237 
238 	// the measurer
239     	MSEMeasurer *mse_meas = NULL;
240     	NLLMeasurer *nll_meas = NULL;
241     	
242 	mse_meas = new(allocator) MSEMeasurer(mlp->outputs, bindata, cmd.getXFile("mse.measure"));
243     	measurers.addNode(mse_meas);
244     	if(use_nll) 
245 	{
246 		nll_meas = new(allocator) NLLMeasurer(mlp->outputs, bindata, cmd.getXFile("nll.measure"));
247 		measurers.addNode(nll_meas);
248 	}
249   
250   	//=================== The Trainer ===============================
251   
252   	// The criterion for the StochasticGradient (MSE criterion or NLL criterion)
253   	Criterion *criterion = NULL;
254 
255 	if(use_mse) criterion = new(allocator) MSECriterion(1);
256     	//if(use_nll) criterion = new(allocator) ClassNLLCriterion(class_format);
257     	if(use_nll) criterion = new(allocator) TwoClassNLLCriterion(0.0);
258 
259   	// The Gradient Machine Trainer
260   	StochasticGradient trainer(mlp, criterion);
261     
262 	trainer.setIOption("max iter", max_iter);
263     	trainer.setROption("end accuracy", accuracy);
264     	trainer.setROption("learning rate", learning_rate);
265     	trainer.setROption("learning rate decay", decay);
266 
267 	//
268   	// Print the number of parameter of the MLP (just for fun)
269 	message("Number of parameters: %d", mlp->params->n_params);
270 
271     	if(k_fold <= 0)
272     	{
273       		trainer.train(bindata, &measurers);
274     
275       		if(strcmp(model_file, "")) mlp->save(model_file, mv_norm);
276     	}
277     	else
278     	{
279 	   	print("Go go KFold.\n");
280 		
281       		KFold k(&trainer, k_fold);
282       		k.crossValidate(bindata, NULL, &measurers);
283     	}
284 
285   	delete allocator;
286 
287 	return(0);
288 }