mbindata2lda.cc

  1 const char *help = "\
  2 progname: mbindata2lda.cc\n\
  3 code2html: This program reads multiple bindata files and a LDA model, projects patterns into it.\n\
  4 version: Torch3 vision2.0, 2003-2005\n\
  5 (c) Sebastien Marcel (marcel@idiap.ch)\n";
  6 
  7 #include "string_utils.h"
  8 #include "FileListCmdOption.h"
  9 #include "str_utils.h"
 10 #include "LDAMachine.h"
 11 #include "CmdLine.h"
 12 
 13 using namespace Torch;
 14 
 15 int main(int argc, char *argv[])
 16 {
 17 	char *dir_name;
 18 	char *model_filename;
 19 	int n_input;
 20 	int n_output;
 21 	char *filename_out;
 22 	real variance;
 23 	bool verbose;
 24 	int verbose_level;
 25 	bool norm_d_mean;
 26 	
 27 	//
 28 	FileListCmdOption filelist("file name", "the list files or one data file");
 29         filelist.isArgument(true);
 30 
 31 	//
 32 	CmdLine cmd;
 33 	cmd.setBOption("write log", false);
 34 	
 35 	cmd.info(help);
 36 	cmd.addText("\nArguments:");
 37   	cmd.addCmdOption(&filelist);
 38 	cmd.addSCmdArg("model filename", &model_filename, "LDA model filename");
 39 	cmd.addICmdArg("n_input", &n_input, "number of inputs");
 40 	cmd.addText("\nOptions:");
 41 	cmd.addSCmdOption("-dir", &dir_name, ".", "dir name");
 42 	cmd.addSCmdOption("-o", &filename_out, "lda.bindata", "bindata output file");
 43 	cmd.addRCmdOption("-variance", &variance, -1.0, "variance (-1 100\%)");
 44 	cmd.addICmdOption("-n_output", &n_output, -1, "number of outputs (-1 auto)");
 45 	cmd.addBCmdOption("-verbose", &verbose, false, "verbose");
 46 	cmd.addICmdOption("-verbose_level", &verbose_level, 1, "level of verbose");
 47 	cmd.addBCmdOption("-dnorm", &norm_d_mean, false, "norm d mean");
 48 	cmd.read(argc, argv);
 49 
 50 	if(verbose == false) verbose_level = 0;
 51 
 52 	if(n_output > n_input) error("n_output > n_input");
 53 
 54 	//
 55 	LDAMachine *lda_machine = NULL;
 56 	lda_machine = new LDAMachine(n_input, norm_d_mean);
 57 
 58 	if(verbose_level >= 1)
 59 		print("Loading LDA model: %s ...\n", model_filename);
 60 
 61 	DiskXFile *file = NULL;
 62 	file = new DiskXFile(model_filename, "r");
 63 	lda_machine->loadXFile(file);
 64 	delete file;
 65 	   
 66 	//
 67 	lda_machine->setIOption("verbose_level", verbose_level);
 68 	lda_machine->setROption("variance", variance);
 69 	lda_machine->init();
 70 
 71 	if(n_output != -1) lda_machine->n_outputs = n_output;
 72 	else n_output = lda_machine->n_outputs;
 73 
 74 	//
 75 	float *realinput = NULL;
 76         Sequence *seq;
 77 
 78         realinput = new float [n_input];
 79         seq = new Sequence(&realinput, 1, n_input);	
 80 
 81 	//
 82 	int dimIn;
 83 	int n_patterns;
 84 
 85 	DiskXFile *pf = NULL;
 86 
 87 	//
 88 	for(int i = 0 ; i < filelist.n_files ; i++)
 89 	{
 90 		char *temp = strBaseName(filelist.file_names[i]);
 91 		char *file_name = strRemoveSuffix(temp);
 92 
 93 		if(verbose)
 94 			print("Processing file %s\n", file_name);
 95 
 96 		pf = new DiskXFile(filelist.file_names[i], "r");
 97 
 98 		pf->read(&n_patterns, sizeof(int), 1);
 99 		pf->read(&dimIn, sizeof(int), 1);
100 
101 		if(verbose_level >= 1)
102 		{
103 			print("n_inputs : %d\n", dimIn);
104 			print("n_patterns : %d\n", n_patterns);  
105 		}
106 
107 		if(n_input > dimIn)
108 		{
109 			error("Number of inputs specified (%d) bigger than into the file (%d)", n_input, dimIn);
110 	   
111 			delete pf;
112 
113 			return 0;
114 		}
115 
116 
117 		//
118 		DiskXFile *pfOutput = NULL;
119 		char filename_out[250];
120 
121 		sprintf(filename_out, "%s/%s.bindata", dir_name, file_name);
122 		
123 		pfOutput = new DiskXFile(filename_out, "w");
124 
125 		//
126 		//
127 		if(verbose_level >= 1)
128 			print("Projection bindata file into PCA space (%d -> %d) ...\n", n_input, n_output);
129 
130 		int P = n_patterns;
131 
132 		pfOutput->write(&P, sizeof(int), 1);
133 		pfOutput->write(&n_output, sizeof(int), 1);
134 
135 		for(int p = 0 ; p < P ; p++)
136 		{
137 	   		//
138 	  		for(int i = 0 ; i < n_input ; i++)
139 				pf->read(&realinput[i], sizeof(float), 1);   
140 
141 			if(verbose_level >= 2) 
142 				print(" Seq =     [%2.3f %2.3f %2.3f ...]\n", realinput[0], realinput[1], realinput[2]);
143 
144 			//
145 			lda_machine->forward(seq);
146 			
147 			if(verbose_level >= 2) 
148 				print(" Output =   [%2.3f %2.3f %2.3f ...]\n", lda_machine->outputs->frames[0][0], lda_machine->outputs->frames[0][1], lda_machine->outputs->frames[0][2]);
149 
150 			float *machine_output = lda_machine->outputs->frames[0];
151 
152 			//
153 			float data;
154 			for(int i = 0 ; i < n_output ; i++)
155 			{
156 		   		data = machine_output[i];
157 
158 				pfOutput->write(&data, sizeof(float), 1);   
159 			}
160 		}
161 
162 		//
163 		delete pf;
164 		delete pfOutput;
165 	}
166 	
167 	//
168 	delete seq;
169 	delete [] realinput;
170 	delete lda_machine;
171 	
172 
173 	
174   	return 0;
175 }
176