trainHaarStump.cc

  1 const char *help = "\
  2 progname: trainHaarStump.cc\n\
  3 code2html: This program trains a linear combinaison of Haar-like stump classifiers.\n\
  4 version: Torch3 vision2.0, 2003-2005\n\
  5 (c) Sebastien Marcel (marcel@idiap.ch) and Yann Rodriguez (rodrig@idiap.ch)\n";
  6 
  7 // Torch
  8 
  9 // trainers
 10 #include "TwoClassFormat.h"
 11 #include "Boosting.h"
 12 #include "BoostingReal.h"
 13 
 14 // measurers
 15 #include "ClassMeasurer.h"
 16 
 17 // command-lines
 18 #include "FileListCmdOption.h"
 19 #include "CmdLine.h"
 20 
 21 /** Torch3vision
 22 */
 23 
 24 // feature maker
 25 #include "HaarFeatureMaker.h"
 26 
 27 // Stump machines and trainers
 28 #include "HaarStumpMachine.h"
 29 #include "StumpTrainer.h"
 30 #include "HaarRealStumpMachine.h"
 31 #include "RealStumpTrainer.h"
 32 #include "ImageWeightedSumMachine.h"
 33 
 34 // datasets
 35 #include "FileBinDataSet.h"
 36 
 37 // image processing
 38 #include "ipIntegralImage.h"
 39 #include "ipNormMeanStdvLight.h"
 40 
 41 using namespace Torch;
 42 
 43 int main(int argc, char **argv)
 44 {
 45    	//
 46    	int width;
 47 	int height;
 48 
 49 	//
 50 	int n_trainers = 10;
 51 	
 52 	//
 53 	char *model_filename;
 54 
 55 	//
 56 	bool image_normalize;
 57 	bool realboost;
 58 	
 59 	Allocator *allocator = new Allocator;
 60 
 61 	FileListCmdOption filelist_class1("file name", "the list files or one data file of positive patterns");
 62 	filelist_class1.isArgument(true);
 63 
 64 	FileListCmdOption filelist_class0("file name", "the list files or one data file of negative patterns");
 65 	filelist_class0.isArgument(true);
 66 
 67 	//
 68 	// Prepare the command-line
 69 	CmdLine cmd;
 70 	cmd.setBOption("write log", false);
 71 	cmd.info(help);
 72 	cmd.addText("\nArguments:");
 73 	cmd.addCmdOption(&filelist_class1);
 74 	cmd.addCmdOption(&filelist_class0);
 75 	cmd.addICmdArg("width", &width, "width");
 76 	cmd.addICmdArg("height", &height, "height");
 77 	cmd.addText("\nOptions:");
 78 	cmd.addBCmdOption("-realboost", &realboost, false, "use real boosting");
 79 	cmd.addBCmdOption("-imagenorm", &image_normalize, false, "considers the input pattern as an image and performs a photometric normalization");
 80 	cmd.addICmdOption("-n", &n_trainers, 10, "number of classifiers to train");
 81 	cmd.addSCmdOption("-o", &model_filename, "model.wsm", "model filename");
 82 
 83 	//
 84 	// Read the command-line
 85 	cmd.read(argc, argv);
 86 
 87 	//
 88 	print(" + class 1:\n");
 89         print("   n_filenames = %d\n", filelist_class1.n_files);
 90         for(int i = 0 ; i < filelist_class1.n_files ; i++)
 91                 print("   filename[%d] = %s\n", i, filelist_class1.file_names[i]);
 92 
 93         print(" + class 0:\n");
 94         print("   n_filenames = %d\n", filelist_class0.n_files);
 95         for(int i = 0 ; i < filelist_class0.n_files ; i++)
 96                 print("   filename[%d] = %s\n", i, filelist_class0.file_names[i]);
 97 
 98 	int n_inputs = width * height;
 99 
100 	real the_target = 1.0;
101 
102 	FileBinDataSet *data = NULL;
103 	data = new(allocator) FileBinDataSet(
104 	      			filelist_class1.file_names, filelist_class1.n_files, the_target,
105         			filelist_class0.file_names, filelist_class0.n_files, -the_target, n_inputs);
106 
107         data->info(false);
108 
109 	//
110 	print("Pre-processing ...\n");
111 	
112 	ipCore *i_machine = NULL;
113 	ipCore *inorm_machine = NULL;
114 	
115 	i_machine = new(allocator) ipIntegralImage(width, height, "gray");
116 
117 	if(image_normalize)
118 		inorm_machine = new(allocator) ipNormMeanStdvLight(width, height, "float");
119 	
120 //#define TRACE
121 	
122 	for(int i=0; i< data->n_examples; i++)
123 	{
124 		data->setExample(i);
125 		real *input_ = data->inputs->frames[0];
126 #ifdef TRACE
127 		printf(" ORGI = [ ");
128 		for(int j = 0 ; j < width * height ; j++)
129 			printf("%g ", input_[j]);
130 		printf("]\n");
131 #endif
132 
133 		Sequence *seqin = data->inputs;
134 		
135 		if(image_normalize)
136 		{
137 		   	// normalize the image
138 			inorm_machine->process(seqin);
139 
140 			input_ = inorm_machine->seq_out->frames[0];			
141 		
142 #ifdef TRACE
143 			printf(" NORM = [ ");
144 			for(int j = 0 ; j < width * height ; j++)
145 				printf("%g ", input_[j]);
146 			printf("]\n");
147 #endif
148 			seqin = inorm_machine->seq_out;			
149 		}
150 			
151 		// computes its integral image
152 		i_machine->process(seqin);
153 
154 		real *output_ = i_machine->seq_out->frames[0];
155 		
156 #ifdef TRACE
157 		printf(" INTI = [ ");
158 #endif
159 		for(int j = 0 ; j < width * height ; j++)
160 		{
161 			data->inputs->frames[0][j] = output_[j];
162 #ifdef TRACE
163 			printf("%g ", output_[j]);
164 #endif
165 		}
166 #ifdef TRACE
167 		printf("]\n");
168 #endif
169 	}
170 	
171 	//
172 	print("Haar feature maker ...\n");
173 	
174 	HaarFeatureMaker *haar = new(allocator) HaarFeatureMaker(5, width, height, 2, 19);
175 	
176 	print(" + number of features = %d\n", haar->n_features);
177 	
178 
179 	//
180 	Trainer **trainers = (Trainer **)allocator->alloc(n_trainers*sizeof(Trainer *));
181 	for(int j = 0 ; j < n_trainers ; j++)
182 	{
183 	   	if(realboost)
184 		{
185 			HaarRealStumpMachine *s_machine = new(allocator) HaarRealStumpMachine(haar->n_features, haar->mask);
186 			trainers[j] = new(allocator) RealStumpTrainer(s_machine);
187 			trainers[j]->setBOption("verbose", false);
188 		}
189 		else
190 		{
191 			HaarStumpMachine *s_machine = new(allocator) HaarStumpMachine(haar->n_features, haar->mask);
192 			trainers[j] = new(allocator) StumpTrainer(s_machine);
193 			trainers[j]->setBOption("verbose", true);
194 		}
195 	}
196 	 
197 	//
198 	ImageWeightedSumMachine *iwsm = new(allocator) ImageWeightedSumMachine(trainers, n_trainers, NULL);
199 
200 	//
201 	TwoClassFormat *class_format = new(allocator) TwoClassFormat(data);
202 
203 	Trainer *boost = NULL;
204 	
205 	if(realboost)
206 		boost = new(allocator) BoostingReal(iwsm, class_format);
207 	else boost = new(allocator) Boosting(iwsm, class_format);
208 
209 	//
210 	MeasurerList measurers;
211         ClassMeasurer *class_meas = new(allocator) ClassMeasurer(iwsm->outputs, data, class_format, cmd.getXFile("the_class_err"));
212         measurers.addNode(class_meas);
213 
214 	//
215 	boost->train(data, &measurers);
216 
217 	//
218 	DiskXFile *model = new(allocator) DiskXFile(model_filename, "w");
219 	model->taggedWrite(&n_inputs, sizeof(int), 1, "N_INPUTS");
220 	model->taggedWrite(&n_trainers, sizeof(int), 1, "N_TRAINERS");
221 	iwsm->saveXFile(model);
222 
223 	//
224   	delete allocator;
225 
226   	return(0);
227 }