summaryrefslogtreecommitdiff
path: root/src/TRbmCu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/TRbmCu.cc')
-rw-r--r--src/TRbmCu.cc396
1 files changed, 396 insertions, 0 deletions
diff --git a/src/TRbmCu.cc b/src/TRbmCu.cc
new file mode 100644
index 0000000..b2d5ea8
--- /dev/null
+++ b/src/TRbmCu.cc
@@ -0,0 +1,396 @@
+
+/***************************************************************************
+ * copyright : (C) 2011 by Karel Vesely,UPGM,FIT,VUT,Brno *
+ * email : iveselyk@fit.vutbr.cz *
+ ***************************************************************************
+ * *
+ * This program is free software; you can redistribute it and/or modify *
+ * it under the terms of the APACHE License as published by the *
+ * Apache Software Foundation; either version 2.0 of the License, *
+ * or (at your option) any later version. *
+ * *
+ ***************************************************************************/
+
+#define SVN_DATE "$Date: 2011-12-08 11:59:03 +0100 (Thu, 08 Dec 2011) $"
+#define SVN_AUTHOR "$Author: iveselyk $"
+#define SVN_REVISION "$Revision: 94 $"
+#define SVN_ID "$Id: TRbmCu.cc 94 2011-12-08 10:59:03Z iveselyk $"
+
+#define MODULE_VERSION "1.0.0 "__TIME__" "__DATE__" "SVN_ID
+
+
+
+
+
+/*** TNetLib includes */
+#include "Error.h"
+#include "Timer.h"
+#include "Features.h"
+#include "Common.h"
+#include "UserInterface.h"
+#include "Timer.h"
+
+/*** TNet includes */
+#include "cuNetwork.h"
+#include "cuRbm.h"
+#include "cuCache.h"
+#include "cuObjectiveFunction.h"
+#include "curand.h"
+
+/*** STL includes */
+#include <iostream>
+#include <sstream>
+#include <numeric>
+
+
+
+
+//////////////////////////////////////////////////////////////////////
+// DEFINES
+//
+
+#define SNAME "TRBM"
+
+using namespace TNet;
+
+void usage(const char* progname)
+{
+ const char *tchrptr;
+ if ((tchrptr = strrchr(progname, '\\')) != NULL) progname = tchrptr+1;
+ if ((tchrptr = strrchr(progname, '/')) != NULL) progname = tchrptr+1;
+ fprintf(stderr,
+"\n%s version " MODULE_VERSION "\n"
+"\nUSAGE: %s [options] DataFiles...\n\n"
+" Option Default\n\n"
+" -n f Set learning rate to f 0.06\n"
+" -A Print command line arguments Off\n"
+" -C cf Set config file to cf Default\n"
+" -D Display configuration variables Off\n"
+" -H mmf Load NN macro file \n"
+" -S file Set script file None\n"
+" -T N Set trace flags to N 0\n"
+" -V Print version information Off\n"
+"\n"
+"FEATURETRANSFORM LEARNINGRATE MOMENTUM NATURALREADORDER PRINTCONFIG PRINTVERSION SCRIPT SOURCEMMF TARGETMMF TRACE WEIGHTCOST\n"
+"\n"
+"STARTFRMEXT ENDFRMEXT CMEANDIR CMEANMASK VARSCALEDIR VARSCALEMASK VARSCALEFN TARGETKIND DERIVWINDOWS DELTAWINDOW ACCWINDOW THIRDWINDOW\n"
+"\n"
+" %s is Copyright (C) 2010-2011 Karel Vesely\n"
+" licensed under the APACHE License, version 2.0\n"
+" Bug reports, feedback, etc, to: iveselyk@fit.vutbr.cz\n"
+"\n", progname, progname, progname);
+ exit(-1);
+}
+
+
+
+///////////////////////////////////////////////////////////////////////
+// MAIN FUNCTION
+//
+
+
+int main(int argc, char *argv[]) try
+{
+ const char* p_option_string =
+ " -n r LEARNINGRATE"
+ " -D n PRINTCONFIG=TRUE"
+ " -H l SOURCEMMF"
+ " -S l SCRIPT"
+ " -T r TRACE"
+ " -V n PRINTVERSION=TRUE"
+ ;
+
+
+ UserInterface ui;
+ FeatureRepository feature_repo;
+ CuNetwork network;
+ CuNetwork transform_network;
+ CuMeanSquareError mse;
+ Timer timer;
+ Timer timer_frontend;
+ double time_frontend = 0.0;
+
+
+ const char* p_script;
+ BaseFloat learning_rate;
+ BaseFloat momentum;
+ BaseFloat weightcost;
+
+ const char* p_source_mmf_file;
+ const char* p_input_transform;
+
+ const char* p_targetmmf;
+
+ int bunch_size;
+ int cache_size;
+ bool randomize;
+ long int seed;
+
+ int trace;
+
+ // variables for feature repository
+ bool swap_features;
+ int target_kind;
+ int deriv_order;
+ int* p_deriv_win_lenghts;
+ int start_frm_ext;
+ int end_frm_ext;
+ char* cmn_path;
+ char* cmn_file;
+ const char* cmn_mask;
+ char* cvn_path;
+ char* cvn_file;
+ const char* cvn_mask;
+ const char* cvg_file;
+
+
+ // OPTION PARSING ..........................................................
+ // use the STK option parsing
+ if (argc == 1) { usage(argv[0]); return 1; }
+ int args_parsed = ui.ParseOptions(argc, argv, p_option_string, SNAME);
+
+
+ // OPTION RETRIEVAL ........................................................
+ // extract the feature parameters
+ swap_features = !ui.GetBool(SNAME":NATURALREADORDER", TNet::IsBigEndian());
+
+ target_kind = ui.GetFeatureParams(&deriv_order, &p_deriv_win_lenghts,
+ &start_frm_ext, &end_frm_ext, &cmn_path, &cmn_file, &cmn_mask,
+ &cvn_path, &cvn_file, &cvn_mask, &cvg_file, SNAME":", 0);
+
+
+ // extract other parameters
+ p_source_mmf_file = ui.GetStr(SNAME":SOURCEMMF", NULL);
+ p_input_transform = ui.GetStr(SNAME":FEATURETRANSFORM", NULL);
+
+ p_targetmmf = ui.GetStr(SNAME":TARGETMMF", NULL);
+
+ p_script = ui.GetStr(SNAME":SCRIPT", NULL);
+ learning_rate = ui.GetFlt(SNAME":LEARNINGRATE" , 0.10f);
+ momentum = ui.GetFlt(SNAME":MOMENTUM" , 0.50f);
+ weightcost = ui.GetFlt(SNAME":WEIGHTCOST" , 0.0002f);
+
+
+ bunch_size = ui.GetInt(SNAME":BUNCHSIZE", 256);
+ cache_size = ui.GetInt(SNAME":CACHESIZE", 12800);
+ randomize = ui.GetBool(SNAME":RANDOMIZE", true);
+
+ //cannot get long int
+ seed = ui.GetInt(SNAME":SEED", 0);
+
+ trace = ui.GetInt(SNAME":TRACE", 0);
+ if(trace&4) { CuDevice::Instantiate().Verbose(true); }
+
+
+
+
+ // process the parameters
+ if(ui.GetBool(SNAME":PRINTCONFIG", false)) {
+ std::cout << std::endl;
+ ui.PrintConfig(std::cout);
+ std::cout << std::endl;
+ }
+ if(ui.GetBool(SNAME":PRINTVERSION", false)) {
+ std::cout << std::endl;
+ std::cout << "======= TRbmCu v"MODULE_VERSION" xvesel39 =======" << std::endl;
+ std::cout << std::endl;
+ }
+ ui.CheckCommandLineParamUse();
+
+
+ // the rest of the parameters are the feature files
+ for (; args_parsed < argc; args_parsed++) {
+ feature_repo.AddFile(argv[args_parsed]);
+ }
+
+ //**************************************************************************
+ //**************************************************************************
+ // OPTION PARSING DONE .....................................................
+
+
+ //read the input transform network
+ if(NULL != p_input_transform) {
+ if(trace&1) TraceLog(std::string("Reading input transform network: ")+p_input_transform);
+ transform_network.ReadNetwork(p_input_transform);
+ }
+
+
+ //read the neural network
+ if(NULL != p_source_mmf_file) {
+ if(trace&1) TraceLog(std::string("Reading network: ")+p_source_mmf_file);
+ network.ReadNetwork(p_source_mmf_file);
+ } else {
+ Error("Source MMF must be specified [-H]");
+ }
+ //extract the RBM from the network
+ if(network.Layers() != 1) {
+ Error(std::string("Number of layers must be 1")+p_source_mmf_file);
+ }
+ if(network.Layer(0).GetType() != CuComponent::RBM && network.Layer(0).GetType() != CuComponent::RBM_SPARSE) {
+ Error(std::string("Layer must be RBM")+p_source_mmf_file);
+ }
+ CuRbmBase& rbm = dynamic_cast<CuRbmBase&>(network.Layer(0));
+
+ // initialize the feature repository
+ feature_repo.Init(
+ swap_features, start_frm_ext, end_frm_ext, target_kind,
+ deriv_order, p_deriv_win_lenghts,
+ cmn_path, cmn_mask, cvn_path, cvn_mask, cvg_file
+ );
+ if(NULL != p_script) {
+ feature_repo.AddFileList(p_script);
+ } else {
+ Warning("WARNING: The script file is missing [-S]");
+ }
+ feature_repo.Trace(trace);
+
+ //set the learnrate, momentum, weightcost
+ rbm.LearnRate(learning_rate);
+ rbm.Momentum(momentum);
+ rbm.Weightcost(weightcost);
+
+ //seed the random number generator
+ if(seed == 0) {
+ struct timeval tv;
+ if (gettimeofday(&tv, 0) == -1) {
+ assert(0 && "gettimeofday does not work.");
+ exit(-1);
+ }
+ seed = (int)(tv.tv_sec) + (int)tv.tv_usec;
+ }
+ srand48(seed);
+
+ //initialize the matrix random number generator
+ CuRand<BaseFloat> cu_rand(bunch_size,rbm.GetNOutputs());
+
+
+
+ //**********************************************************************
+ //**********************************************************************
+ // INITIALIZATION DONE .................................................
+ //
+ // Start training
+ timer.Start();
+ std::cout << "===== TRbmCu TRAINING STARTED =====" << std::endl;
+ std::cout << "learning rate: " << learning_rate
+ << " momentum: " << momentum
+ << " weightcost: " << weightcost
+ << std::endl;
+ std::cout << "Using seed: " << seed << "\n";
+
+
+ CuCache cache;
+ cache.Init(cache_size,bunch_size);
+ cache.Trace(trace);
+ feature_repo.Rewind();
+
+ //**********************************************************************
+ //**********************************************************************
+ // MAIN LOOP
+ //
+ CuMatrix<BaseFloat> pos_vis, pos_hid, neg_vis, neg_hid;
+ CuMatrix<BaseFloat> dummy_labs, dummy_err;
+ while(!feature_repo.EndOfList()) {
+ timer_frontend.Start();
+ //fill cache
+ while(!cache.Full() && !feature_repo.EndOfList()) {
+ Matrix<BaseFloat> feats_host;
+ CuMatrix<BaseFloat> feats_original;
+ CuMatrix<BaseFloat> feats_expanded;
+
+ //read feats, perfrom feature transform
+ feature_repo.ReadFullMatrix(feats_host);
+ feats_original.CopyFrom(feats_host);
+ transform_network.Propagate(feats_original,feats_expanded);
+
+ //trim the start/end context
+ int rows = feats_expanded.Rows()-start_frm_ext-end_frm_ext;
+ CuMatrix<BaseFloat> feats_trim(rows,feats_expanded.Cols());
+ feats_trim.CopyRows(rows,start_frm_ext,feats_expanded,0);
+
+ //fake the labels!!!
+ CuMatrix<BaseFloat> labs_cu(feats_trim.Rows(),1);
+
+ //add to cache
+ cache.AddData(feats_trim,labs_cu);
+
+ feature_repo.MoveNext();
+ }
+ timer_frontend.End(); time_frontend += timer_frontend.Val();
+
+ if(randomize) {
+ //randomize the cache
+ cache.Randomize();
+ }
+
+ while(!cache.Empty()) {
+ //get training data
+ cache.GetBunch(pos_vis,dummy_labs);
+
+ //forward pass
+ rbm.Propagate(pos_vis,pos_hid);
+
+ //change the hidden values so we can generate negative example
+ if(rbm.HidType() == CuRbmBase::BERNOULLI) {
+ cu_rand.BinarizeProbs(pos_hid,neg_hid);
+ } else {
+ neg_hid.CopyFrom(pos_hid);
+ cu_rand.AddGaussNoise(neg_hid);
+ }
+
+ //reconstruct pass
+ rbm.Reconstruct(neg_hid,neg_vis);
+
+ //forward pass
+ rbm.Propagate(neg_vis, neg_hid);
+
+ //update the weioghts
+ rbm.RbmUpdate(pos_vis, pos_hid, neg_vis, neg_hid);
+
+ //evalueate mean square error
+ mse.Evaluate(neg_vis,pos_vis,dummy_err);
+
+ if(trace&2) std::cout << "." << std::flush;
+ }
+ //check the NaN/inf
+ pos_hid.CheckData();
+ }
+
+
+
+ //**********************************************************************
+ //**********************************************************************
+ // TRAINING FINISHED .................................................
+ //
+ // Let's store the network, report the log
+
+ if(trace&1) TraceLog("Training finished");
+
+ //write the network
+ if (NULL != p_targetmmf) {
+ if(trace&1) TraceLog(std::string("Writing network: ")+p_targetmmf);
+ network.WriteNetwork(p_targetmmf);
+ } else {
+ Error("missing argument --TARGETMMF");
+ }
+
+ timer.End();
+ std::cout << "===== TRbmCu FINISHED ( " << timer.Val() << "s ) "
+ << "[FPS:" << mse.GetFrames() / timer.Val()
+ << ",RT:" << 1.0f / (mse.GetFrames() / timer.Val() / 100.0f)
+ << "] =====" << std::endl;
+
+ //report objective function (accuracy, frame counts...)
+ std::cout << mse.Report();
+
+ if(trace &4) {
+ std::cout << "\n== PROFILE ==\nT-fe: " << time_frontend << std::endl;
+ }
+
+ return 0; ///finish OK
+
+} catch (std::exception& rExc) {
+ std::cerr << "Exception thrown" << std::endl;
+ std::cerr << rExc.what() << std::endl;
+ return 1;
+}