diff options
| author | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 | 
|---|---|---|
| committer | Joe Zhao <ztuowen@gmail.com> | 2014-04-14 08:14:45 +0800 | 
| commit | cccccbf6cca94a3eaf813b4468453160e91c332b (patch) | |
| tree | 23418cb73a10ae3b0688681a7f0ba9b06424583e /src/TNet.cc | |
| download | tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2 tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip | |
First commit
Diffstat (limited to 'src/TNet.cc')
| -rw-r--r-- | src/TNet.cc | 379 | 
1 files changed, 379 insertions, 0 deletions
| diff --git a/src/TNet.cc b/src/TNet.cc new file mode 100644 index 0000000..7d60e63 --- /dev/null +++ b/src/TNet.cc @@ -0,0 +1,379 @@ + +/*************************************************************************** + *   copyright            : (C) 2011 by Karel Vesely,UPGM,FIT,VUT,Brno     * + *   email                : iveselyk@fit.vutbr.cz                          * + *************************************************************************** + *                                                                         * + *   This program is free software; you can redistribute it and/or modify  * + *   it under the terms of the APACHE License as published by the          * + *   Apache Software Foundation; either version 2.0 of the License,        * + *   or (at your option) any later version.                                * + *                                                                         * + ***************************************************************************/ + +#define SVN_DATE       "$Date: 2011-04-04 19:14:16 +0200 (Mon, 04 Apr 2011) $" +#define SVN_AUTHOR     "$Author: iveselyk $" +#define SVN_REVISION   "$Revision: 46 $" +#define SVN_ID         "$Id: TNet.cc 46 2011-04-04 17:14:16Z iveselyk $" + +#define MODULE_VERSION "1.0.0 "__TIME__" "__DATE__" "SVN_ID   + +/** + * \file TNet.cc + * \brief NNet training entry program Multi-core version + */ + +/*** TNetLib includes */ +#include "Error.h" +#include "Timer.h" +#include "Features.h" +#include "Common.h" +#include "MlfStream.h" +#include "UserInterface.h" +#include "Timer.h" + +/*** TNet includes */ +#include "Nnet.h" +#include "ObjFun.h" +#include "Platform.h" + + +/*** STL includes */ +#include <iostream> +#include <sstream> +#include <numeric> + + +#define SNAME "TNET" + +using namespace TNet; + +void usage(const char* progname)  +{ +  const char *tchrptr; +  if ((tchrptr = strrchr(progname, '\\')) != NULL) progname = tchrptr+1; +  if ((tchrptr = strrchr(progname, '/')) != NULL) progname = tchrptr+1; +  fprintf(stderr, +"\n%s version " MODULE_VERSION "\n" +"\nUSAGE: %s [options] DataFiles...\n\n" +":TODO:\n\n" +" Option                                                     Default\n\n" +" -c         Enable crossvalidation                          off\n" +" -m file    Set label map of NN outputs                     \n" +" -n f       Set learning rate to f                          0.06\n" +" -o ext     Set target model ext                            None\n" +" -A         Print command line arguments                    Off\n"  +" -C cf      Set config file to cf                           Default\n" +" -D         Display configuration variables                 Off\n" +" -H mmf     Load NN macro file                              \n" +" -I mlf     Load master label file mlf                      \n" +" -L dir     Set input label (or net) dir                    Current\n" +" -M dir     Dir to write NN macro files                     Current\n" +" -O fn      Objective function [mse,xent]                   xent\n" +" -S file    Set script file                                 None\n" +" -T N       Set trace flags to N                            0\n"  +" -V         Print version information                       Off\n" +" -X ext     Set input label file ext                        lab\n" +"\n" +"BUNCHSIZE CACHESIZE CONFUSIONMODE[no,max,soft,dmax,dsoft] CROSSVALIDATE FEATURETRANSFORM LEARNINGRATE LEARNRATEFACTORS MLFTRANSC MOMENTUM NATURALREADORDER OBJECTIVEFUNCTION[mse,xent] OUTPUTLABELMAP PRINTCONFIG PRINTVERSION RANDOMIZE SCRIPT SEED SOURCEMLF SOURCEMMF SOURCETRANSCDIR SOURCETRANSCEXT TARGETMMF TARGETMODELDIR TARGETMODELEXT TRACE WEIGHTCOST\n" +"\n" +"STARTFRMEXT ENDFRMEXT CMEANDIR CMEANMASK VARSCALEDIR VARSCALEMASK VARSCALEFN TARGETKIND DERIVWINDOWS DELTAWINDOW ACCWINDOW THIRDWINDOW\n" +"\n" +" %s is Copyright (C) 2010-2011 Karel Vesely\n" +" licensed under the APACHE License, version 2.0\n" +" Bug reports, feedback, etc, to: iveselyk@fit.vutbr.cz\n" +"\n", progname, progname, progname); +  exit(-1); +} + + + +/////////////////////////////////////////////////////////////////////// +// MAIN FUNCTION +// + + +int main(int argc, char *argv[]) +{ +  const char* p_option_string = +    " -c n   CROSSVALIDATE=TRUE" +//  " -d r   SOURCEMODELDIR" +    " -m r   OUTPUTLABELMAP"  +    " -n r   LEARNINGRATE"  +    " -o r   TARGETMODELEXT"  +    " -p r   PARALLELMODE"  +//  " -r n   REGULARISATION=TRUE" //add later +//  " -u r   UPDATE" //add later, update only certain weights... +//  " -x r   SOURCEMODELEXT" +    " -B n   SAVEBINARY=TRUE"  +    " -D n   PRINTCONFIG=TRUE" +//  " -G r   SOURCETRANSCFMT" //add if more transcription formats +    " -H l   SOURCEMMF" +    " -I r   SOURCEMLF" +    " -L r   SOURCETRANSCDIR" +    " -M r   TARGETMODELDIR" +    " -O r   OBJECTIVEFUNCTION"  +    " -S l   SCRIPT" +    " -T r   TRACE" +    " -V n   PRINTVERSION=TRUE" +    " -X r   SOURCETRANSCEXT"; + + +  try { +    UserInterface        ui; +    Platform             pl; +    Timer                timer; + +    +    const char*                       p_script; +    const char*                       p_output_label_map; +    BaseFloat                         learning_rate; +    BaseFloat                         weightcost; +    ObjectiveFunction::ObjFunType     obj_fun_id; +    CrossEntropy::ConfusionMode       xent_conf_mode; + +    const char*                       p_source_mmf_file; +    const char*                       p_input_transform; + +    const char*                       p_targetmmf; //< SNet legacy --TARGETMMF +          char                        p_trg_mmf_file[4096]; +    const char*                       p_trg_mmf_dir; +    const char*                       p_trg_mmf_ext; + +    const char*                       p_source_mlf_file; +    const char*                       p_src_lbl_dir; +    const char*                       p_src_lbl_ext; + +    int                               bunch_size; +    int                               cache_size; +    bool                              randomize; +    long int                          seed; + + +    int                               trace; +    bool                              crossval; +    int                               num_threads; + + +    // variables for feature repository +    bool                              swap_features; +    int                               target_kind; +    int                               deriv_order; +    int*                              p_deriv_win_lenghts; +    int                               start_frm_ext; +    int                               end_frm_ext; +          char*                       cmn_path; +          char*                       cmn_file; +    const char*                       cmn_mask; +          char*                       cvn_path; +          char*                       cvn_file; +    const char*                       cvn_mask; +    const char*                       cvg_file; +  +     +    // OPTION PARSING .......................................................... +    // use the STK option parsing +    if (argc == 1) { usage(argv[0]); return 1; } +    int args_parsed = ui.ParseOptions(argc, argv, p_option_string, SNAME); + + +    // OPTION RETRIEVAL ........................................................ +    // extract the feature parameters +    swap_features = !ui.GetBool(SNAME":NATURALREADORDER", TNet::IsBigEndian()); +     +    target_kind = ui.GetFeatureParams(&deriv_order, &p_deriv_win_lenghts, +         &start_frm_ext, &end_frm_ext, &cmn_path, &cmn_file, &cmn_mask, +         &cvn_path, &cvn_file, &cvn_mask, &cvg_file, SNAME":", 0); + + +    // extract other parameters +    p_source_mmf_file   = ui.GetStr(SNAME":SOURCEMMF",     NULL); +    p_input_transform   = ui.GetStr(SNAME":FEATURETRANSFORM",  NULL); +     +    p_targetmmf         = ui.GetStr(SNAME":TARGETMMF",     NULL);//< has higher priority than "dir/file.ext" composition (SNet legacy) +    p_trg_mmf_dir       = ui.GetStr(SNAME":TARGETMODELDIR",  "");//< dir for composition +    p_trg_mmf_ext       = ui.GetStr(SNAME":TARGETMODELEXT",  "");//< ext for composition + +    p_script            = ui.GetStr(SNAME":SCRIPT",         NULL); +    p_output_label_map  = ui.GetStr(SNAME":OUTPUTLABELMAP", NULL); +    learning_rate       = ui.GetFlt(SNAME":LEARNINGRATE"  , 0.06f); +    weightcost          = ui.GetFlt(SNAME":WEIGHTCOST"    , 0.0); + +    obj_fun_id          = static_cast<ObjectiveFunction::ObjFunType>( +                          ui.GetEnum(SNAME":OBJECTIVEFUNCTION",  +                                     ObjectiveFunction::CROSS_ENTROPY, //< default +                                     "ent", ObjectiveFunction::CROSS_ENTROPY, +                                     "mse", ObjectiveFunction::MEAN_SQUARE_ERROR +                          )); + +    xent_conf_mode      = static_cast<CrossEntropy::ConfusionMode>( +                          ui.GetEnum(SNAME":CONFUSIONMODE", +                                     CrossEntropy::NO_CONF, //< default +                                     "no", CrossEntropy::NO_CONF, +                                     "max", CrossEntropy::MAX_CONF, +                                     "soft", CrossEntropy::SOFT_CONF, +                                     "dmax", CrossEntropy::DIAG_MAX_CONF, +                                     "dsoft", CrossEntropy::DIAG_SOFT_CONF +                          )); + +    p_source_mlf_file   = ui.GetStr(SNAME":SOURCEMLF",       NULL); +    p_src_lbl_dir       = ui.GetStr(SNAME":SOURCETRANSCDIR", NULL); +    p_src_lbl_ext       = ui.GetStr(SNAME":SOURCETRANSCEXT", "lab"); + +    bunch_size          = ui.GetInt(SNAME":BUNCHSIZE", 256); +    cache_size          = ui.GetInt(SNAME":CACHESIZE", 12800); +    randomize           = ui.GetBool(SNAME":RANDOMIZE", true); + +    //cannot get long int +    seed                = ui.GetInt(SNAME":SEED", 0); + + +    //Fill the global variables of the singleton 'Gl' +    trace               = ui.GetInt(SNAME":TRACE",               0); +    num_threads         = ui.GetInt(SNAME":THREADS",          1); +    crossval            = ui.GetBool(SNAME":CROSSVALIDATE",  false); + + +    // process the parameters +    if(ui.GetBool(SNAME":PRINTCONFIG", false)) { +      std::cout << std::endl; +      ui.PrintConfig(std::cout); +      std::cout << std::endl; +    } +    if(ui.GetBool(SNAME":PRINTVERSION", false)) { +      std::cout << std::endl; +      std::cout << "======= TNET v"MODULE_VERSION" =======" << std::endl; +      std::cout << std::endl; +    } +    ui.CheckCommandLineParamUse(); +     + +    // the rest of the parameters are the feature files +    for (; args_parsed < argc; args_parsed++) { +      pl.feature_.AddFile(argv[args_parsed]); +    } + +    //************************************************************************** +    //************************************************************************** +    // OPTION PARSING DONE ..................................................... + + +    //initialize the InputProxy +    if(NULL == p_script) +      Warning("WARNING: The script file is missing [-S]"); +    if(NULL == p_source_mlf_file) +      Error("Source mlf file file is missing [-I]"); +    if(NULL == p_output_label_map) +      Error("Output label map is missing [-m]"); +    // initialize the feature repository +    if(trace&1) TraceLog("Initializing FeatureRepository"); +    pl.feature_.Init( +      swap_features, start_frm_ext, end_frm_ext, target_kind, +      deriv_order, p_deriv_win_lenghts,  +      cmn_path, cmn_mask, cvn_path, cvn_mask, cvg_file +    ); +    //open the scp file +    pl.feature_.AddFileList(p_script); + +    // initialize the label repository +    if(trace&1) TraceLog("Initializing LabelRepository"); +    pl.label_.Init(p_source_mlf_file,p_output_label_map, p_src_lbl_dir, p_src_lbl_ext); + +    // read input transform     +    if(NULL != p_input_transform) { +      if(trace&1) TraceLog(std::string("Reading input transform:")+p_input_transform); +      pl.nnet_transf_.ReadNetwork(p_input_transform); +    } + +    // read network +    if(NULL != p_source_mmf_file) {  +      if(trace&1) TraceLog(std::string("Reading network:")+p_source_mmf_file); +      pl.nnet_.ReadNetwork(p_source_mmf_file); +    } else { +      Error("Source MMF must be specified [-H]"); +    } +    pl.nnet_.SetLearnRate(learning_rate); +    pl.nnet_.SetWeightcost(weightcost); + +    //get objective function instance +    pl.obj_fun_ = ObjectiveFunction::Factory(obj_fun_id); +    //setup the cross entropy +    if(obj_fun_id == ObjectiveFunction::CROSS_ENTROPY) { +      CrossEntropy* xent = dynamic_cast<CrossEntropy*>(pl.obj_fun_); +      //confusion mode +      xent->SetConfusionMode(xent_conf_mode); +      //pass the outputlabelmap +      xent->SetOutputLabelMap(p_output_label_map); +    } + +    //initialize the cache +    pl.bunchsize_ = bunch_size; +    pl.cachesize_ = cache_size; +    pl.randomize_ = randomize; +    // +    pl.start_frm_ext_ = start_frm_ext; +    pl.end_frm_ext_ = end_frm_ext; +    pl.trace_ = trace; +    pl.crossval_ = crossval; + +    //TODO do someting with seed!!! +    pl.seed_ = seed; +    //data_proxy.InitCache(cache_size, bunch_size, network, randomize, seed); +     +    timer.Start(); +    std::cout << "===== TNET "  +              << (crossval?"CROSSVALIDATION":"TRAINING")  +              << " STARTED =====" << std::endl; +    std::cout << "Objective function: "  +              << pl.obj_fun_->GetName() << std::endl; +    if(!crossval) { +      std::cout << "Learning rate: " << learning_rate << std::endl; +    } + + +    /* +     * PERFORM ONE ITERATION OF THE TRAINING +     */ +    pl.RunTrain(num_threads); +    /* +     * +     */ + + +    if(trace&1) TraceLog("Training finished"); + +    //write the network +    if(!crossval) { +      if(trace&1) TraceLog("Writing network"); +      if (NULL != p_targetmmf) { +        pl.nnet_.WriteNetwork(p_targetmmf); +      } else { +        MakeHtkFileName(p_trg_mmf_file, p_source_mmf_file, p_trg_mmf_dir, p_trg_mmf_ext); +        pl.nnet_.WriteNetwork(p_trg_mmf_file); +      } +    } + +    //show report +    timer.End(); + +    pl.cout_mutex_.Lock(); + +    std::cout << "===== TNET FINISHED ( " << timer.Val() << "s ) " +              << "[ FPS: " << pl.obj_fun_->GetFrames() / timer.Val()  +              << " RT: " << 1.0f / (pl.obj_fun_->GetFrames() / timer.Val() / 100.0f) +              << " ] =====" << std::endl; + +    //report objective function +    std::cout << "-- " << (crossval?"CV ":"TR ")  +              << pl.obj_fun_->Report(); + +    pl.cout_mutex_.Unlock(); + +  } +  catch (std::exception& rExc) { +    std::cerr << "Exception thrown" << std::endl; +    std::cerr << rExc.what() << std::endl; +    return 1; +  } +  return 0; +} | 
