/*************************************************************************** * copyright : (C) 2011 by Karel Vesely,UPGM,FIT,VUT,Brno * * email : iveselyk@fit.vutbr.cz * *************************************************************************** * * * This program is free software; you can redistribute it and/or modify * * it under the terms of the APACHE License as published by the * * Apache Software Foundation; either version 2.0 of the License, * * or (at your option) any later version. * * * ***************************************************************************/ #define SVN_DATE "$Date: 2011-09-26 16:48:24 +0200 (Mon, 26 Sep 2011) $" #define SVN_AUTHOR "$Author: iveselyk $" #define SVN_REVISION "$Revision: 73 $" #define SVN_ID "$Id: TNorm.cc 73 2011-09-26 14:48:24Z iveselyk $" #define MODULE_VERSION "1.0.0 "__TIME__" "__DATE__" "SVN_ID /*** KaldiLib includes */ #include "Error.h" #include "Timer.h" #include "Features.h" #include "Common.h" #include "UserInterface.h" #include "Timer.h" /*** TNet includes */ #include "Nnet.h" /*** STL includes */ #include <iostream> #include <sstream> #include <numeric> ////////////////////////////////////////////////////////////////////// // DEFINES // #define SNAME "TNORM" using namespace TNet; void usage(const char* progname) { const char *tchrptr; if ((tchrptr = strrchr(progname, '\\')) != NULL) progname = tchrptr+1; if ((tchrptr = strrchr(progname, '/')) != NULL) progname = tchrptr+1; fprintf(stderr, "\n%s version " MODULE_VERSION "\n" "\nUSAGE: %s [options] DataFiles...\n\n" " Option Default\n\n" " -A Print command line arguments Off\n" " -C cf Set config file to cf Default\n" " -D Display configuration variables Off\n" " -H mmf Load NN macro file \n" " -S file Set script file None\n" " -T N Set trace flags to N 0\n" " -V Print version information Off\n" "\n" "NATURALREADORDER PRINTCONFIG PRINTVERSION SCRIPT SOURCEMMF TARGETMMF TRACE\n" "\n" "STARTFRMEXT ENDFRMEXT CMEANDIR CMEANMASK VARSCALEDIR VARSCALEMASK VARSCALEFN TARGETKIND DERIVWINDOWS DELTAWINDOW ACCWINDOW THIRDWINDOW\n" "\n" " %s is Copyright (C) 2010-2011 Karel Vesely\n" " licensed under the APACHE License, version 2.0\n" " Bug reports, feedback, etc, to: iveselyk@fit.vutbr.cz\n" "\n", progname, progname, progname); exit(-1); } /////////////////////////////////////////////////////////////////////// // MAIN FUNCTION // int main(int argc, char *argv[]) try { const char* p_option_string = " -D n PRINTCONFIG=TRUE" " -H l SOURCEMMF" " -S l SCRIPT" " -T r TRACE" " -V n PRINTVERSION=TRUE" ; UserInterface ui; FeatureRepository features; Network network_cpu; Timer timer; const char* p_script; const char* p_source_mmf_file; const char* p_targetmmf; int traceFlag; // variables for feature repository bool swap_features; int target_kind; int deriv_order; int* p_deriv_win_lenghts; int start_frm_ext; int end_frm_ext; char* cmn_path; char* cmn_file; const char* cmn_mask; char* cvn_path; char* cvn_file; const char* cvn_mask; const char* cvg_file; // OPTION PARSING .......................................................... // use the STK option parsing if (argc == 1) { usage(argv[0]); return 1; } int args_parsed = ui.ParseOptions(argc, argv, p_option_string, SNAME); // OPTION RETRIEVAL ........................................................ // extract the feature parameters swap_features = !ui.GetBool(SNAME":NATURALREADORDER", TNet::IsBigEndian()); target_kind = ui.GetFeatureParams(&deriv_order, &p_deriv_win_lenghts, &start_frm_ext, &end_frm_ext, &cmn_path, &cmn_file, &cmn_mask, &cvn_path, &cvn_file, &cvn_mask, &cvg_file, SNAME":", 0); // extract other parameters p_source_mmf_file = ui.GetStr(SNAME":SOURCEMMF", NULL); p_targetmmf = ui.GetStr(SNAME":TARGETMMF", NULL);//< target for mean/variance p_script = ui.GetStr(SNAME":SCRIPT", NULL); traceFlag = ui.GetInt(SNAME":TRACE", 0); // process the parameters if(ui.GetBool(SNAME":PRINTCONFIG", false)) { std::cout << std::endl; ui.PrintConfig(std::cout); std::cout << std::endl; } if(ui.GetBool(SNAME":PRINTVERSION", false)) { std::cout << std::endl; std::cout << "======= TNET v"MODULE_VERSION" xvesel39 =======" << std::endl; std::cout << std::endl; } ui.CheckCommandLineParamUse(); // the rest of the parameters are the feature files for (; args_parsed < argc; args_parsed++) { features.AddFile(argv[args_parsed]); } //************************************************************************** //************************************************************************** // OPTION PARSING DONE ..................................................... //read the neural network if(NULL != p_source_mmf_file) { if(traceFlag&1) TraceLog(std::string("Reading network: ")+p_source_mmf_file); network_cpu.ReadNetwork(p_source_mmf_file); } else { Error("Source MMF must be specified [-H]"); } // initialize the feature repository features.Init( swap_features, start_frm_ext, end_frm_ext, target_kind, deriv_order, p_deriv_win_lenghts, cmn_path, cmn_mask, cvn_path, cvn_mask, cvg_file ); if(NULL != p_script) { features.AddFileList(p_script); } else { Warning("WARNING: The script file is missing [-S]"); } //********************************************************************** //********************************************************************** // INITIALIZATION DONE ................................................. // // Start training timer.Start(); std::cout << "===== TNorm STARTED =====" << std::endl; int dim = network_cpu.GetNOutputs(); Vector<double> first(dim); first.Set(0.0); Vector<double> second(dim); second.Set(0.0); unsigned long framesN = 0; //progress size_t cnt = 0; size_t step = features.QueueSize() / 100; if(step == 0) step = 1; //********************************************************************** //********************************************************************** // MAIN LOOP for(features.Rewind(); !features.EndOfList(); features.MoveNext()) { Matrix<BaseFloat> feats_host,net_out; Matrix<BaseFloat> feats_host_out; //get features features.ReadFullMatrix(feats_host); //propagate network_cpu.Propagate(feats_host,net_out); //trim the xxx_frm_ext feats_host_out.Init(net_out.Rows()-start_frm_ext-end_frm_ext,net_out.Cols()); memcpy(feats_host_out.pData(),net_out.pRowData(start_frm_ext),feats_host_out.MSize()); //accumulate first/second order statistics for(size_t m=0; m<feats_host_out.Rows(); m++) { for(size_t n=0; n<feats_host_out.Cols(); n++) { BaseFloat val = feats_host_out(m,n); first[n] += val; second[n] += val*val; if(isnan(first[n])||isnan(second[n])|| isinf(first[n])||isinf(second[n])) { std::ostringstream oss; oss << "nan/inf in accumulators\n" << "first:" << first << "\n" << "second:" << second << "\n" << "frames:" << framesN << "\n" << "utterance:" << features.Current().Logical() << "\n" << "feats_host: " << feats_host << "\n" << "feats_host_out: " << feats_host_out << "\n"; Error(oss.str()); } } } framesN += feats_host.Rows(); //progress if((cnt++ % step) == 0) std::cout << 100 * cnt / features.QueueSize() << "%, " << std::flush; } //********************************************************************** //********************************************************************** // ACCUMULATING FINISHED ................................................. // //get the mean/variance vectors Vector<double> mean(first); mean.Scale(1.0/framesN); Vector<double> variance(second); variance.Scale(1.0/framesN); for(size_t i=0; i<mean.Dim(); i++) { variance[i] -= mean[i]*mean[i]; } //get the mean normalization biase vector, //use negative mean vector Vector<double> bias(mean); bias.Scale(-1.0); //get the variance normalization window vector, //inverse of square root of variance Vector<double> window(variance); for(size_t i=0; i<window.Dim(); i++) { window[i] = 1.0/sqrt(window[i]); } //store the normalization network std::ofstream os(p_targetmmf); if(!os.good()) Error(std::string("Cannot open file for writing: ")+p_targetmmf); dim = mean.Dim(); os << "<bias> " << dim << " " << dim << "\n" << bias << "\n\n" << "<window> " << dim << " " << dim << "\n" << window << "\n\n"; os.close(); timer.End(); std::cout << "\n\n===== TNorm FINISHED ( " << timer.Val() << "s ) " << "[FPS:" << framesN / timer.Val() << ",RT:" << 1.0f / (framesN / timer.Val() / 100.0f) << "] =====" << std::endl; std::cout << "frames: " << framesN << ", max_bias: " << bias.Max() << ", max_window: " << window.Max() << ", min_window: " << window.Min() << "\n"; return 0; ///finish OK } catch (std::exception& rExc) { std::cerr << "Exception thrown" << std::endl; std::cerr << rExc.what() << std::endl; return 1; }