diff options
author | Joe Zhao <ztuowen@gmail.com> | 2015-06-11 14:19:16 +0800 |
---|---|---|
committer | Joe Zhao <ztuowen@gmail.com> | 2015-06-11 14:19:16 +0800 |
commit | e80d3cbbdc61c28fffbd75530888aa56f6ac15b1 (patch) | |
tree | 0c6ddf24e3822f3d9755308bc9299262bf82c884 | |
parent | 93b28277476de5d4fbfbf5ac236f9f619d482c46 (diff) | |
download | ranksvm-e80d3cbbdc61c28fffbd75530888aa56f6ac15b1.tar.gz ranksvm-e80d3cbbdc61c28fffbd75530888aa56f6ac15b1.tar.bz2 ranksvm-e80d3cbbdc61c28fffbd75530888aa56f6ac15b1.zip |
fscore
-rw-r--r-- | model/rankaccu.h | 72 | ||||
-rw-r--r-- | train.cpp | 12 |
2 files changed, 84 insertions, 0 deletions
diff --git a/model/rankaccu.h b/model/rankaccu.h index 8cac56c..a5866da 100644 --- a/model/rankaccu.h +++ b/model/rankaccu.h @@ -41,6 +41,78 @@ public: } }; +class Fscore +{ +private: + std::vector<double> pos,neg; + std::vector<double> apos,aneg; + int cpos,cneg; + int f; +public: + void clear() { + cpos=0;cneg=0;pos.clear();neg.clear();apos.clear();aneg.clear(); + } + Fscore(){clear();} + void init(int fsize) { + f=fsize;pos.resize(fsize);neg.resize(fsize);apos.resize(fsize);aneg.resize(fsize); + for (int i=0;i<fsize;++i) apos[i]=aneg[i]=pos[i]=neg[i]=0; + } + void firstPass(RidList &rid,int x){ + Eigen::VectorXd vec = rid.getVec(x); + std::vector<double> *p; + if (rid.getL(x)>0) { + p=' + cpos+=1; + } + else { + p=&aneg; + cneg+=1; + } + for (int i=0;i<f;++i) + (*p)[i]+=vec(i); + } + void calAvg(){ + for (int i=0;i<f;++i) + apos[i]/=cpos; + for (int i=0;i<f;++i) + aneg[i]/=cneg; + } + void secondPass(RidList &rid,int x){ + Eigen::VectorXd vec = rid.getVec(x); + std::vector<double> *p,*a; + if (rid.getL(x)>0) { + p=&pos; + a=' + } + else { + p=&neg; + a=&aneg; + } + for (int i=0;i<f;++i) + (*p)[i]+=(vec(i)-(*a)[i])*(vec(i)-(*a)[i]); + } + std::vector<double> getFscore(){ + std::vector<double> res; + res.reserve(f); + for (int i=0;i<f;++i) + { + double avg; + avg = (cpos*apos[i]+cneg*aneg[i])/(cpos+cneg); + res[i] = (apos[i]-avg)*(apos[i]-avg)+(aneg[i]-avg)*(aneg[i]-avg); + res[i] /= pos[i]/(cpos-1)+neg[i]/(cneg-1); + } + return res; + } + void audit(RidList &rid){ + init(rid.getfSize()); + for (int i=0;i<rid.getSize();++i) + firstPass(rid,i); + calAvg(); + for (int i=0;i<rid.getSize();++i) + secondPass(rid,i); + } +}; + void rank_CMC(RidList &D,const std::vector<double> pred,CMC & cmc); void rank_accu(RidList &D,const std::vector<double> pred); @@ -48,6 +48,8 @@ int predict(DataProvider &dp) { RidList D; vector<double> L; CMC cmc; + Fscore f; + LOG(INFO)<<"Prediction started"; ofstream fout; @@ -82,6 +84,15 @@ int predict(DataProvider &dp) { *ot<<pair[i]<<endl; } else + if (vm.count("fscore")) + { + vector<double> pair; + f.audit(D); + pair=f.getFscore(); + for (int i=0;i<D.getfSize();++i) + *ot<<pair[i]<<endl; + } + else for (int i=0; i<L.size();++i) *ot<<L[i]<<endl; } @@ -121,6 +132,7 @@ int main(int argc, char **argv) { ("debug,d", "show debug messages") ("single,s", "one from a pair") ("pair,p","get pair result") + ("fscore,f","get F-score") ("model,m", po::value<string>(), "set input model file") ("output,o", po::value<string>(), "set output model/prediction file") ("feature,i", po::value<string>(), "set input feature file") |