#include "ranksvmtn.h"
#include<iostream>
#include<list>
#include"../tools/matrixIO.h"

using namespace std;
using namespace Eigen;

// Main terminating criteria
const int maxiter = 60; // max iteration count
const double prec=1e-10; // precision
// conjugate gradient
const double cg_prec=1e-10; // precision
const int cg_maxiter = 10;
const int ls_maxiter = 10;
// line search
const double line_prec=1e-10; // precision
const double line_turb=1e-15; // purturbation

int cal_Hs(RidList &D,const vector<int> &rank,const VectorXd &corr,const VectorXd &alpha,const VectorXd s,VectorXd &Hs)
{
    Hs = VectorXd::Zero(s.rows());
    VectorXd Ds(D.getSize());
    for (int i=0;i<D.getSize();++i)
        Ds(i) = D.getVec(i).dot(s);
    VectorXd gamma(D.getSize());
    for (int i=0;i<D.getSize();)
    {
        double g=0;
        for (int j = D.getqSize()-1;j>=0;--j)
            if (corr[rank[i+j]]>0)
                gamma[rank[i+j]]=g;
            else
                g+=Ds[rank[i+j]];
        g=0;
        i+=D.getqSize();
        for (int j = D.getqSize();j>0;--j)
            if (corr[rank[i-j]]<0)
                gamma[rank[i-j]]=g;
            else
                g+=Ds[rank[i-j]];
    }
    VectorXd tmp = alpha.cwiseProduct(Ds)-gamma;
    VectorXd res = 0*s;
    for (int i=0;i<D.getSize();++i)
        res = res + D.getVec(i)*tmp[i];
    Hs = s + C*res;
    return 0;
}

int cg_solve(RidList &D,const vector<int> &rank,const VectorXd &corr,const VectorXd &alph,const VectorXd &b, VectorXd &x)
{
    double alpha,beta,r_1,r_2;
    int iter=0;
    VectorXd q;
    VectorXd Hs;
    cal_Hs(D,rank,corr,alph,x,Hs);
    VectorXd res = b - Hs;
    VectorXd p = res;
    while (1)
    {
        // Non preconditioned version
        r_1 = res.dot(res);
        if (iter)
        LOG(INFO) << "CG iter "<<iter<<", r:"<<r_1;
        if (r_1<cg_prec) // Terminate condition
            break;
        if (iter >= cg_maxiter)
        {
            LOG(INFO) << "CG forced termination by maxiter";
            break;
        }
        if (iter){
            beta = r_1 / r_2;
            p = res + p * beta;
        }
        cal_Hs(D,rank,corr,alph,p,q);
        alpha = r_1/p.dot(q);
        x=x+p*alpha;
        res=res-q*alpha;
        ++iter;
        r_2=r_1;
    }
    return 0;
}

void ranksort(int l,int r,vector<int> &rank,VectorXd &ref)
{
    int i=l,j=r,k;
    double mid=ref(rank[(l+r)>>1]);
    while (i<=j)
    {
        while (ref[rank[i]]<mid) ++i;
        while (ref[rank[j]]>mid) --j;
        if (i<=j)
        {
            k=rank[i];
            rank[i]=rank[j];
            rank[j]=k;
            ++i;
            --j;
        }
    }
    if (j>l)
        ranksort(l,j,rank,ref);
    if (i<r)
        ranksort(i,r,rank,ref);
}

int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,RidList &D,vector<int> &rank,VectorXd &yt,VectorXd &alpha,VectorXd &beta)
{
    long n = dw.rows();
    yt = dw - corr;
    alpha=VectorXd::Zero(n);
    beta=VectorXd::Zero(n);
    for (int i=0;i<n;++i) rank[i]=i;
    for (int i=0;i<D.getSize();i+=D.getqSize())
    {
        int ed=i+D.getqSize()-1;
        ranksort(i,ed,rank,yt);
        double a=0,b=0;
        for (int j=i;j<=ed;++j)
            if (corr[rank[j]]<0)
            {
                alpha[rank[j]]=a;
                beta[rank[j]]=b;
            }
            else
            {
                a+=1;
                b+=yt[rank[j]];
            }
        a=b=0;
        for (int j=ed;j>=i;--j)
            if (corr[rank[j]]>0)
            {
                alpha[rank[j]]=a;
                beta[rank[j]]=b;
            }
            else
            {
                a+=1;
                b+=yt[rank[j]];
            }
    }
}

