/*
 * ranksvm: main program
 *  usage: ./ranksvm -h to see all options
 *  support:
 *      training
 *      validating
 *      predicting
 *  model:
 *      TN  RankSVM(truncated newton, conjugate gradient, various opt)
 *      BH  bhat-dist
 *      HE  Hell-dist(but output chance instead?!)
 *  out features:
 *      cmc
 *          Cumulative Matching Characteristic
 *      avg
 *          Normalized avg rank
 *      predict
 *          image pair relevance value
 */


#include <iostream>
#include <Eigen/Dense>
#include <boost/program_options.hpp>
#include <list>
#include "tools/easylogging++.h"
#include "model/ranksvmtn.h"
#include "tools/fileDataProvider.h"
#include "model/rankaccu.h"

INITIALIZE_EASYLOGGINGPP

using namespace Eigen;
using namespace std;
namespace po = boost::program_options;

po::variables_map vm;

typedef int (*mainFunc)(DataProvider &dp);

int train(DataProvider &dp) {
    RSVM *rsvm;
    rsvm = RSVM::loadModel(vm["model"].as<string>());

    dp.open();
    RidList D;

    LOG(INFO)<<"Training started";
    dp.getAllDataSet(D);
    LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
    LOG(INFO)<<"C: "<<C<<" ,iter: "<<maxiter<<" ,prec: "<<prec;
    LOG(INFO)<<"cg_maxiter: "<<cg_maxiter<<" ,cg_prec: "<<cg_prec<<" ,ls_maxiter: "<<ls_maxiter<<" ,ls_prec: "<<ls_prec;

    rsvm->train(D);

    LOG(INFO)<<"Training finished,saving model";

    dp.close();
    rsvm->saveModel(vm["output"].as<string>().c_str());
    delete rsvm;
    return 0;
}

int predict(DataProvider &dp) {
    RSVM *rsvm;
    rsvm = RSVM::loadModel(vm["model"].as<string>().c_str());

    dp.open();
    RidList D;
    vector<double> L;
    CMC cmc;
    Fscore f;

    LOG(INFO)<<"Prediction started";

    ofstream fout;

    ostream* ot;

    if (vm.count("output")) {
        fout.open(vm["output"].as<string>().c_str());
        ot=&fout;
    }
    else
        ot=&cout;

    dp.getAllDataSet(D);
    LOG(INFO)<<"Read "<<D.getSize()<<" entries with "<< D.getfSize()<<" features";
    rsvm->predict(D,L);

    if (vm.count("validate"))
    {
        rank_accu(D,L);
        if (vm.count("cmc"))
            rank_CMC(D,L,cmc);
    }

    if (vm.count("predict"))
    {
        if (vm.count("pair"))
        {
            vector<double> pair;
            rank_pair(D,L,pair);
            for (int i=0;i<pair.size();++i)
                *ot<<pair[i]<<endl;
        }
        else
        if (vm.count("fscore"))
        {
            vector<double> pair;
            f.audit(D);
            pair=f.getFscore();
            for (int i=0;i<D.getfSize();++i)
                *ot<<pair[i]<<endl;
        }
        else
        for (int i=0; i<L.size();++i)
            *ot<<L[i]<<endl;
    }

    LOG(INFO)<<"Finished";
    if (vm.count("cmc"))
    {
        LOG(INFO)<< "CMC accounted over " <<cmc.getCount() << " queries";
        *ot << "CMC"<<endl;
        vector<double> cur = cmc.getAcc();
        for (int i = 0;i<CMC_MAX;++i)
        *ot << cur[i]<<endl;
        *ot << "AVG"<<endl;
        *ot << cmc.getAvg()/D.getqSize() <<endl;
    }
    if (vm.count("output"))
        fout.close();
    dp.close();
    delete rsvm;
    return 0;
}

void getmask(string fname,vector<double> &msk)
{
    ifstream fin;
    int fsize;
    fin.open(fname.c_str());
    fin>>fsize;
    msk.resize(fsize);
    for (int i=0;i<fsize;++i)
        fin>>msk[i];
    fin.close();
}

int main(int argc, char **argv) {
    el::Configurations defaultConf;
    defaultConf.setToDefault();
    // Values are always std::string
    defaultConf.setGlobally(el::ConfigurationType::Format, "%datetime %level %msg");

    // Defining program options
    po::options_description desc("Allowed options");
    desc.add_options()
            ("help,h", "produce help message")
            ("train,T", "training model")
            ("validate,V", "validate model")
            ("predict,P", "use model for prediction")
            ("cmc,C", "enable cmc auditing")
            ("debug,d", "show debug messages")
            ("single,s", "one from a pair")
            ("pair,p","get pair result")
            ("fscore,f","get F-score")
            ("mask,M", po::value<string>(), "set feature mask")
            ("model,m", po::value<string>(), "set input model file")
            ("output,o", po::value<string>(), "set output model/prediction file")
            ("feature,i", po::value<string>(), "set input feature file")
            ("c,c",po::value<double>(),"trades margin size against training error")
            ("iter",po::value<int>(),"iter main")
            ("prec",po::value<double>(),"prec main")
            ("cg_iter",po::value<int>(),"iter conjugate gradient")
            ("cg_prec",po::value<double>(),"prec conjugate gradient")
            ("ls_iter",po::value<int>(),"iter line search")
            ("ls_prec",po::value<double>(),"prec line search");

    // Parsing program options
    po::store(po::parse_command_line(argc, argv, desc), vm);
    po::notify(vm);

    // Print help if necessary
    if (vm.count("help") || !(vm.count("train") || vm.count("validate") || vm.count("predict"))) {
        cout << desc;
        return 0;
    }

    if (!vm.count("debug"))
        defaultConf.setGlobally(el::ConfigurationType::Enabled, "false");
    // default logger uses default configurations
    el::Loggers::reconfigureLogger("default", defaultConf);

    mainFunc mainf;
    RidList::single=vm.count("single")>0;
    if (vm.count("train")) {
        if (vm.count("c")) { C=vm["c"].as<double>(); }
        if (vm.count("iter")) { maxiter=vm["iter"].as<int>(); }
        if (vm.count("prec")) { prec=vm["prec"].as<double>(); }
        if (vm.count("cg_iter")) { cg_maxiter=vm["cg_iter"].as<int>(); }
        if (vm.count("cg_prec")) { cg_prec=vm["cg_prec"].as<double>(); }
        if (vm.count("ls_iter")) { ls_maxiter=vm["ls_iter"].as<int>(); }
        if (vm.count("ls_prec")) { ls_prec=vm["ls_prec"].as<double>(); }
        mainf = &train;
    }
    else if (vm.count("validate")||vm.count("predict")) {
        mainf = &predict;
    }
    else return 0;
    DataProvider* dp;
    if (vm["feature"].as<string>().find(".rid") == string::npos)
        LOG(FATAL)<<"Format no longer supported";
    else
    {
        RidFileDP* tmpdp = new RidFileDP(vm["feature"].as<string>());
        if (vm.count("mask"))
        {
            vector<double> msk;
            getmask(vm["mask"].as<string>(),msk);
            tmpdp->datmask(msk);
        }
        dp = tmpdp;
    }
    mainf(*dp);
    delete dp;
    return 0;
}