summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2015-06-11 14:19:16 +0800
committerJoe Zhao <ztuowen@gmail.com>2015-06-11 14:19:16 +0800
commite80d3cbbdc61c28fffbd75530888aa56f6ac15b1 (patch)
tree0c6ddf24e3822f3d9755308bc9299262bf82c884
parent93b28277476de5d4fbfbf5ac236f9f619d482c46 (diff)
downloadranksvm-e80d3cbbdc61c28fffbd75530888aa56f6ac15b1.tar.gz
ranksvm-e80d3cbbdc61c28fffbd75530888aa56f6ac15b1.tar.bz2
ranksvm-e80d3cbbdc61c28fffbd75530888aa56f6ac15b1.zip
fscore
-rw-r--r--model/rankaccu.h72
-rw-r--r--train.cpp12
2 files changed, 84 insertions, 0 deletions
diff --git a/model/rankaccu.h b/model/rankaccu.h
index 8cac56c..a5866da 100644
--- a/model/rankaccu.h
+++ b/model/rankaccu.h
@@ -41,6 +41,78 @@ public:
}
};
+class Fscore
+{
+private:
+ std::vector<double> pos,neg;
+ std::vector<double> apos,aneg;
+ int cpos,cneg;
+ int f;
+public:
+ void clear() {
+ cpos=0;cneg=0;pos.clear();neg.clear();apos.clear();aneg.clear();
+ }
+ Fscore(){clear();}
+ void init(int fsize) {
+ f=fsize;pos.resize(fsize);neg.resize(fsize);apos.resize(fsize);aneg.resize(fsize);
+ for (int i=0;i<fsize;++i) apos[i]=aneg[i]=pos[i]=neg[i]=0;
+ }
+ void firstPass(RidList &rid,int x){
+ Eigen::VectorXd vec = rid.getVec(x);
+ std::vector<double> *p;
+ if (rid.getL(x)>0) {
+ p=&apos;
+ cpos+=1;
+ }
+ else {
+ p=&aneg;
+ cneg+=1;
+ }
+ for (int i=0;i<f;++i)
+ (*p)[i]+=vec(i);
+ }
+ void calAvg(){
+ for (int i=0;i<f;++i)
+ apos[i]/=cpos;
+ for (int i=0;i<f;++i)
+ aneg[i]/=cneg;
+ }
+ void secondPass(RidList &rid,int x){
+ Eigen::VectorXd vec = rid.getVec(x);
+ std::vector<double> *p,*a;
+ if (rid.getL(x)>0) {
+ p=&pos;
+ a=&apos;
+ }
+ else {
+ p=&neg;
+ a=&aneg;
+ }
+ for (int i=0;i<f;++i)
+ (*p)[i]+=(vec(i)-(*a)[i])*(vec(i)-(*a)[i]);
+ }
+ std::vector<double> getFscore(){
+ std::vector<double> res;
+ res.reserve(f);
+ for (int i=0;i<f;++i)
+ {
+ double avg;
+ avg = (cpos*apos[i]+cneg*aneg[i])/(cpos+cneg);
+ res[i] = (apos[i]-avg)*(apos[i]-avg)+(aneg[i]-avg)*(aneg[i]-avg);
+ res[i] /= pos[i]/(cpos-1)+neg[i]/(cneg-1);
+ }
+ return res;
+ }
+ void audit(RidList &rid){
+ init(rid.getfSize());
+ for (int i=0;i<rid.getSize();++i)
+ firstPass(rid,i);
+ calAvg();
+ for (int i=0;i<rid.getSize();++i)
+ secondPass(rid,i);
+ }
+};
+
void rank_CMC(RidList &D,const std::vector<double> pred,CMC & cmc);
void rank_accu(RidList &D,const std::vector<double> pred);
diff --git a/train.cpp b/train.cpp
index dd1b1e6..b6ac730 100644
--- a/train.cpp
+++ b/train.cpp
@@ -48,6 +48,8 @@ int predict(DataProvider &dp) {
RidList D;
vector<double> L;
CMC cmc;
+ Fscore f;
+
LOG(INFO)<<"Prediction started";
ofstream fout;
@@ -82,6 +84,15 @@ int predict(DataProvider &dp) {
*ot<<pair[i]<<endl;
}
else
+ if (vm.count("fscore"))
+ {
+ vector<double> pair;
+ f.audit(D);
+ pair=f.getFscore();
+ for (int i=0;i<D.getfSize();++i)
+ *ot<<pair[i]<<endl;
+ }
+ else
for (int i=0; i<L.size();++i)
*ot<<L[i]<<endl;
}
@@ -121,6 +132,7 @@ int main(int argc, char **argv) {
("debug,d", "show debug messages")
("single,s", "one from a pair")
("pair,p","get pair result")
+ ("fscore,f","get F-score")
("model,m", po::value<string>(), "set input model file")
("output,o", po::value<string>(), "set output model/prediction file")
("feature,i", po::value<string>(), "set input feature file")