summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2015-06-16 11:34:46 +0800
committerJoe Zhao <ztuowen@gmail.com>2015-06-16 11:34:46 +0800
commit44018ad44d7d0d8196f16402bd1fa6c1c10de8ad (patch)
treeb81955eabcaae9d22fee1bd937e7ed4b65a43cdc
parente80d3cbbdc61c28fffbd75530888aa56f6ac15b1 (diff)
downloadranksvm-44018ad44d7d0d8196f16402bd1fa6c1c10de8ad.tar.gz
ranksvm-44018ad44d7d0d8196f16402bd1fa6c1c10de8ad.tar.bz2
ranksvm-44018ad44d7d0d8196f16402bd1fa6c1c10de8ad.zip
fscore
-rw-r--r--split.cpp2
-rw-r--r--tools/fileDataProvider.cpp17
-rw-r--r--tools/fileDataProvider.h10
-rw-r--r--train.cpp25
4 files changed, 47 insertions, 7 deletions
diff --git a/split.cpp b/split.cpp
index e774ea9..ec23af2 100644
--- a/split.cpp
+++ b/split.cpp
@@ -64,7 +64,7 @@ int main(int argc, char **argv)
dp.close();
return 0;
}
-
+ RidFileDP::seed();
RidFileDP dp(vm["input"].as<string>().c_str());
vector<DataEntry*> a;
vector<DataEntry*> b;
diff --git a/tools/fileDataProvider.cpp b/tools/fileDataProvider.cpp
index 9be1132..2b52dc7 100644
--- a/tools/fileDataProvider.cpp
+++ b/tools/fileDataProvider.cpp
@@ -42,6 +42,11 @@ void RidFileDP::readEntries() {
d.clear();
fin >> fsize;
LOG(INFO) << "Feature size:" << fsize;
+ if (!maskinit)
+ {
+ for (int i=0;i<fsize;++i)
+ mask.push_back(1);
+ }
d.setfSize(fsize);
while (!fin.eof()) {
e = new DataEntry;
@@ -52,8 +57,10 @@ void RidFileDP::readEntries() {
}
e->feature.resize(fsize);
e->rank=-1;
+ double tin;
for (int i = 0; i < fsize; ++i) {
- fin >> e->feature(i);
+ fin >> tin;
+ e->feature(i) = tin*mask[i];
}
d.addEntry(e);
}
@@ -124,6 +131,10 @@ int RidFileDP::getpSize() {
return p.size();
};
+void RidFileDP::seed() {
+ gen.seed(time(NULL));
+}
+
void RidFileDP::shuffle(vector<DataEntry*> &dat)
{
DataEntry* e;
@@ -131,6 +142,7 @@ void RidFileDP::shuffle(vector<DataEntry*> &dat)
for (int i=0;i<sz;++i)
{
int pos = (int)(gen()%(sz-i));
+ cout<<pos<<endl;
e=dat[pos];
dat[pos] = dat[sz-i-1];
dat[sz-i-1] = e;
@@ -139,7 +151,6 @@ void RidFileDP::shuffle(vector<DataEntry*> &dat)
void RidFileDP::take(int n,vector<DataEntry*> &a,vector<DataEntry*> &b)
{
- gen.seed(time(NULL));
DataEntry *e;
if (!read)
readEntries();
@@ -148,9 +159,9 @@ void RidFileDP::take(int n,vector<DataEntry*> &a,vector<DataEntry*> &b)
a.clear();
b.clear();
std::vector<DataEntry*> &dat = d.getData();
- shuffle(tmp);
for (int i=0;i<dat.size();++i)
tmp.push_back(dat[i]);
+ shuffle(tmp);
int pos = 0;
string qid;
for (int i=0;i<n;++i)
diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h
index 972a4c5..0ab1948 100644
--- a/tools/fileDataProvider.h
+++ b/tools/fileDataProvider.h
@@ -29,13 +29,20 @@ class RidFileDP:public DataProvider
private:
std::string fname;
std::ifstream fin;
+ std::vector<double> mask;
DataList d;
bool read;
+ bool maskinit;
int pos;
int qid;
public:
- RidFileDP(std::string fn=""):fname(fn){read=false;};
+ RidFileDP(std::string fn=""):fname(fn),read(false),maskinit(false){};
void readEntries();
+ void datmask(std::vector<double> &m){
+ mask.resize(m.size());
+ for (int i=0;i<m.size();++i)
+ mask[i]=m[i];
+ maskinit=true;}
int getfSize() { if(!read) readEntries(); return d.getfSize();};
int getpSize();
void shuffle(std::vector<DataEntry*> &dat);
@@ -52,6 +59,7 @@ public:
for (int i=0;i<dat.size();++i)
rid.push_back(dat[i]);
}
+ static void seed();
};
#endif \ No newline at end of file
diff --git a/train.cpp b/train.cpp
index b6ac730..04e80ce 100644
--- a/train.cpp
+++ b/train.cpp
@@ -115,6 +115,17 @@ int predict(DataProvider &dp) {
return 0;
}
+void getmask(string fname,vector<double> &msk)
+{
+ ifstream fin;
+ int fsize;
+ fin.open(fname.c_str());
+ fin>>fsize;
+ for (int i=0;i<fsize;++i)
+ fin>>msk[i];
+ fin.close();
+}
+
int main(int argc, char **argv) {
el::Configurations defaultConf;
defaultConf.setToDefault();
@@ -133,6 +144,7 @@ int main(int argc, char **argv) {
("single,s", "one from a pair")
("pair,p","get pair result")
("fscore,f","get F-score")
+ ("mask,M", po::value<string>(), "set feature mask")
("model,m", po::value<string>(), "set input model file")
("output,o", po::value<string>(), "set output model/prediction file")
("feature,i", po::value<string>(), "set input feature file")
@@ -177,9 +189,18 @@ int main(int argc, char **argv) {
else return 0;
DataProvider* dp;
if (vm["feature"].as<string>().find(".rid") == string::npos)
- dp = new FileDP(vm["feature"].as<string>());
+ LOG(FATAL)<<"Format not supported";
else
- dp = new RidFileDP(vm["feature"].as<string>());
+ {
+ RidFileDP* tmpdp = new RidFileDP(vm["feature"].as<string>());
+ if (vm.count("mask"))
+ {
+ vector<double> msk;
+ getmask(vm["mask"].as<string>(),msk);
+ tmpdp->datmask(msk);
+ }
+ dp = tmpdp;
+ }
mainf(*dp);
delete dp;
return 0;