diff options
Diffstat (limited to 'train.cpp')
-rw-r--r-- | train.cpp | 25 |
1 files changed, 23 insertions, 2 deletions
@@ -115,6 +115,17 @@ int predict(DataProvider &dp) { return 0; } +void getmask(string fname,vector<double> &msk) +{ + ifstream fin; + int fsize; + fin.open(fname.c_str()); + fin>>fsize; + for (int i=0;i<fsize;++i) + fin>>msk[i]; + fin.close(); +} + int main(int argc, char **argv) { el::Configurations defaultConf; defaultConf.setToDefault(); @@ -133,6 +144,7 @@ int main(int argc, char **argv) { ("single,s", "one from a pair") ("pair,p","get pair result") ("fscore,f","get F-score") + ("mask,M", po::value<string>(), "set feature mask") ("model,m", po::value<string>(), "set input model file") ("output,o", po::value<string>(), "set output model/prediction file") ("feature,i", po::value<string>(), "set input feature file") @@ -177,9 +189,18 @@ int main(int argc, char **argv) { else return 0; DataProvider* dp; if (vm["feature"].as<string>().find(".rid") == string::npos) - dp = new FileDP(vm["feature"].as<string>()); + LOG(FATAL)<<"Format not supported"; else - dp = new RidFileDP(vm["feature"].as<string>()); + { + RidFileDP* tmpdp = new RidFileDP(vm["feature"].as<string>()); + if (vm.count("mask")) + { + vector<double> msk; + getmask(vm["mask"].as<string>(),msk); + tmpdp->datmask(msk); + } + dp = tmpdp; + } mainf(*dp); delete dp; return 0; |