#include "ranksvmtn.h"
#include<iostream>
#include<list>
#include"../tools/matrixIO.h"

using namespace std;
using namespace Eigen;

const double C=1e-5; // Compensating & scaling
// Main terminating criteria
const int maxiter = 10; // max iteration count
const double prec=1e-10; // precision
// conjugate gradient
const double cg_prec=1e-10; // precision
const int cg_maxiter = 30;
// line search
const double line_prec=1e-10; // precision
const double line_turb=1e-15; // purturbation

int cal_Hs(const MatrixXd &D,const vector<int> &rank,const VectorXd &corr,const VectorXd &alpha,const vector<int> &A1,const vector<int> &A2,const VectorXd s,VectorXd &Hs)
{
    Hs = VectorXd::Zero(s.rows());
    VectorXd Ds=D*s;
    VectorXd gamma(D.rows());
    for (int i=0;i<A1.size();++i)
    {
        double g=0;
        for (int j = A1[i];j<=A2[i];++j)
            if (corr[rank[j]]<0)
                gamma[rank[j]]=g;
            else
                g+=Ds[rank[j]];
        g=0;
        for (int j = A2[i];j>=A1[i];--j)
            if (corr[rank[j]]>0)
                gamma[rank[j]]=g;
            else
                g+=Ds[rank[j]];
    }
    Hs = s + C*(D.transpose()*(alpha.cwiseProduct(Ds) - gamma));
    return 0;
}

int cg_solve(const MatrixXd &D,const vector<int> &rank,const VectorXd &corr,const VectorXd &alph,const vector<int> &A1,const vector<int> &A2,const VectorXd &b, VectorXd &x)
{
    double alpha,beta,r_1,r_2;
    int step=0;
    VectorXd q;
    VectorXd Hs;
    cal_Hs(D,rank,corr,alph,A1,A2,x,Hs);
    VectorXd res = b - Hs;
    VectorXd p = res;
    while (1)
    {
        // Non preconditioned version
        r_1 = res.dot(res);
        if (r_1<cg_prec) // Terminate condition
            break;
        if (step){
            beta = r_1 / r_2;
            p = res + p * beta;
        }
        cal_Hs(D,rank,corr,alph,A1,A2,p,q);
        alpha = r_1/p.dot(q);
        x=x+p*alpha;
        res=res-q*alpha;
        ++step;
        if (step > cg_maxiter)
        {
            LOG(INFO) << "CG force terminated by maxiter";
            break;
        }
        r_2=r_1;
    }
    return 0;
}

void ranksort(int l,int r,vector<int> &rank,VectorXd &ref)
{
    int i=l,j=r,k;
    double mid=ref(rank[(l+r)>>1]);
    while (i<=j)
    {
        while (ref[rank[i]]<mid) ++i;
        while (ref[rank[j]]>mid) --j;
        if (i<=j)
        {
            k=rank[i];
            rank[i]=rank[j];
            rank[j]=k;
            ++i;
            --j;
        }
    }
    if (j>l)
        ranksort(l,j,rank,ref);
    if (i<r)
        ranksort(i,r,rank,ref);
}

int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,const vector<int> &A1,const vector<int> &A2,vector<int> &rank,VectorXd &yt,VectorXd &alpha,VectorXd &beta)
{
    long n = dw.rows();
    yt = dw - corr;
    alpha=VectorXd::Zero(n);
    beta=VectorXd::Zero(n);
    for (int i=0;i<n;++i) rank[i]=i;
    for (int i=0;i<A1.size();++i)
    {
        ranksort(A1[i],A2[i],rank,yt);
        double a=0,b=0;
        for (int j=A1[i];j<=A2[i];++j)
            if (corr[rank[j]]<0)
            {
                alpha[rank[j]]=a;
                beta[rank[j]]=b;
            }
            else
            {
                a+=1;
                b+=yt[rank[j]];
            }
        a=b=0;
        for (int j=A2[i];j>=A1[i];--j)
            if (corr[rank[j]]>0)
            {
                alpha[rank[j]]=a;
                beta[rank[j]]=b;
            }
            else
            {
                a+=1;
                b+=yt[rank[j]];
            }
    }
}