// line search using newton method
int line_search(const VectorXd &w,RidList &D,const VectorXd &corr,const VectorXd &step,double &t)
{
    VectorXd Dd(D.getSize());
    for (int i=0;i<D.getSize();++i)
        Dd(i) = D.getVec(i).dot(step);
    VectorXd alpha,beta,yt;
    VectorXd grad;
    VectorXd Hs;
    int n=D.getSize();
    vector<int> rank(D.getSize());
    int iter = 0;

    double g,h;
    t = 0;
    while (1)
    {
        grad=w+t*step;
        for (int i=0;i<D.getSize();++i)
            Dd(i) = D.getVec(i).dot(grad);
        cal_alpha_beta(Dd,corr,D,rank,yt,alpha,beta);
        VectorXd tmp = alpha.cwiseProduct(yt)-beta;
        VectorXd res = 0*grad;
        for (int i=0;i<n;++i)
            res = res + D.getVec(i)*tmp[i];
        grad = grad + C*res;
        g = grad.dot(step);
        cal_Hs(D,rank,corr,alpha,step,Hs);
        h = Hs.dot(step);
        g=g+line_turb;
        h = h+line_turb;
        t=t-g/h;
        if (g*g/h<line_prec)
            break;
        ++iter;
        LOG(INFO) << "line search iter "<<iter<<", prec:"<<g*g/h;
        if (iter >= ls_maxiter)
        {
            LOG(INFO) << "line search forced termination by maxiter";
            break;
        }
    }
    return 0;
}

int train_orig(int fsize, RidList &Data,const VectorXd &corr,VectorXd &weight){
    int iter = 0;

    long n=Data.getSize();
    LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n << " Query size:" << Data.getuSize();
    VectorXd grad(fsize);
    VectorXd step(fsize);
    vector<int> rank(n);
    double obj,t;

    VectorXd dw(n);
    VectorXd yt;
    VectorXd alpha,beta;
    while (true)
    {
        iter+=1;
        if (iter> maxiter)
        {
            LOG(INFO)<< "Maxiter reached";
            break;
        }

        for (int i=0;i<n;++i)
            dw(i) = Data.getVec(i).dot(weight);
        cal_alpha_beta(dw,corr,Data,rank,yt,alpha,beta);
        // Generate support vector matrix sv & gradient
        obj = (weight.dot(weight) + C*(alpha.dot(yt.cwiseProduct(yt))-beta.dot(yt)))/2;
        VectorXd tmp = alpha.cwiseProduct(yt)-beta;
        VectorXd res = 0*weight;
        for (int i=0;i<n;++i)
            res = res + Data.getVec(i)*tmp[i];
        grad = weight + C*res;
        step = grad*0;
        // Solve
        cg_solve(Data,rank,corr,alpha,grad,step);
        // do line search
        line_search(weight,Data,corr,step,t);
        weight=weight+step*t;
        // When dec is small enough
        LOG(INFO)<<"Iter: "<<iter<<" Obj: " <<obj << " Newton decr:"<<step.dot(grad)/2 << " linesearch: "<< -t ;
        if (step.dot(grad) < prec * obj)
            break;
    }
    return 0;
}

int RSVMTN::train(RidList &D){
    VectorXd corr(D.getSize());
    vector<int> A1,A2;
    int i,j;
    LOG(INFO)<<"Processing input";
    for (i=0;i<D.getSize();++i)
        corr(i)=D.getL(i)>0?0.5:-0.5;
    train_orig(fsize,D,corr,model.weight);
    return 0;
};

int RSVMTN::predict(DataList &D, vector<double> &res){
    res.clear();
    for (int i=0;i<D.getSize();++i)
        res.push_back(((D.getData()[i])->feature).dot(model.weight));
    return 0;
};