1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
|
#include <iostream>
#include <Eigen/Dense>
#include <boost/program_options.hpp>
#include <list>
#include "tools/easylogging++.h"
#include "model/ranksvmtn.h"
#include "tools/fileDataProvider.h"
#include "model/rankaccu.h"
INITIALIZE_EASYLOGGINGPP
using namespace Eigen;
using namespace std;
namespace po = boost::program_options;
po::variables_map vm;
typedef int (*mainFunc)(DataProvider &dp);
int train(DataProvider &dp) {
RSVM *rsvm;
rsvm = RSVM::loadModel(vm["model"].as<string>());
dp.open();
DataList D;
LOG(INFO)<<"Training started";
dp.getAllDataSet(D);
LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
rsvm->train(D);
vector<double> L;
rsvm->predict(D,L);
LOG(INFO)<<"Training finished,saving model";
dp.close();
rsvm->saveModel(vm["output"].as<string>().c_str());
delete rsvm;
return 0;
}
int predict(DataProvider &dp) {
RSVM *rsvm;
rsvm = RSVM::loadModel(vm["model"].as<string>().c_str());
dp.open();
DataList D;
vector<double> L;
CMC cmc;
LOG(INFO)<<"Prediction started";
ofstream fout;
ostream* ot;
if (vm.count("output")) {
fout.open(vm["output"].as<string>().c_str());
ot=&fout;
}
else
ot=&cout;
while (!dp.EOFile())
{
dp.getDataSet(D);
LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
rsvm->predict(D,L);
if (vm.count("validate"))
{
rank_accu(D,L);
if (vm.count("cmc"))
rank_CMC(D,L,cmc);
}
if (vm.count("output") || !vm.count("validate"))
for (int i=0; i<L.size();++i)
*ot<<L[i]<<endl;
}
LOG(INFO)<<"Finished";
if (vm.count("cmc"))
{
LOG(INFO)<< "CMC accounted over " <<cmc.getCount() << " queries";
*ot << "CMC"<<endl;
vector<double> cur = cmc.getAcc();
for (int i = 0;i<CMC_MAX;++i)
*ot << cur[i]<<endl;
}
if (vm.count("output"))
fout.close();
dp.close();
delete rsvm;
return 0;
}
int main(int argc, char **argv) {
el::Configurations defaultConf;
defaultConf.setToDefault();
// Values are always std::string
defaultConf.setGlobally(el::ConfigurationType::Format, "%datetime %level %msg");
// Defining program options
po::options_description desc("Allowed options");
desc.add_options()
("help,h", "produce help message")
("train,T", "training model")
("validate,V", "validate model")
("predict,P", "use model for prediction")
("cmc,C", "enable cmc auditing")
("debug,d", "show debug messages")
("model,m", po::value<string>(), "set input model file")
("output,o", po::value<string>(), "set output model/prediction file")
("feature,i", po::value<string>(), "set input feature file");
// Parsing program options
po::store(po::parse_command_line(argc, argv, desc), vm);
po::notify(vm);
// Print help if necessary
if (vm.count("help") || !(vm.count("train") || vm.count("validate") || vm.count("predict"))) {
cout << desc;
return 0;
}
if (!vm.count("debug"))
defaultConf.setGlobally(el::ConfigurationType::Enabled, "false");
// default logger uses default configurations
el::Loggers::reconfigureLogger("default", defaultConf);
mainFunc mainf;
if (vm.count("train")) {
mainf = &train;
}
else if (vm.count("validate")||vm.count("predict")) {
mainf = &predict;
}
else return 0;
DataProvider* dp;
if (vm["feature"].as<string>().find(".rid") == string::npos)
dp = new FileDP(vm["feature"].as<string>());
else
dp = new RidFileDP(vm["feature"].as<string>());
mainf(*dp);
delete dp;
return 0;
}
|