summaryrefslogtreecommitdiff
path: root/train.cpp
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2015-06-16 11:34:46 +0800
committerJoe Zhao <ztuowen@gmail.com>2015-06-16 11:34:46 +0800
commit44018ad44d7d0d8196f16402bd1fa6c1c10de8ad (patch)
treeb81955eabcaae9d22fee1bd937e7ed4b65a43cdc /train.cpp
parente80d3cbbdc61c28fffbd75530888aa56f6ac15b1 (diff)
downloadranksvm-44018ad44d7d0d8196f16402bd1fa6c1c10de8ad.tar.gz
ranksvm-44018ad44d7d0d8196f16402bd1fa6c1c10de8ad.tar.bz2
ranksvm-44018ad44d7d0d8196f16402bd1fa6c1c10de8ad.zip
fscore
Diffstat (limited to 'train.cpp')
-rw-r--r--train.cpp25
1 files changed, 23 insertions, 2 deletions
diff --git a/train.cpp b/train.cpp
index b6ac730..04e80ce 100644
--- a/train.cpp
+++ b/train.cpp
@@ -115,6 +115,17 @@ int predict(DataProvider &dp) {
return 0;
}
+void getmask(string fname,vector<double> &msk)
+{
+ ifstream fin;
+ int fsize;
+ fin.open(fname.c_str());
+ fin>>fsize;
+ for (int i=0;i<fsize;++i)
+ fin>>msk[i];
+ fin.close();
+}
+
int main(int argc, char **argv) {
el::Configurations defaultConf;
defaultConf.setToDefault();
@@ -133,6 +144,7 @@ int main(int argc, char **argv) {
("single,s", "one from a pair")
("pair,p","get pair result")
("fscore,f","get F-score")
+ ("mask,M", po::value<string>(), "set feature mask")
("model,m", po::value<string>(), "set input model file")
("output,o", po::value<string>(), "set output model/prediction file")
("feature,i", po::value<string>(), "set input feature file")
@@ -177,9 +189,18 @@ int main(int argc, char **argv) {
else return 0;
DataProvider* dp;
if (vm["feature"].as<string>().find(".rid") == string::npos)
- dp = new FileDP(vm["feature"].as<string>());
+ LOG(FATAL)<<"Format not supported";
else
- dp = new RidFileDP(vm["feature"].as<string>());
+ {
+ RidFileDP* tmpdp = new RidFileDP(vm["feature"].as<string>());
+ if (vm.count("mask"))
+ {
+ vector<double> msk;
+ getmask(vm["mask"].as<string>(),msk);
+ tmpdp->datmask(msk);
+ }
+ dp = tmpdp;
+ }
mainf(*dp);
delete dp;
return 0;