diff options
author | Joe Zhao <ztuowen@gmail.com> | 2015-06-16 11:34:46 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2015-06-16 11:34:46 +0800 |
commit | 44018ad44d7d0d8196f16402bd1fa6c1c10de8ad (patch) | |
tree | b81955eabcaae9d22fee1bd937e7ed4b65a43cdc | |
parent | e80d3cbbdc61c28fffbd75530888aa56f6ac15b1 (diff) | |
download | ranksvm-44018ad44d7d0d8196f16402bd1fa6c1c10de8ad.tar.gz ranksvm-44018ad44d7d0d8196f16402bd1fa6c1c10de8ad.tar.bz2 ranksvm-44018ad44d7d0d8196f16402bd1fa6c1c10de8ad.zip |
fscore
-rw-r--r-- | split.cpp | 2 | ||||
-rw-r--r-- | tools/fileDataProvider.cpp | 17 | ||||
-rw-r--r-- | tools/fileDataProvider.h | 10 | ||||
-rw-r--r-- | train.cpp | 25 |
4 files changed, 47 insertions, 7 deletions
@@ -64,7 +64,7 @@ int main(int argc, char **argv) dp.close(); return 0; } - + RidFileDP::seed(); RidFileDP dp(vm["input"].as<string>().c_str()); vector<DataEntry*> a; vector<DataEntry*> b; diff --git a/tools/fileDataProvider.cpp b/tools/fileDataProvider.cpp index 9be1132..2b52dc7 100644 --- a/tools/fileDataProvider.cpp +++ b/tools/fileDataProvider.cpp @@ -42,6 +42,11 @@ void RidFileDP::readEntries() { d.clear(); fin >> fsize; LOG(INFO) << "Feature size:" << fsize; + if (!maskinit) + { + for (int i=0;i<fsize;++i) + mask.push_back(1); + } d.setfSize(fsize); while (!fin.eof()) { e = new DataEntry; @@ -52,8 +57,10 @@ void RidFileDP::readEntries() { } e->feature.resize(fsize); e->rank=-1; + double tin; for (int i = 0; i < fsize; ++i) { - fin >> e->feature(i); + fin >> tin; + e->feature(i) = tin*mask[i]; } d.addEntry(e); } @@ -124,6 +131,10 @@ int RidFileDP::getpSize() { return p.size(); }; +void RidFileDP::seed() { + gen.seed(time(NULL)); +} + void RidFileDP::shuffle(vector<DataEntry*> &dat) { DataEntry* e; @@ -131,6 +142,7 @@ void RidFileDP::shuffle(vector<DataEntry*> &dat) for (int i=0;i<sz;++i) { int pos = (int)(gen()%(sz-i)); + cout<<pos<<endl; e=dat[pos]; dat[pos] = dat[sz-i-1]; dat[sz-i-1] = e; @@ -139,7 +151,6 @@ void RidFileDP::shuffle(vector<DataEntry*> &dat) void RidFileDP::take(int n,vector<DataEntry*> &a,vector<DataEntry*> &b) { - gen.seed(time(NULL)); DataEntry *e; if (!read) readEntries(); @@ -148,9 +159,9 @@ void RidFileDP::take(int n,vector<DataEntry*> &a,vector<DataEntry*> &b) a.clear(); b.clear(); std::vector<DataEntry*> &dat = d.getData(); - shuffle(tmp); for (int i=0;i<dat.size();++i) tmp.push_back(dat[i]); + shuffle(tmp); int pos = 0; string qid; for (int i=0;i<n;++i) diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h index 972a4c5..0ab1948 100644 --- a/tools/fileDataProvider.h +++ b/tools/fileDataProvider.h @@ -29,13 +29,20 @@ class RidFileDP:public DataProvider private: std::string fname; std::ifstream fin; + std::vector<double> mask; DataList d; bool read; + bool maskinit; int pos; int qid; public: - RidFileDP(std::string fn=""):fname(fn){read=false;}; + RidFileDP(std::string fn=""):fname(fn),read(false),maskinit(false){}; void readEntries(); + void datmask(std::vector<double> &m){ + mask.resize(m.size()); + for (int i=0;i<m.size();++i) + mask[i]=m[i]; + maskinit=true;} int getfSize() { if(!read) readEntries(); return d.getfSize();}; int getpSize(); void shuffle(std::vector<DataEntry*> &dat); @@ -52,6 +59,7 @@ public: for (int i=0;i<dat.size();++i) rid.push_back(dat[i]); } + static void seed(); }; #endif
\ No newline at end of file @@ -115,6 +115,17 @@ int predict(DataProvider &dp) { return 0; } +void getmask(string fname,vector<double> &msk) +{ + ifstream fin; + int fsize; + fin.open(fname.c_str()); + fin>>fsize; + for (int i=0;i<fsize;++i) + fin>>msk[i]; + fin.close(); +} + int main(int argc, char **argv) { el::Configurations defaultConf; defaultConf.setToDefault(); @@ -133,6 +144,7 @@ int main(int argc, char **argv) { ("single,s", "one from a pair") ("pair,p","get pair result") ("fscore,f","get F-score") + ("mask,M", po::value<string>(), "set feature mask") ("model,m", po::value<string>(), "set input model file") ("output,o", po::value<string>(), "set output model/prediction file") ("feature,i", po::value<string>(), "set input feature file") @@ -177,9 +189,18 @@ int main(int argc, char **argv) { else return 0; DataProvider* dp; if (vm["feature"].as<string>().find(".rid") == string::npos) - dp = new FileDP(vm["feature"].as<string>()); + LOG(FATAL)<<"Format not supported"; else - dp = new RidFileDP(vm["feature"].as<string>()); + { + RidFileDP* tmpdp = new RidFileDP(vm["feature"].as<string>()); + if (vm.count("mask")) + { + vector<double> msk; + getmask(vm["mask"].as<string>(),msk); + tmpdp->datmask(msk); + } + dp = tmpdp; + } mainf(*dp); delete dp; return 0; |