summaryrefslogtreecommitdiff
path: root/train.cpp
blob: bae88f3c7f017bd19f8469adc7de11f977d23d07 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#include <iostream>
#include <Eigen/Dense>
#include <boost/program_options.hpp>
#include <list>
#include "tools/easylogging++.h"
#include "model/ranksvmtn.h"
#include "tools/fileDataProvider.h"
#include "model/rankaccu.h"

INITIALIZE_EASYLOGGINGPP

using namespace Eigen;
using namespace std;
namespace po = boost::program_options;

po::variables_map vm;

typedef int (*mainFunc)(DataProvider &dp);

int train(DataProvider &dp) {
    RSVM *rsvm;
    rsvm = RSVM::loadModel(vm["model"].as<string>());

    dp.open();
    DataList D;

    LOG(INFO)<<"Training started";
    dp.getAllDataSet(D);
    LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
    LOG(INFO)<<"C: "<<C;
    rsvm->train(D);
    vector<double> L;
    rsvm->predict(D,L);

    LOG(INFO)<<"Training finished,saving model";

    dp.close();
    rsvm->saveModel(vm["output"].as<string>().c_str());
    delete rsvm;
    return 0;
}

int predict(DataProvider &dp) {
    RSVM *rsvm;
    rsvm = RSVM::loadModel(vm["model"].as<string>().c_str());

    dp.open();
    DataList D;
    vector<double> L;
    CMC cmc;
    LOG(INFO)<<"Prediction started";

    ofstream fout;

    ostream* ot;

    if (vm.count("output")) {
        fout.open(vm["output"].as<string>().c_str());
        ot=&fout;
    }
    else
        ot=&cout;

    while (!dp.EOFile())
    {
        dp.getDataSet(D);
        LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
        rsvm->predict(D,L);

        if (vm.count("validate"))
        {
            rank_accu(D,L);
            if (vm.count("cmc"))
                rank_CMC(D,L,cmc);
        }

        if (vm.count("output") || !vm.count("validate"))
            for (int i=0; i<L.size();++i)
                *ot<<L[i]<<endl;
    }

    LOG(INFO)<<"Finished";
    if (vm.count("cmc"))
    {
        LOG(INFO)<< "CMC accounted over " <<cmc.getCount() << " queries";
        *ot << "CMC"<<endl;
        vector<double> cur = cmc.getAcc();
        for (int i = 0;i<CMC_MAX;++i)
        *ot << cur[i]<<endl;
    }
    if (vm.count("output"))
        fout.close();
    dp.close();
    delete rsvm;
    return 0;
}

int main(int argc, char **argv) {
    el::Configurations defaultConf;
    defaultConf.setToDefault();
    // Values are always std::string
    defaultConf.setGlobally(el::ConfigurationType::Format, "%datetime %level %msg");

    // Defining program options
    po::options_description desc("Allowed options");
    desc.add_options()
            ("help,h", "produce help message")
            ("train,T", "training model")
            ("validate,V", "validate model")
            ("predict,P", "use model for prediction")
            ("cmc,C", "enable cmc auditing")
            ("debug,d", "show debug messages")
            ("model,m", po::value<string>(), "set input model file")
            ("output,o", po::value<string>(), "set output model/prediction file")
            ("feature,i", po::value<string>(), "set input feature file")
            ("c,c",po::value<double>(),"trades margin size against training error");

    // Parsing program options
    po::store(po::parse_command_line(argc, argv, desc), vm);
    po::notify(vm);

    // Print help if necessary
    if (vm.count("help") || !(vm.count("train") || vm.count("validate") || vm.count("predict"))) {
        cout << desc;
        return 0;
    }

    if (!vm.count("debug"))
        defaultConf.setGlobally(el::ConfigurationType::Enabled, "false");
    // default logger uses default configurations
    el::Loggers::reconfigureLogger("default", defaultConf);

    mainFunc mainf;
    if (vm.count("train")) {
        if (vm.count("c")) {
            C=vm["c"].as<double>();
        }
        mainf = &train;
    }
    else if (vm.count("validate")||vm.count("predict")) {
        mainf = &predict;
    }
    else return 0;
    DataProvider* dp;
    if (vm["feature"].as<string>().find(".rid") == string::npos)
        dp = new FileDP(vm["feature"].as<string>());
    else
        dp = new RidFileDP(vm["feature"].as<string>());
    mainf(*dp);
    delete dp;
    return 0;
}