#include "ranksvmtn.h"
#include<iostream>
#include<list>
#include"../tools/matrixIO.h"

using namespace std;
using namespace Eigen;

double RC=0;

void cal_Dw(RidList &D,const VectorXd &w, VectorXd &Dw)
{
    int n = D.getSize();
    // static chunk size of 1 to interleave the iterations
    #pragma omp parallel for schedule(static,1)
    for (int i=0;i<n;++i)
        Dw(i) = D.getVecDot(i,w);
}

void cal_Dtw(RidList &D,const VectorXd &w, VectorXd &Dw)
{
    int n = D.getSize();
    int fsize = D.getfSize();
    #pragma omp parallel shared(D,Dw)
    {
        VectorXd Dw_private = VectorXd::Zero(D.getfSize());
        #pragma omp for nowait
        for (int i=0;i<n;++i)
            D.addVecw(i,w(i),Dw_private);
        #pragma omp critical
        {
            for (int i=0;i<fsize;++i)
                Dw(i) = Dw(i) + Dw_private(i);
        }
    }
}

int cal_Hs(RidList &D,const vector<int> &rank,const VectorXd &corr,const VectorXd &alpha,const VectorXd s,VectorXd &Hs)
{
    int n = D.getSize();
    int q = D.getqSize();
    Hs = VectorXd::Zero(s.rows());
    VectorXd Ds(n);
    cal_Dw(D,s,Ds);
    VectorXd gamma(n);
    for (int i=0;i<n;i+=q)
    {
        double g=0;
        // find B, cal A
        for (int j = q-1;j>=0;--j)
            if (corr[rank[i+j]]>0)
                gamma[rank[i+j]]=g;
            else
                g+=Ds[rank[i+j]];
        g=0;
        // find A, cal B
        for (int j = 0;j<q;++j)
            if (corr[rank[i+j]]<0)
                gamma[rank[i+j]]=g;
            else
                g+=Ds[rank[i+j]];
    }
    VectorXd tmp(n);
    for (int i=0;i<n;++i)
        tmp(i) = alpha(i)*Ds(i)-gamma(i);
    VectorXd res = VectorXd::Zero(D.getfSize());
    cal_Dtw(D,tmp,res);
    Hs = s + RC*res;
    return 0;
}

int cg_solve(RidList &D,const vector<int> &rank,const VectorXd &corr,const VectorXd &alph,const VectorXd &b, VectorXd &x)
{
    double alpha,beta,r_1,r_2;
    int iter=0;
    VectorXd q;
    VectorXd Hs;
    cal_Hs(D,rank,corr,alph,x,Hs);
    VectorXd res = b - Hs;
    // Non preconditioned version
    VectorXd p = res;
    r_1 = res.dot(res);
    while (1)
    {
        cal_Hs(D,rank,corr,alph,p,q);
        alpha = r_1/p.dot(q);
        x=x+p*alpha;
        res=res-q*alpha;
        ++iter;
        r_2=r_1;
        r_1 = res.dot(res);
        LOG(INFO) << "CG iter "<<iter<<", r:"<<r_1;
        if (r_1<cg_prec) // Terminate condition
            break;
        if (iter >= cg_maxiter)
        {
            LOG(INFO) << "CG forced termination by maxiter";
            break;
        }
        beta = r_1 / r_2;
        p = res + p * beta;
    }
    return 0;
}

void ranksort(int l,int r,vector<int> &rank,VectorXd &ref)
{
    int i=l,j=r,k;
    double mid=ref(rank[(l+r)>>1]);
    while (i<=j)
    {
        while (ref[rank[i]]<mid) ++i;
        while (ref[rank[j]]>mid) --j;
        if (i<=j)
        {
            k=rank[i];
            rank[i]=rank[j];
            rank[j]=k;
            ++i;
            --j;
        }
    }
    if (j>l)
        ranksort(l,j,rank,ref);
    if (i<r)
        ranksort(i,r,rank,ref);
}

