trainHaarStump.cc

  1 const char *help = "\
  2 progname: trainHaarStump.cc\n\
  3 code2html: This program trains a linear combinaison of Haar-like stump classifiers.\n\
  4 version: Torch3 vision2.0, 2003-2005\n\
  5 (c) Sebastien Marcel (marcel@idiap.ch) and Yann Rodriguez (rodrig@idiap.ch)\n";
  6 
  7 // Torch
  8 
  9 // trainers
 10 #include "TwoClassFormat.h"
 11 #include "Boosting.h"
 12 
 13 // measurers
 14 #include "ClassMeasurer.h"
 15 
 16 // command-lines
 17 #include "FileListCmdOption.h"
 18 #include "CmdLine.h"
 19 
 20 /** Torch3vision
 21 */
 22 
 23 // feature maker
 24 #include "HaarFeatureMaker.h"
 25 
 26 // Stump machines and trainers
 27 #include "HaarStumpMachine.h"
 28 #include "DiscreteStumpTrainer.h"
 29 #include "HaarRealStumpMachine.h"
 30 #include "RealStumpTrainer.h"
 31 #include "ImageWeightedSumMachine.h"
 32 
 33 // datasets
 34 #include "FileBinDataSet.h"
 35 
 36 // image processing
 37 #include "ipIntegralImage.h"
 38 #include "ipNormMeanStdvLight.h"
 39 
 40 using namespace Torch;
 41 
 42 int main(int argc, char **argv)
 43 {
 44    	//
 45    	int width;
 46 	int height;
 47 
 48 	//
 49 	int n_trainers = 10;
 50 	
 51 	//
 52 	char *model_filename;
 53 
 54 	//
 55 	bool image_normalize;
 56 	bool realstump;
 57 	
 58 	Allocator *allocator = new Allocator;
 59 
 60 	FileListCmdOption filelist_class1("file name", "the list files or one data file of positive patterns");
 61 	filelist_class1.isArgument(true);
 62 
 63 	FileListCmdOption filelist_class0("file name", "the list files or one data file of negative patterns");
 64 	filelist_class0.isArgument(true);
 65 
 66 	//
 67 	// Prepare the command-line
 68 	CmdLine cmd;
 69 	cmd.setBOption("write log", false);
 70 	cmd.info(help);
 71 	cmd.addText("\nArguments:");
 72 	cmd.addCmdOption(&filelist_class1);
 73 	cmd.addCmdOption(&filelist_class0);
 74 	cmd.addICmdArg("width", &width, "width");
 75 	cmd.addICmdArg("height", &height, "height");
 76 	cmd.addText("\nOptions:");
 77 	cmd.addBCmdOption("-real", &realstump, false, "uses real stump");
 78 	cmd.addBCmdOption("-imagenorm", &image_normalize, false, "considers the input pattern as an image and performs a photometric normalization");
 79 	cmd.addICmdOption("-n", &n_trainers, 10, "number of classifiers to train");
 80 	cmd.addSCmdOption("-o", &model_filename, "model.wsm", "model filename");
 81 
 82 	//
 83 	// Read the command-line
 84 	cmd.read(argc, argv);
 85 
 86 	//
 87 	print(" + class 1:\n");
 88         print("   n_filenames = %d\n", filelist_class1.n_files);
 89         for(int i = 0 ; i < filelist_class1.n_files ; i++)
 90                 print("   filename[%d] = %s\n", i, filelist_class1.file_names[i]);
 91 
 92         print(" + class 0:\n");
 93         print("   n_filenames = %d\n", filelist_class0.n_files);
 94         for(int i = 0 ; i < filelist_class0.n_files ; i++)
 95                 print("   filename[%d] = %s\n", i, filelist_class0.file_names[i]);
 96 
 97 	int n_inputs = width * height;
 98 
 99 	real the_target = 1.0;
100 
101 	FileBinDataSet *data = NULL;
102 	data = new(allocator) FileBinDataSet(
103 	      			filelist_class1.file_names, filelist_class1.n_files, the_target,
104         			filelist_class0.file_names, filelist_class0.n_files, -the_target, n_inputs);
105 
106         data->info(false);
107 
108 	//
109 	print("Pre-processing ...\n");
110 	
111 	ipCore *i_machine = NULL;
112 	ipCore *inorm_machine = NULL;
113 	
114 	i_machine = new(allocator) ipIntegralImage(width, height, "gray");
115 
116 	if(image_normalize)
117 		inorm_machine = new(allocator) ipNormMeanStdvLight(width, height, "float");
118 	
119 //#define TRACE
120 	
121 	for(int i=0; i< data->n_examples; i++)
122 	{
123 		data->setExample(i);
124 		real *input_ = data->inputs->frames[0];
125 #ifdef TRACE
126 		printf(" ORGI = [ ");
127 		for(int j = 0 ; j < width * height ; j++)
128 			printf("%g ", input_[j]);
129 		printf("]\n");
130 #endif
131 
132 		Sequence *seqin = data->inputs;
133 		
134 		if(image_normalize)
135 		{
136 		   	// normalize the image
137 			inorm_machine->process(seqin);
138 
139 			input_ = inorm_machine->seq_out->frames[0];			
140 		
141 #ifdef TRACE
142 			printf(" NORM = [ ");
143 			for(int j = 0 ; j < width * height ; j++)
144 				printf("%g ", input_[j]);
145 			printf("]\n");
146 #endif
147 			seqin = inorm_machine->seq_out;			
148 		}
149 			
150 		// computes its integral image
151 		i_machine->process(seqin);
152 
153 		real *output_ = i_machine->seq_out->frames[0];
154 		
155 #ifdef TRACE
156 		printf(" INTI = [ ");
157 #endif
158 		for(int j = 0 ; j < width * height ; j++)
159 		{
160 			data->inputs->frames[0][j] = output_[j];
161 #ifdef TRACE
162 			printf("%g ", output_[j]);
163 #endif
164 		}
165 #ifdef TRACE
166 		printf("]\n");
167 #endif
168 	}
169 	
170 	//
171 	print("Haar feature maker ...\n");
172 	
173 	HaarFeatureMaker *haar = new(allocator) HaarFeatureMaker(5, width, height, 2, 19);
174 	
175 	print(" + number of features = %d\n", haar->n_features);
176 	
177 
178 	//
179 	Trainer **trainers = (Trainer **)allocator->alloc(n_trainers*sizeof(Trainer *));
180 	for(int j = 0 ; j < n_trainers ; j++)
181 	{
182 	   	if(realstump)
183 		{
184 			HaarRealStumpMachine *s_machine = new(allocator) HaarRealStumpMachine(haar->n_features, haar->mask);
185 			trainers[j] = new(allocator) RealStumpTrainer(s_machine);
186 			trainers[j]->setBOption("verbose", false);
187 		}
188 		else
189 		{
190 			HaarStumpMachine *s_machine = new(allocator) HaarStumpMachine(haar->n_features, haar->mask);
191 			trainers[j] = new(allocator) DiscreteStumpTrainer(s_machine);
192 			trainers[j]->setBOption("verbose", true);
193 		}
194 	}
195 	 
196 	//
197 	ImageWeightedSumMachine *iwsm = new(allocator) ImageWeightedSumMachine(trainers, n_trainers, NULL);
198 
199 	//
200 	TwoClassFormat *class_format = new(allocator) TwoClassFormat(data);
201 
202 	Trainer *boost = NULL;
203 	
204 	boost = new(allocator) Boosting(iwsm, class_format);
205 
206 	//
207 	MeasurerList measurers;
208         ClassMeasurer *class_meas = new(allocator) ClassMeasurer(iwsm->outputs, data, class_format, cmd.getXFile("the_class_err"));
209         measurers.addNode(class_meas);
210 
211 	//
212 	boost->train(data, &measurers);
213 
214 	//
215 	DiskXFile *model = new(allocator) DiskXFile(model_filename, "w");
216 	model->taggedWrite(&n_inputs, sizeof(int), 1, "N_INPUTS");
217 	model->taggedWrite(&n_trainers, sizeof(int), 1, "N_TRAINERS");
218 	iwsm->saveXFile(model);
219 
220 	//
221   	delete allocator;
222 
223   	return(0);
224 }