bindata2ilda.cc

  1 #include "ImageGray.h"
  2 
  3 #include "LDAMachine.h"
  4 #include "matrix.h"
  5 
  6 #include "CmdLine.h"
  7 
  8 using namespace Torch;
  9 
 10 const char *help = "This program reads a bindata file and and LDA model and projects patterns into it\
 11 (c) Sebastien Marcel 2003-2004\n";
 12 
 13 int main(int argc, char *argv[])
 14 {
 15 	char *bindata_filename;
 16 	char *model_filename;
 17 	char *ilda_filename;
 18 	int n_input;
 19 	int n_output;
 20 	char *filename_out;
 21 	real variance;
 22 	bool verbose;
 23 	int verbose_level;
 24 	bool norm_d_mean;
 25 	CmdLine cmd;
 26 	cmd.setBOption("write log", false);
 27 	
 28 	cmd.info(help);
 29 	cmd.addText("\nArguments:");
 30 	cmd.addSCmdArg("bindata filename", &bindata_filename, "bindata filename");
 31 	cmd.addSCmdArg("model filename", &model_filename, "LDA model filename");
 32 	cmd.addICmdArg("n_input", &n_input, "number of inputs");
 33 	cmd.addText("\nOptions:");
 34 	cmd.addSCmdOption("-o", &filename_out, "ilda.bindata", "bindata output file");
 35 	cmd.addRCmdOption("-variance", &variance, -1.0, "variance (-1 100\%)");
 36 	cmd.addICmdOption("-n_output", &n_output, -1, "number of outputs (-1 auto)");
 37 	cmd.addBCmdOption("-verbose", &verbose, false, "verbose");
 38 	cmd.addICmdOption("-verbose_level", &verbose_level, 1, "level of verbose");
 39 	cmd.addSCmdOption("-ilda", &ilda_filename, "", "iLDA model filename");
 40 	cmd.addBCmdOption("-dnorm", &norm_d_mean, false, "norm d mean");
 41 	cmd.read(argc, argv);
 42 
 43 	if(verbose == false) verbose_level = 0;
 44 
 45 	if(n_output > n_input) error("n_output > n_input");
 46 
 47 	//
 48 	LDAMachine *lda_machine = NULL;
 49 	lda_machine = new LDAMachine(n_input, norm_d_mean);
 50 
 51 	if(verbose_level >= 1)
 52 		print("Loading LDA model: %s ...\n", model_filename);
 53 
 54 	DiskXFile *file = NULL;
 55 	file = new DiskXFile(model_filename, "r");
 56 	lda_machine->loadXFile(file);
 57 	delete file;
 58 	   
 59 	//
 60 	lda_machine->setIOption("verbose_level", verbose_level);
 61 	lda_machine->setROption("variance", variance);
 62 	lda_machine->init();
 63 
 64 	if(n_output != -1) lda_machine->n_outputs = n_output;
 65 	else n_output = lda_machine->n_outputs;
 66 
 67 	//
 68 	Mat *E_inv = NULL;
 69 	E_inv = new Mat(n_input, n_input);
 70 	DiskXFile *pfMxInverse = NULL;
 71 
 72 	if(strcmp(ilda_filename, "") == 0)
 73 	{
 74 		print("Copying eigenvectors ...\n");
 75 	   
 76 		Mat *eigenvectors = NULL;
 77         	eigenvectors = new Mat(n_input, n_input);
 78 
 79 		for(int i = 0 ; i < n_input ; i++)
 80        			for(int j = 0 ; j < n_input ; j++)
 81 				eigenvectors->ptr[j][i] = lda_machine->eigenvectors[j][i];
 82 			
 83 		print("Inversion of eigenvectors matrix ...\n");
 84 	
 85 		mxInverse(eigenvectors, E_inv);
 86 
 87 		print("Saving matrix inverse ...\n");
 88 		pfMxInverse = new DiskXFile("ilda.matrix", "w");
 89 		pfMxInverse->write(&n_input, sizeof(int), 1);
 90        		for(int j = 0 ; j < n_input ; j++)
 91 			pfMxInverse->write(E_inv->ptr[j], sizeof(double), n_input);
 92 	
 93 		delete eigenvectors;
 94 	}
 95 	else
 96 	{
 97 		print("Loading matrix inverse %s ...\n", ilda_filename);
 98 		pfMxInverse = new DiskXFile(ilda_filename, "r");
 99 		int n_input_;
100 		pfMxInverse->read(&n_input_, sizeof(int), 1);
101 		if(n_input_ != n_input)
102 			error("Number of inputs %d != %d incorrect\n", n_input_, n_input);
103        		for(int j = 0 ; j < n_input ; j++)
104 			pfMxInverse->read(E_inv->ptr[j], sizeof(double), n_input);
105 	}
106 
107 	delete pfMxInverse;
108 
109 
110 	//
111 	int dimIn;
112 	int n_patterns;
113 
114 	DiskXFile *pf = NULL;
115 	pf = new DiskXFile(bindata_filename, "r");
116 
117 	pf->read(&n_patterns, sizeof(int), 1);
118 	pf->read(&dimIn, sizeof(int), 1);
119 
120 	if(verbose_level >= 1)
121 	{
122 		print("n_inputs : %d\n", dimIn);
123 		print("n_patterns : %d\n", n_patterns);  
124 	}
125 
126 	if(n_input > dimIn)
127 	{
128 		error("Number of inputs specified (%d) bigger than into the file (%d)", n_input, dimIn);
129 	   
130 		delete pf;
131 
132 		return 0;
133 	}
134 
135 
136 	//
137 	float *realinput = NULL;
138         Sequence *seq;
139 	float *ilda = NULL;
140 
141         realinput = new float [n_input];
142         seq = new Sequence(&realinput, 1, n_input);	
143         ilda = new float [n_input];
144 
145 	//
146 	DiskXFile *pfOutput = NULL;
147 	pfOutput = new DiskXFile(filename_out, "w");
148 
149 	//
150 	//
151 	if(verbose_level >= 1)
152 		print("Projection bindata file into PCA space (%d -> %d) ...\n", n_input, n_output);
153 
154 	int P = n_patterns;
155 
156 	pfOutput->write(&P, sizeof(int), 1);
157 	pfOutput->write(&dimIn, sizeof(int), 1);
158 
159 	real MSE = 0.0;
160 
161 	for(int p = 0 ; p < P ; p++)
162 	{
163 	   	//
164 	  	for(int i = 0 ; i < n_input ; i++)
165 			pf->read(&realinput[i], sizeof(float), 1);   
166 
167 		if(verbose_level >= 2) 
168 			print(" Seq =     [%2.3f %2.3f %2.3f ...]\n", realinput[0], realinput[1], realinput[2]);
169 
170 		//
171 		lda_machine->forward(seq);
172 			
173 		if(verbose_level >= 2) 
174 			print(" Output =   [%2.3f %2.3f %2.3f ...]\n", lda_machine->outputs->frames[0][0], lda_machine->outputs->frames[0][1], lda_machine->outputs->frames[0][2]);
175 
176 		float *machine_output = lda_machine->outputs->frames[0];
177 		
178 		if(norm_d_mean == true)
179 			for(int j = 0 ; j < n_output ; j++) machine_output[j] += lda_machine->d_m_mean_[j];
180 
181 		for(int i = 0 ; i < n_input ; i++)
182 		{
183 			ilda[i] = 0.0;
184 
185 			for(int j = 0 ; j < n_output ; j++)
186 				ilda[i] += E_inv->ptr[j][i] * machine_output[j];
187 		}
188 		
189 		//
190 		float data;
191 		float mse = 0;
192 		for(int i = 0 ; i < n_input ; i++)
193 		{
194 		   	data = ilda[i];
195 			
196 			real z = realinput[i] - data;
197 			mse += z*z;
198 
199 			pfOutput->write(&data, sizeof(float), 1);   
200 		}
201 		
202 		mse /= (float) n_input;
203 
204 		print(" mse = %g\n", mse);
205 		
206 		MSE += mse;
207 	}
208 	
209 	MSE /= (real) P;
210 
211 	print("MSE = %g\n", MSE);
212 
213 	//
214 	delete pf;
215 	delete pfOutput;
216 	
217 	//
218 	delete seq;
219 	delete [] realinput;
220 	delete [] ilda;
221 	delete lda_machine;
222 	
223 
224 	
225   	return 0;
226 }
227