summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2015-06-02 18:10:33 +0800
committerJoe Zhao <ztuowen@gmail.com>2015-06-02 18:10:33 +0800
commitb44da2a2ab11425961014a39994484c92626ce58 (patch)
treec3b3d3fded5553bccd245a16a09e8323635624ff
parentb29d766bb0c3d2e5839164ef6cd316b2e00fba62 (diff)
downloadranksvm-b44da2a2ab11425961014a39994484c92626ce58.tar.gz
ranksvm-b44da2a2ab11425961014a39994484c92626ce58.tar.bz2
ranksvm-b44da2a2ab11425961014a39994484c92626ce58.zip
misc
-rw-r--r--model/rankaccu.cpp19
-rw-r--r--model/rankaccu.h2
-rw-r--r--train.cpp13
3 files changed, 33 insertions, 1 deletions
diff --git a/model/rankaccu.cpp b/model/rankaccu.cpp
index 8404abf..0f55e26 100644
--- a/model/rankaccu.cpp
+++ b/model/rankaccu.cpp
@@ -128,4 +128,23 @@ void rank_CMC(RidList &D,const std::vector<double> pred,CMC & cmc) {
break; // account only for the first match;
}
}
+}
+
+void rank_pair(RidList &D,const vector<double> pred,vector<double> &pair)
+{
+ int n =D.getSize(),q=D.getqSize();
+ pair.clear();
+ for (int i=0;i<n;i+=q)
+ {
+ int corr=0;
+ for (int j=0;j<q;++j)
+ if (D.getL(i+j)>0) {
+ corr = j;
+ break;
+ }
+
+ for (int j=0;j<q;++j)
+ if (j!=corr)
+ pair.push_back(pred[i+corr]-pred[i+j]);
+ }
} \ No newline at end of file
diff --git a/model/rankaccu.h b/model/rankaccu.h
index e8da882..8cac56c 100644
--- a/model/rankaccu.h
+++ b/model/rankaccu.h
@@ -45,4 +45,6 @@ void rank_CMC(RidList &D,const std::vector<double> pred,CMC & cmc);
void rank_accu(RidList &D,const std::vector<double> pred);
+void rank_pair(RidList &D,const std::vector<double> pred,std::vector<double> &pair);
+
#endif //RANKSVM_RANKACCU_H
diff --git a/train.cpp b/train.cpp
index fda06cb..dd1b1e6 100644
--- a/train.cpp
+++ b/train.cpp
@@ -72,9 +72,19 @@ int predict(DataProvider &dp) {
rank_CMC(D,L,cmc);
}
- if (vm.count("output") && vm.count("predict"))
+ if (vm.count("predict"))
+ {
+ if (vm.count("pair"))
+ {
+ vector<double> pair;
+ rank_pair(D,L,pair);
+ for (int i=0;i<pair.size();++i)
+ *ot<<pair[i]<<endl;
+ }
+ else
for (int i=0; i<L.size();++i)
*ot<<L[i]<<endl;
+ }
LOG(INFO)<<"Finished";
if (vm.count("cmc"))
@@ -110,6 +120,7 @@ int main(int argc, char **argv) {
("cmc,C", "enable cmc auditing")
("debug,d", "show debug messages")
("single,s", "one from a pair")
+ ("pair,p","get pair result")
("model,m", po::value<string>(), "set input model file")
("output,o", po::value<string>(), "set output model/prediction file")
("feature,i", po::value<string>(), "set input feature file")