// line search using newton method
int line_search(const VectorXd &w,const MatrixXd &D,const VectorXd &corr,const vector<int> &A1,const vector<int> &A2,const VectorXd &step,double &t)
{
    VectorXd Dd = D*step;
    VectorXd Xd = VectorXd::Zero(A1.size());
    VectorXd alpha,beta,yt;
    VectorXd grad;
    VectorXd Hs;
    vector<int> rank(D.rows());
    int iter = 0;

    for (int i=0;i<A1.size();++i)
        Xd(i) = Dd(A1[i])-Dd(A2[i]);
    double g,h;
    t = 0;
    while (1)
    {
        grad=w+t*step;
        Dd = D*(w + t*step);
        cal_alpha_beta(Dd,corr,A1,A2,rank,yt,alpha,beta);
        grad = grad + C*(D.transpose()*(alpha.cwiseProduct(yt)-beta));
        g = grad.dot(step);
        cal_Hs(D,rank,corr,alpha,A1,A2,step,Hs);
        h = Hs.dot(step);
        g=g+line_turb;
        h = h+line_turb;
        t=t-g/h;
        if (g*g/h<line_prec)
            break;
        ++iter;
        if (iter > cg_maxiter)
        {
            LOG(INFO) << "line search force terminated by maxiter";
            break;
        }
    }
    return 0;
}

int train_orig(int fsize, MatrixXd &D,const vector<int> &A1,const vector<int> &A2,const VectorXd &corr,VectorXd &weight){
    int iter = 0;

    long n=D.rows();
    LOG(INFO) << "training with feature size:" << fsize << " Data size:" << n << " Query size:" << A1.size();
    VectorXd grad(fsize);
    VectorXd step(fsize);
    vector<int> rank(n);
    double obj,t;

    VectorXd dw = D*weight;
    VectorXd yt;
    VectorXd alpha,beta;
    while (true)
    {
        iter+=1;
        if (iter> maxiter)
        {
            LOG(INFO)<< "Maxiter reached";
            break;
        }

        dw = D*weight;
        cal_alpha_beta(dw,corr,A1,A2,rank,yt,alpha,beta);
        // Generate support vector matrix sv & gradient
        obj = (weight.dot(weight) + C*(alpha.dot(yt.cwiseProduct(yt))-beta.dot(yt)))/2;//
        grad = weight + C*(D.transpose()*(alpha.cwiseProduct(yt)-beta));
        step = grad*0;
        // Solve
        cg_solve(D,rank,corr,alpha,A1,A2,grad,step);
        // do line search
        line_search(weight,D,corr,A1,A2,step,t);
        weight=weight+step*t;
        // When dec is small enough
        LOG(INFO)<<"Iter: "<<iter<<" Obj: " <<obj << " Newton decr:"<<step.dot(grad)/2 << " linesearch: "<< -t ;
        if (step.dot(grad) < prec * obj)
            break;
    }
    return 0;
}

int RSVMTN::train(DataList &D){
    MatrixXd Data(D.getSize(),D.getfSize());
    VectorXd corr(D.getSize());
    vector<int> A1,A2;
    int i,j;
    LOG(INFO)<<"Processing input";
    vector<DataEntry*> &dat = D.getData();
    for (i=0;i<D.getSize();++i) {
        corr(i)=(dat[i])->rank>0?0.5:-0.5;

        for (j = 0; j < D.getfSize(); ++j){
            Data(i, j) = dat[i]->feature(j);}

    }
    i=j=0;
    while (i<D.getSize())
    {
        if ((i+1 == D.getSize())|| dat[i]->qid!=dat[i+1]->qid)
        {
            A1.push_back(j);
            A2.push_back(i);
            j = i+1;
        }
        ++i;
    }
    train_orig(fsize,Data,A1,A2,corr,model.weight);
    return 0;
};

int RSVMTN::predict(DataList &D, vector<double> &res){
    res.clear();
    for (int i=0;i<D.getSize();++i)
        res.push_back(((D.getData()[i])->feature).dot(model.weight));
    return 0;
};