diff options
author | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 |
commit | cccccbf6cca94a3eaf813b4468453160e91c332b (patch) | |
tree | 23418cb73a10ae3b0688681a7f0ba9b06424583e /src/TNet.cc | |
download | tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2 tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip |
First commit
Diffstat (limited to 'src/TNet.cc')
-rw-r--r-- | src/TNet.cc | 379 |
1 files changed, 379 insertions, 0 deletions
diff --git a/src/TNet.cc b/src/TNet.cc new file mode 100644 index 0000000..7d60e63 --- /dev/null +++ b/src/TNet.cc @@ -0,0 +1,379 @@ + +/*************************************************************************** + * copyright : (C) 2011 by Karel Vesely,UPGM,FIT,VUT,Brno * + * email : iveselyk@fit.vutbr.cz * + *************************************************************************** + * * + * This program is free software; you can redistribute it and/or modify * + * it under the terms of the APACHE License as published by the * + * Apache Software Foundation; either version 2.0 of the License, * + * or (at your option) any later version. * + * * + ***************************************************************************/ + +#define SVN_DATE "$Date: 2011-04-04 19:14:16 +0200 (Mon, 04 Apr 2011) $" +#define SVN_AUTHOR "$Author: iveselyk $" +#define SVN_REVISION "$Revision: 46 $" +#define SVN_ID "$Id: TNet.cc 46 2011-04-04 17:14:16Z iveselyk $" + +#define MODULE_VERSION "1.0.0 "__TIME__" "__DATE__" "SVN_ID + +/** + * \file TNet.cc + * \brief NNet training entry program Multi-core version + */ + +/*** TNetLib includes */ +#include "Error.h" +#include "Timer.h" +#include "Features.h" +#include "Common.h" +#include "MlfStream.h" +#include "UserInterface.h" +#include "Timer.h" + +/*** TNet includes */ +#include "Nnet.h" +#include "ObjFun.h" +#include "Platform.h" + + +/*** STL includes */ +#include <iostream> +#include <sstream> +#include <numeric> + + +#define SNAME "TNET" + +using namespace TNet; + +void usage(const char* progname) +{ + const char *tchrptr; + if ((tchrptr = strrchr(progname, '\\')) != NULL) progname = tchrptr+1; + if ((tchrptr = strrchr(progname, '/')) != NULL) progname = tchrptr+1; + fprintf(stderr, +"\n%s version " MODULE_VERSION "\n" +"\nUSAGE: %s [options] DataFiles...\n\n" +":TODO:\n\n" +" Option Default\n\n" +" -c Enable crossvalidation off\n" +" -m file Set label map of NN outputs \n" +" -n f Set learning rate to f 0.06\n" +" -o ext Set target model ext None\n" +" -A Print command line arguments Off\n" +" -C cf Set config file to cf Default\n" +" -D Display configuration variables Off\n" +" -H mmf Load NN macro file \n" +" -I mlf Load master label file mlf \n" +" -L dir Set input label (or net) dir Current\n" +" -M dir Dir to write NN macro files Current\n" +" -O fn Objective function [mse,xent] xent\n" +" -S file Set script file None\n" +" -T N Set trace flags to N 0\n" +" -V Print version information Off\n" +" -X ext Set input label file ext lab\n" +"\n" +"BUNCHSIZE CACHESIZE CONFUSIONMODE[no,max,soft,dmax,dsoft] CROSSVALIDATE FEATURETRANSFORM LEARNINGRATE LEARNRATEFACTORS MLFTRANSC MOMENTUM NATURALREADORDER OBJECTIVEFUNCTION[mse,xent] OUTPUTLABELMAP PRINTCONFIG PRINTVERSION RANDOMIZE SCRIPT SEED SOURCEMLF SOURCEMMF SOURCETRANSCDIR SOURCETRANSCEXT TARGETMMF TARGETMODELDIR TARGETMODELEXT TRACE WEIGHTCOST\n" +"\n" +"STARTFRMEXT ENDFRMEXT CMEANDIR CMEANMASK VARSCALEDIR VARSCALEMASK VARSCALEFN TARGETKIND DERIVWINDOWS DELTAWINDOW ACCWINDOW THIRDWINDOW\n" +"\n" +" %s is Copyright (C) 2010-2011 Karel Vesely\n" +" licensed under the APACHE License, version 2.0\n" +" Bug reports, feedback, etc, to: iveselyk@fit.vutbr.cz\n" +"\n", progname, progname, progname); + exit(-1); +} + + + +/////////////////////////////////////////////////////////////////////// +// MAIN FUNCTION +// + + +int main(int argc, char *argv[]) +{ + const char* p_option_string = + " -c n CROSSVALIDATE=TRUE" +// " -d r SOURCEMODELDIR" + " -m r OUTPUTLABELMAP" + " -n r LEARNINGRATE" + " -o r TARGETMODELEXT" + " -p r PARALLELMODE" +// " -r n REGULARISATION=TRUE" //add later +// " -u r UPDATE" //add later, update only certain weights... +// " -x r SOURCEMODELEXT" + " -B n SAVEBINARY=TRUE" + " -D n PRINTCONFIG=TRUE" +// " -G r SOURCETRANSCFMT" //add if more transcription formats + " -H l SOURCEMMF" + " -I r SOURCEMLF" + " -L r SOURCETRANSCDIR" + " -M r TARGETMODELDIR" + " -O r OBJECTIVEFUNCTION" + " -S l SCRIPT" + " -T r TRACE" + " -V n PRINTVERSION=TRUE" + " -X r SOURCETRANSCEXT"; + + + try { + UserInterface ui; + Platform pl; + Timer timer; + + + const char* p_script; + const char* p_output_label_map; + BaseFloat learning_rate; + BaseFloat weightcost; + ObjectiveFunction::ObjFunType obj_fun_id; + CrossEntropy::ConfusionMode xent_conf_mode; + + const char* p_source_mmf_file; + const char* p_input_transform; + + const char* p_targetmmf; //< SNet legacy --TARGETMMF + char p_trg_mmf_file[4096]; + const char* p_trg_mmf_dir; + const char* p_trg_mmf_ext; + + const char* p_source_mlf_file; + const char* p_src_lbl_dir; + const char* p_src_lbl_ext; + + int bunch_size; + int cache_size; + bool randomize; + long int seed; + + + int trace; + bool crossval; + int num_threads; + + + // variables for feature repository + bool swap_features; + int target_kind; + int deriv_order; + int* p_deriv_win_lenghts; + int start_frm_ext; + int end_frm_ext; + char* cmn_path; + char* cmn_file; + const char* cmn_mask; + char* cvn_path; + char* cvn_file; + const char* cvn_mask; + const char* cvg_file; + + + // OPTION PARSING .......................................................... + // use the STK option parsing + if (argc == 1) { usage(argv[0]); return 1; } + int args_parsed = ui.ParseOptions(argc, argv, p_option_string, SNAME); + + + // OPTION RETRIEVAL ........................................................ + // extract the feature parameters + swap_features = !ui.GetBool(SNAME":NATURALREADORDER", TNet::IsBigEndian()); + + target_kind = ui.GetFeatureParams(&deriv_order, &p_deriv_win_lenghts, + &start_frm_ext, &end_frm_ext, &cmn_path, &cmn_file, &cmn_mask, + &cvn_path, &cvn_file, &cvn_mask, &cvg_file, SNAME":", 0); + + + // extract other parameters + p_source_mmf_file = ui.GetStr(SNAME":SOURCEMMF", NULL); + p_input_transform = ui.GetStr(SNAME":FEATURETRANSFORM", NULL); + + p_targetmmf = ui.GetStr(SNAME":TARGETMMF", NULL);//< has higher priority than "dir/file.ext" composition (SNet legacy) + p_trg_mmf_dir = ui.GetStr(SNAME":TARGETMODELDIR", "");//< dir for composition + p_trg_mmf_ext = ui.GetStr(SNAME":TARGETMODELEXT", "");//< ext for composition + + p_script = ui.GetStr(SNAME":SCRIPT", NULL); + p_output_label_map = ui.GetStr(SNAME":OUTPUTLABELMAP", NULL); + learning_rate = ui.GetFlt(SNAME":LEARNINGRATE" , 0.06f); + weightcost = ui.GetFlt(SNAME":WEIGHTCOST" , 0.0); + + obj_fun_id = static_cast<ObjectiveFunction::ObjFunType>( + ui.GetEnum(SNAME":OBJECTIVEFUNCTION", + ObjectiveFunction::CROSS_ENTROPY, //< default + "ent", ObjectiveFunction::CROSS_ENTROPY, + "mse", ObjectiveFunction::MEAN_SQUARE_ERROR + )); + + xent_conf_mode = static_cast<CrossEntropy::ConfusionMode>( + ui.GetEnum(SNAME":CONFUSIONMODE", + CrossEntropy::NO_CONF, //< default + "no", CrossEntropy::NO_CONF, + "max", CrossEntropy::MAX_CONF, + "soft", CrossEntropy::SOFT_CONF, + "dmax", CrossEntropy::DIAG_MAX_CONF, + "dsoft", CrossEntropy::DIAG_SOFT_CONF + )); + + p_source_mlf_file = ui.GetStr(SNAME":SOURCEMLF", NULL); + p_src_lbl_dir = ui.GetStr(SNAME":SOURCETRANSCDIR", NULL); + p_src_lbl_ext = ui.GetStr(SNAME":SOURCETRANSCEXT", "lab"); + + bunch_size = ui.GetInt(SNAME":BUNCHSIZE", 256); + cache_size = ui.GetInt(SNAME":CACHESIZE", 12800); + randomize = ui.GetBool(SNAME":RANDOMIZE", true); + + //cannot get long int + seed = ui.GetInt(SNAME":SEED", 0); + + + //Fill the global variables of the singleton 'Gl' + trace = ui.GetInt(SNAME":TRACE", 0); + num_threads = ui.GetInt(SNAME":THREADS", 1); + crossval = ui.GetBool(SNAME":CROSSVALIDATE", false); + + + // process the parameters + if(ui.GetBool(SNAME":PRINTCONFIG", false)) { + std::cout << std::endl; + ui.PrintConfig(std::cout); + std::cout << std::endl; + } + if(ui.GetBool(SNAME":PRINTVERSION", false)) { + std::cout << std::endl; + std::cout << "======= TNET v"MODULE_VERSION" =======" << std::endl; + std::cout << std::endl; + } + ui.CheckCommandLineParamUse(); + + + // the rest of the parameters are the feature files + for (; args_parsed < argc; args_parsed++) { + pl.feature_.AddFile(argv[args_parsed]); + } + + //************************************************************************** + //************************************************************************** + // OPTION PARSING DONE ..................................................... + + + //initialize the InputProxy + if(NULL == p_script) + Warning("WARNING: The script file is missing [-S]"); + if(NULL == p_source_mlf_file) + Error("Source mlf file file is missing [-I]"); + if(NULL == p_output_label_map) + Error("Output label map is missing [-m]"); + // initialize the feature repository + if(trace&1) TraceLog("Initializing FeatureRepository"); + pl.feature_.Init( + swap_features, start_frm_ext, end_frm_ext, target_kind, + deriv_order, p_deriv_win_lenghts, + cmn_path, cmn_mask, cvn_path, cvn_mask, cvg_file + ); + //open the scp file + pl.feature_.AddFileList(p_script); + + // initialize the label repository + if(trace&1) TraceLog("Initializing LabelRepository"); + pl.label_.Init(p_source_mlf_file,p_output_label_map, p_src_lbl_dir, p_src_lbl_ext); + + // read input transform + if(NULL != p_input_transform) { + if(trace&1) TraceLog(std::string("Reading input transform:")+p_input_transform); + pl.nnet_transf_.ReadNetwork(p_input_transform); + } + + // read network + if(NULL != p_source_mmf_file) { + if(trace&1) TraceLog(std::string("Reading network:")+p_source_mmf_file); + pl.nnet_.ReadNetwork(p_source_mmf_file); + } else { + Error("Source MMF must be specified [-H]"); + } + pl.nnet_.SetLearnRate(learning_rate); + pl.nnet_.SetWeightcost(weightcost); + + //get objective function instance + pl.obj_fun_ = ObjectiveFunction::Factory(obj_fun_id); + //setup the cross entropy + if(obj_fun_id == ObjectiveFunction::CROSS_ENTROPY) { + CrossEntropy* xent = dynamic_cast<CrossEntropy*>(pl.obj_fun_); + //confusion mode + xent->SetConfusionMode(xent_conf_mode); + //pass the outputlabelmap + xent->SetOutputLabelMap(p_output_label_map); + } + + //initialize the cache + pl.bunchsize_ = bunch_size; + pl.cachesize_ = cache_size; + pl.randomize_ = randomize; + // + pl.start_frm_ext_ = start_frm_ext; + pl.end_frm_ext_ = end_frm_ext; + pl.trace_ = trace; + pl.crossval_ = crossval; + + //TODO do someting with seed!!! + pl.seed_ = seed; + //data_proxy.InitCache(cache_size, bunch_size, network, randomize, seed); + + timer.Start(); + std::cout << "===== TNET " + << (crossval?"CROSSVALIDATION":"TRAINING") + << " STARTED =====" << std::endl; + std::cout << "Objective function: " + << pl.obj_fun_->GetName() << std::endl; + if(!crossval) { + std::cout << "Learning rate: " << learning_rate << std::endl; + } + + + /* + * PERFORM ONE ITERATION OF THE TRAINING + */ + pl.RunTrain(num_threads); + /* + * + */ + + + if(trace&1) TraceLog("Training finished"); + + //write the network + if(!crossval) { + if(trace&1) TraceLog("Writing network"); + if (NULL != p_targetmmf) { + pl.nnet_.WriteNetwork(p_targetmmf); + } else { + MakeHtkFileName(p_trg_mmf_file, p_source_mmf_file, p_trg_mmf_dir, p_trg_mmf_ext); + pl.nnet_.WriteNetwork(p_trg_mmf_file); + } + } + + //show report + timer.End(); + + pl.cout_mutex_.Lock(); + + std::cout << "===== TNET FINISHED ( " << timer.Val() << "s ) " + << "[ FPS: " << pl.obj_fun_->GetFrames() / timer.Val() + << " RT: " << 1.0f / (pl.obj_fun_->GetFrames() / timer.Val() / 100.0f) + << " ] =====" << std::endl; + + //report objective function + std::cout << "-- " << (crossval?"CV ":"TR ") + << pl.obj_fun_->Report(); + + pl.cout_mutex_.Unlock(); + + } + catch (std::exception& rExc) { + std::cerr << "Exception thrown" << std::endl; + std::cerr << rExc.what() << std::endl; + return 1; + } + return 0; +} |