trainPixStump.cc

  1 const char *help = "\
  2 progname: trainPixStump.cc\n\
  3 code2html: This program trains a linear combination of pixel-based stump classifiers.\n\
  4 version: Torch3 vision2.0, 2003-2005\n\
  5 (c) Sebastien Marcel (marcel@idiap.ch) and Yann Rodriguez (rodrig@idiap.ch)\n";
  6 
  7 // Torch
  8 
  9 // trainers
 10 #include "TwoClassFormat.h"
 11 #include "Boosting.h"
 12 
 13 // measurers
 14 #include "ClassMeasurer.h"
 15 
 16 // command-lines
 17 #include "FileListCmdOption.h"
 18 #include "CmdLine.h"
 19 
 20 /** Torch3vision
 21 */
 22 
 23 // Stump machines and trainers
 24 #include "DiscreteStumpMachine.h"
 25 #include "DiscreteStumpTrainer.h"
 26 #include "RealStumpMachine.h"
 27 #include "RealStumpTrainer.h"
 28 #include "ImageWeightedSumMachine.h"
 29 
 30 // datasets
 31 #include "FileBinDataSet.h"
 32 
 33 // image processing
 34 #include "ipHistoEqual.h"
 35 #include "ipNormMeanStdvLight.h"
 36 
 37 using namespace Torch;
 38 
 39 int main(int argc, char **argv)
 40 {
 41    	//
 42    	int width;
 43 	int height;
 44 
 45 	//
 46 	int n_trainers = 10;
 47 	
 48 	//
 49 	char *model_filename;
 50 
 51 	//
 52 	bool image_normalize;
 53 	bool equal_histo;
 54 	bool realstump;
 55 	
 56 	Allocator *allocator = new Allocator;
 57 
 58 	FileListCmdOption filelist_class1("file name", "the list files or one data file of positive patterns");
 59 	filelist_class1.isArgument(true);
 60 
 61 	FileListCmdOption filelist_class0("file name", "the list files or one data file of negative patterns");
 62 	filelist_class0.isArgument(true);
 63 
 64 	//
 65 	// Prepare the command-line
 66 	CmdLine cmd;
 67 	cmd.setBOption("write log", false);
 68 	cmd.info(help);
 69 	cmd.addText("\nArguments:");
 70 	cmd.addCmdOption(&filelist_class1);
 71 	cmd.addCmdOption(&filelist_class0);
 72 	cmd.addICmdArg("width", &width, "width");
 73 	cmd.addICmdArg("height", &height, "height");
 74 	cmd.addText("\nOptions:");
 75 	cmd.addBCmdOption("-imagenorm", &image_normalize, false, "considers the input pattern as an image and performs a photometric normalization");
 76 	cmd.addBCmdOption("-equalh", &equal_histo, false, "perform histogram equalization");
 77 	cmd.addICmdOption("-n", &n_trainers, 10, "number of classifiers to train");
 78 	cmd.addSCmdOption("-o", &model_filename, "model.wsm", "model filename");
 79 	cmd.addBCmdOption("-real", &realstump, false, "uses real weak classifiers");
 80 
 81 	//
 82 	// Read the command-line
 83 	cmd.read(argc, argv);
 84 
 85 	//
 86 	print(" + class 1:\n");
 87         print("   n_filenames = %d\n", filelist_class1.n_files);
 88         for(int i = 0 ; i < filelist_class1.n_files ; i++)
 89                 print("   filename[%d] = %s\n", i, filelist_class1.file_names[i]);
 90 
 91         print(" + class 0:\n");
 92         print("   n_filenames = %d\n", filelist_class0.n_files);
 93         for(int i = 0 ; i < filelist_class0.n_files ; i++)
 94                 print("   filename[%d] = %s\n", i, filelist_class0.file_names[i]);
 95 
 96 	int n_inputs = width * height;
 97 
 98 	real the_target = 1.0;
 99 
100 	FileBinDataSet *data = NULL;
101 	data = new(allocator) FileBinDataSet(
102 	      			filelist_class1.file_names, filelist_class1.n_files, the_target,
103         			filelist_class0.file_names, filelist_class0.n_files, -the_target, n_inputs);
104 
105         data->info(false);
106 
107 	//
108 	if(image_normalize)
109 	{
110 	   	ipCore *imachine = NULL;
111 
112 		if(equal_histo)
113 			imachine = new(allocator) ipHistoEqual(width, height, "float");
114 		else 
115 			imachine = new(allocator) ipNormMeanStdvLight(width, height, "float");
116 	
117 		for(int i=0; i< data->n_examples; i++)
118                 {
119                         data->setExample(i);
120 
121                         imachine->process(data->inputs);
122                 }
123 
124 	}
125 	
126 	//
127 	Trainer **trainers = (Trainer **)allocator->alloc(n_trainers*sizeof(Trainer *));
128 	for(int j = 0 ; j < n_trainers ; j++)
129 	{
130 	   	if(realstump)
131 		{
132 			RealStumpMachine *s_machine = new(allocator) RealStumpMachine(n_inputs);
133 			trainers[j] = new(allocator) RealStumpTrainer(s_machine);
134 		}
135 		else
136 		{
137 			DiscreteStumpMachine *s_machine = new(allocator) DiscreteStumpMachine(n_inputs);
138 			trainers[j] = new(allocator) DiscreteStumpTrainer(s_machine);
139 		}
140 		trainers[j]->setBOption("verbose", true);
141 	}
142 	 
143 	//
144 	ImageWeightedSumMachine *iwsm = new(allocator) ImageWeightedSumMachine(trainers, n_trainers, NULL);
145 
146 	//
147 	TwoClassFormat *class_format = new(allocator) TwoClassFormat(data);
148 	Boosting *boost = new(allocator) Boosting(iwsm, class_format);
149 
150 	//
151 	MeasurerList measurers;
152         ClassMeasurer *class_meas = new(allocator) ClassMeasurer(iwsm->outputs, data, class_format, cmd.getXFile("the_class_err"));
153         measurers.addNode(class_meas);
154 
155 	//
156 	boost->train(data, &measurers);
157 
158 	//
159 	DiskXFile *model = new(allocator) DiskXFile(model_filename, "w");
160 	model->taggedWrite(&n_inputs, sizeof(int), 1, "N_INPUTS");
161 	model->taggedWrite(&n_trainers, sizeof(int), 1, "N_TRAINERS");
162 	iwsm->saveXFile(model);
163 
164 	//
165   	delete allocator;
166 
167   	return(0);
168 }