int cal_alpha_beta(const VectorXd &dw,const VectorXd &corr,RidList &D,vector<int> &rank,VectorXd &yt,VectorXd &alpha,VectorXd &beta)
{
    long n = dw.rows();
    int q = D.getqSize();
    yt = dw - corr;
    alpha=VectorXd::Zero(n);
    beta=VectorXd::Zero(n);
    for (int i=0;i<n;++i) rank[i]=i;
    for (int i=0;i<n;i+=q)
    {
        int ed=i+q-1;
        ranksort(i,ed,rank,yt);
        double a=0,b=0;
        // find A, cal B
        for (int j=i;j<=ed;++j)
            if (corr[rank[j]]<0)
            {
                alpha[rank[j]]=a;
                beta[rank[j]]=b;
            }
            else
            {
                a+=1;
                b+=yt[rank[j]];
            }
        a=b=0;
        // find B, cal A
        for (int j=ed;j>=i;--j)
            if (corr[rank[j]]>0)
            {
                alpha[rank[j]]=a;
                beta[rank[j]]=b;
            }
            else
            {
                a+=1;
                b+=yt[rank[j]];
            }
    }
}

// line search using newton method
int line_search(const VectorXd &w,RidList &D,const VectorXd &corr,const VectorXd &step,double &t)
{
    int n=D.getSize();
    VectorXd Dd(n);
    cal_Dw(D,step,Dd);
    VectorXd alpha,beta,yt;
    VectorXd grad;
    VectorXd Hs;

    vector<int> rank(n);
    int iter = 0;

    double g,h;
    t = 0;
    while (1)
    {
        grad=w+t*step;
        cal_Dw(D,grad,Dd);
        cal_alpha_beta(Dd,corr,D,rank,yt,alpha,beta);
        VectorXd tmp = alpha.cwiseProduct(yt)-beta;
        VectorXd res = VectorXd::Zero(D.getfSize());
        cal_Dtw(D,tmp,res);
        grad = grad + RC*res;
        g = grad.dot(step);
        cal_Hs(D,rank,corr,alpha,step,Hs);
        h = Hs.dot(step);
        g=g+ls_turb;
        h = h+ls_turb;
        t=t-g/h;
        ++iter;
        LOG(INFO) << "line search iter "<<iter<<", prec:"<<g*g/h;
        if (g*g/h<ls_prec)
            break;
        if (iter >= ls_maxiter)
        {
            LOG(INFO) << "line search forced termination by maxiter";
            break;
        }
    }
    return 0;
}

int train_orig(int fsize, RidList &Data,const VectorXd &corr,VectorXd &weight){
    int iter = 0;

    long n=Data.getSize();
    LOG(INFO) << "training with feature size:" << fsize << " Data size:" << Data.getSize() << " Query size:" << Data.getqSize();
    VectorXd grad(fsize);
    VectorXd step(fsize);
    vector<int> rank(n);
    double obj,t,l;

    VectorXd dw(n);
    VectorXd yt;
    VectorXd alpha,beta;
    step = VectorXd::Zero(fsize);
    while (true)
    {
        cal_Dw(Data,weight,dw);
        cal_alpha_beta(dw,corr,Data,rank,yt,alpha,beta);
        // Generate support vector matrix sv & gradient
        l=alpha.dot(yt.cwiseProduct(yt))-beta.dot(yt);
        obj = (weight.dot(weight) + RC*l)/2;
        VectorXd tmp = alpha.cwiseProduct(yt)-beta;
        VectorXd res = VectorXd::Zero(fsize);
        cal_Dtw(Data,tmp,res);
        grad = weight + RC*res;
        // Solve
        cg_solve(Data,rank,corr,alpha,grad,step);
        // do line search
        line_search(weight,Data,corr,step,t);
        weight=weight+step*t;
        // When dec is small enough
        double nprec = step.dot(grad)/obj;
        ++iter;
        LOG(INFO)<<"Iter: "<<iter<<" Obj: " <<obj<< " l: "<< l << " Ndec/Obj:"<<nprec << " linesearch: "<< -t ;
        if (iter>= maxiter)
        {
            LOG(INFO)<< "Maxiter reached";
            break;
        }
        if (nprec < prec)
            break;
    }
    return 0;
}

int RSVMTN::train(RidList &D){
    VectorXd corr(D.getSize());
    int i;
    LOG(INFO)<<"Processing input";
    for (i=0;i<D.getSize();++i)
        corr(i)=D.getL(i)>0?0.5:-0.5;
    RC=2.0*C/D.getSize();
    train_orig(fsize,D,corr,model.weight);
    return 0;
};

int RSVMTN::predict(RidList &D, vector<double> &res){
    res.clear();
    int n = D.getSize();
    for (int i=0;i<n;++i)
    {
        double r=D.getVecDot(i,model.weight);
        res.push_back(r);
    }
    return 0;
};