diff options
| author | Joe Zhao <ztuowen@gmail.com> | 2015-05-13 13:35:03 +0800 | 
|---|---|---|
| committer | Joe Zhao <ztuowen@gmail.com> | 2015-05-13 13:35:03 +0800 | 
| commit | 20587ac550cfcb2d7b3d6ec16e46ba1a8d0af869 (patch) | |
| tree | 8da41db1cef2bcedadeb5769832d95c45ffb7f13 | |
| parent | 62b6b42e27a4972397e94fdbb03e74ac3f5f1244 (diff) | |
| download | ranksvm-20587ac550cfcb2d7b3d6ec16e46ba1a8d0af869.tar.gz ranksvm-20587ac550cfcb2d7b3d6ec16e46ba1a8d0af869.tar.bz2 ranksvm-20587ac550cfcb2d7b3d6ec16e46ba1a8d0af869.zip | |
added split
| -rw-r--r-- | CMakeLists.txt | 9 | ||||
| -rw-r--r-- | split.cpp | 76 | ||||
| -rw-r--r-- | tools/fileDataProvider.cpp | 173 | ||||
| -rw-r--r-- | tools/fileDataProvider.h | 91 | ||||
| -rw-r--r-- | train.cpp (renamed from main.cpp) | 12 | 
5 files changed, 273 insertions, 88 deletions
| diff --git a/CMakeLists.txt b/CMakeLists.txt index 6920572..180456c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,9 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")  FIND_PACKAGE( Boost COMPONENTS program_options REQUIRED )  INCLUDE_DIRECTORIES( ${Boost_INCLUDE_DIR}) -set(SOURCE_FILES main.cpp ./model/ranksvm.cpp ./model/ranksvmtn.cpp ./model/rankaccu.cpp) -add_executable(ranksvm ${SOURCE_FILES} model/rankaccu.h model/ranksvm.h model/ranksvmtn.h tools/dataProvider.h tools/easylogging++.h tools/matrixIO.h tools/fileDataProvider.h) -TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} )
\ No newline at end of file +set(SOURCE_FILES model/ranksvm.cpp model/ranksvmtn.cpp model/rankaccu.cpp tools/fileDataProvider.cpp) +add_executable(ranksvm train.cpp ${SOURCE_FILES} model/rankaccu.h model/ranksvm.h model/ranksvmtn.h tools/dataProvider.h tools/matrixIO.h tools/fileDataProvider.h) +add_executable(split split.cpp ${SOURCE_FILES}) +add_dependencies(ranksvm split) +TARGET_LINK_LIBRARIES( ranksvm ${Boost_LIBRARIES} ) +TARGET_LINK_LIBRARIES( split ${Boost_LIBRARIES})
\ No newline at end of file diff --git a/split.cpp b/split.cpp new file mode 100644 index 0000000..be80545 --- /dev/null +++ b/split.cpp @@ -0,0 +1,76 @@ +// +// Created by joe on 5/13/15. +// + +#include <iostream> +#include <boost/program_options.hpp> +#include "tools/dataProvider.h" +#include "tools/fileDataProvider.h" +#include <vector> +#include <fstream> + +INITIALIZE_EASYLOGGINGPP + +using namespace std; +namespace po = boost::program_options; + +po::variables_map vm; + +int outputRid(vector<DataEntry*> a,int fsize,string fname) +{ +    ofstream fout(fname.c_str()); +    fout<<fsize<<endl; +    for (int i=0;i<a.size();++i) +    { +        fout<< a[i]->qid; +        for (int j=0;j<fsize;++j) +            fout<<" "<< a[i]->feature(j); +        fout<<endl; +    } +    fout<<0; +    fout.close(); +} + +int main(int argc, char **argv) +{ +    el::Configurations defaultConf; +    defaultConf.setToDefault(); +    // Values are always std::string +    defaultConf.set(el::Level::Global,el::ConfigurationType::Enabled, "false"); +    // default logger uses default configurations +    el::Loggers::reconfigureLogger("default", defaultConf); +    po::options_description desc("Allowed options"); +    desc.add_options() +            ("help,h", "produce help message") +            ("query,Q", "Query person count") +            ("count,c", po::value<int>(), "take number") +            ("take,a", po::value<string>(), "set output rid file 1(taken)") +            ("left,b", po::value<string>(), "set output rid file 2(left)") +            ("input,i", po::value<string>(), "set input Rid file"); + +    po::store(po::parse_command_line(argc, argv, desc), vm); +    po::notify(vm); +    // Print help if necessary +    if (vm.count("help")) { +        cout << desc; +        return 0; +    } + +    if (vm.count("query")){ +        RidFileDP dp(vm["input"].as<string>().c_str()); +        dp.open(); +        cout<<dp.getpSize()<<endl; +        dp.close(); +        return 0; +    } + +    RidFileDP dp(vm["input"].as<string>().c_str()); +    vector<DataEntry*> a; +    vector<DataEntry*> b; +    dp.open(); +    dp.take(vm["count"].as<int>(),a,b); +    outputRid(a,dp.getfSize(),vm["take"].as<string>()); +    outputRid(b,dp.getfSize(),vm["left"].as<string>()); +    dp.close(); +    return 0; +}
\ No newline at end of file diff --git a/tools/fileDataProvider.cpp b/tools/fileDataProvider.cpp new file mode 100644 index 0000000..e9b7f3d --- /dev/null +++ b/tools/fileDataProvider.cpp @@ -0,0 +1,173 @@ +// +// Created by joe on 5/13/15. +// + +#include "fileDataProvider.h" +#include <random> +#include <ctime> + +using namespace std; + +mt19937 gen; + +int FileDP::getDataSet(DataList &out){ +    DataEntry* e; +    out.clear(); +    int fsize; +    fin>>fsize; +    LOG(INFO)<<"Feature size:"<<fsize; +    out.setfSize(fsize); +    while (!fin.eof()) { +        e = new DataEntry; +        fin>>e->rank; +        if (e->rank == 0) +        { +            delete e; +            break; +        } +        fin>>e->qid; +        e->feature.resize(fsize); +        for (int i=0;i<fsize;++i) { +            fin>>e->feature(i); +        } +        out.addEntry(e); +    } +    eof=true; +    return 0; +} + +void RidFileDP::readEntries() { +    DataEntry *e; +    int fsize; +    d.clear(); +    fin >> fsize; +    LOG(INFO) << "Feature size:" << fsize; +    d.setfSize(fsize); +    while (!fin.eof()) { +        e = new DataEntry; +        fin >> e->qid; +        if (e->qid == "0") { +            delete e; +            break; +        } +        e->feature.resize(fsize); +        e->rank=-1; +        for (int i = 0; i < fsize; ++i) { +            fin >> e->feature(i); +        } +        d.addEntry(e); +    } +    pos = 0; +    qid = 1; +    read = true; +} + +int RidFileDP::getDataSet(DataList &out){ +    DataEntry *e; +    int fsize; +    if (!read) +        readEntries(); +    out.clear(); +    fsize = d.getfSize(); +    out.setfSize(fsize); +    std::vector<DataEntry*> & dat = d.getData(); +    for (int i=0;i<d.getSize();++i) +        if (i!=pos) +        { +            if (dat[i]->qid == dat[pos]->qid) +            { +                e = new DataEntry; +                e->rank=1; +                dat[i]->rank=qid; +            } +            else +            { +                e = new DataEntry; +                e->rank=-1; +            } +            e->feature.resize(d.getfSize()); +            e->qid=dat[pos]->qid; +            for (int j = 0; j < fsize; ++j) { +                e->feature(j) = fabs(dat[i]->feature(j) -dat[pos]->feature(j)); +            } +            out.addEntry(e); +        } +    dat[pos]->qid=std::to_string(qid); +    ++qid; +    dat[pos]->rank=qid; +    while (pos<dat.size() && dat[pos]->rank!=-1) +        ++pos; +    if (pos==d.getSize()) +        eof = true; +    return 0; +} + +int RidFileDP::getpSize() { +    std::vector<string> p; +    if (!read) +        readEntries(); +    std::vector<DataEntry*> &dat = d.getData(); +    for (int i=0;i<dat.size();++i) +    { +        bool ext=false; +        for (int j=0;j<p.size();++j) +            if (p[j] == dat[i]->qid ) +            { +                ext=true; +                break; +            } +        if (!ext) +            p.push_back(dat[i]->qid); +    } +    return p.size(); +}; + +void scrambler(vector<DataEntry*> &dat) +{ +    DataEntry* e; +    int sz=(int)dat.size(); +    for (int i=0;i<sz;++i) +    { +        int pos = (int)(gen()%(sz-i)); +        e=dat[pos]; +        dat[pos] = dat[sz-i-1]; +        dat[sz-i-1] = e; +    } +} + +void RidFileDP::take(int n,vector<DataEntry*> &a,vector<DataEntry*> &b) +{ +    gen.seed(time(NULL)); +    DataEntry *e; +    if (!read) +        readEntries(); +    vector<DataEntry*> tmp; +    tmp.reserve(d.getSize()); +    a.clear(); +    b.clear(); +    std::vector<DataEntry*> &dat = d.getData(); +    scrambler(tmp); +    for (int i=0;i<dat.size();++i) +        tmp.push_back(dat[i]); +    int pos = 0; +    string qid; +    for (int i=0;i<n;++i) +    { +        while (tmp[pos]==NULL) +            ++pos; +        qid = tmp[pos]->qid; +        a.push_back(tmp[pos]); +        tmp[pos]=NULL; +        for (int j = pos+1; j< tmp.size();++j) +            if (tmp[j]!=NULL &&tmp[j]->qid==qid) +            { +                a.push_back(tmp[j]); +                tmp[j]=NULL; +            } +    } +    for (int i=0;i<tmp.size();++i) +        if (tmp[i]!=NULL) +            b.push_back(tmp[i]); +    scrambler(a); +    scrambler(b); +}
\ No newline at end of file diff --git a/tools/fileDataProvider.h b/tools/fileDataProvider.h index f54a38e..7bea92d 100644 --- a/tools/fileDataProvider.h +++ b/tools/fileDataProvider.h @@ -16,31 +16,7 @@ private:      std::ifstream fin;  public:      FileDP(std::string fn=""):fname(fn){}; -    virtual int getDataSet(DataList &out){ -        DataEntry* e; -        out.clear(); -        int fsize; -        fin>>fsize; -        LOG(INFO)<<"Feature size:"<<fsize; -        out.setfSize(fsize); -        while (!fin.eof()) { -            e = new DataEntry; -            fin>>e->rank; -            if (e->rank == 0) -            { -                delete e; -                break; -            } -            fin>>e->qid; -            e->feature.resize(fsize); -            for (int i=0;i<fsize;++i) { -                fin>>e->feature(i); -            } -            out.addEntry(e); -        } -        eof=true; -        return 0; -    } +    virtual int getDataSet(DataList &out);      virtual int open(){fin.open(fname); eof=false;return 0;};      virtual int close(){fin.close();return 0;};  }; @@ -58,68 +34,13 @@ private:      int qid;  public:      RidFileDP(std::string fn=""):fname(fn){read=false;}; -    virtual int getDataSet(DataList &out){ -        DataEntry *e; -        int fsize; -        if (!read) { -            d.clear(); -            fin >> fsize; -            LOG(INFO) << "Feature size:" << fsize; -            d.setfSize(fsize); -            while (!fin.eof()) { -                e = new DataEntry; -                fin >> e->qid; -                if (e->qid == "0") { -                    delete e; -                    break; -                } -                e->feature.resize(fsize); -                e->rank=-1; -                for (int i = 0; i < fsize; ++i) { -                    fin >> e->feature(i); -                } -                d.addEntry(e); -            } -            pos = 0; -            qid = 1; -            read = true; -        } -        out.clear(); -        fsize = d.getfSize(); -        out.setfSize(fsize); -        std::vector<DataEntry*> & dat = d.getData(); -        for (int i=0;i<d.getSize();++i) -            if (i!=pos) -            { -                if (dat[i]->qid == dat[pos]->qid) -                { -                    e = new DataEntry; -                    e->rank=1; -                    dat[i]->rank=qid; -                } -                else -                { -                    e = new DataEntry; -                    e->rank=-1; -                } -                e->feature.resize(d.getfSize()); -                e->qid=dat[pos]->qid; -                for (int j = 0; j < fsize; ++j) { -                    e->feature(j) = fabs(dat[i]->feature(j) -dat[pos]->feature(j)); -                } -                out.addEntry(e); -            } -        dat[pos]->qid=std::to_string(qid); -        ++qid; -        dat[pos]->rank=qid; -        while (pos<dat.size() && dat[pos]->rank!=-1) -            ++pos; -        if (pos==d.getSize()) -            eof = true; -        return 0; -    } +    void readEntries(); +    int getfSize() { if(!read) readEntries(); return d.getfSize();}; +    int getpSize(); +    virtual int getDataSet(DataList &out);      virtual int open(){fin.open(fname); eof=false;return 0;};      virtual int close(){fin.close(); d.clear();return 0;}; +    void take(int n,std::vector<DataEntry*> &a,std::vector<DataEntry*> &b);  };  #endif
\ No newline at end of file @@ -95,6 +95,11 @@ int predict(DataProvider &dp) {  }  int main(int argc, char **argv) { +    el::Configurations defaultConf; +    defaultConf.setToDefault(); +    // Values are always std::string +    defaultConf.setGlobally(el::ConfigurationType::Format, "%datetime %level %msg"); +      // Defining program options      po::options_description desc("Allowed options");      desc.add_options() @@ -103,6 +108,7 @@ int main(int argc, char **argv) {              ("validate,V", "validate model")              ("predict,P", "use model for prediction")              ("cmc,C", "enable cmc auditing") +            ("debug,d", "show debug messages")              ("model,m", po::value<string>(), "set input model file")              ("output,o", po::value<string>(), "set output model/prediction file")              ("feature,i", po::value<string>(), "set input feature file"); @@ -116,6 +122,12 @@ int main(int argc, char **argv) {          cout << desc;          return 0;      } + +    if (!vm.count("debug")) +        defaultConf.setGlobally(el::ConfigurationType::Enabled, "false"); +    // default logger uses default configurations +    el::Loggers::reconfigureLogger("default", defaultConf); +      mainFunc mainf;      if (vm.count("train")) {          mainf = &train; | 
