package algorithmrepository;


public class ConjugateGradient {
    
    /*
     * Find the minumum of f over the vector of variables x using a Conjugate Gradient
     * method with Newton Raphson 1D minimisation. 
     * 
     * @param f The function - must provide evaluate(), dfdx(), d2fdx2() and 
     */
    public static void Minimise(Function f)
    {
        /* Algorithm taken from 
         * " An Introduction to the Conjugate Gradient Method Without the Agonizing Pain"
         * Edition 1+1/4 - Jonathan Richard Shewchuk - August 4, 1994        
         */
        
        DoubleParam x[] = f.getParams();
        
        double r[],d[],dfdx[]; //working vars            
        double delta_new,delta_0,delta_old; 
        double beta,rd;
        double CGErr,LMErr; //target CG and Line Minimising error
        double sigma_0;
        int i,imax; //CG iteration and maximum
        int jmax; //Line minimising max iterations
        int k,kmax; //CG iterations since restart and maximum       
        int n;
        int a; //indexing variables for the vectors and matricies
        
        n = x.length; //number of paramters
        imax = 50;
        jmax = 10;
        kmax = 10;
        CGErr = 1E-10;
        LMErr = 1E-5;
        sigma_0 = 0.1; //sigma_0 for Secant minimisation only
        
        d = new double[n];  r = new double[n];      
                
        k = 0;
        dfdx = f.dfdx();
        for(a=0;a<n;a++) r[a] = -dfdx[a]; // r = -f'(x)     
        for(a=0;a<n;a++) d[a] = r[a]; // d = r
        
        delta_new = 0; for(a=0;a<n;a++)delta_new += r[a]*r[a];  // delta_new = r.r
        delta_0 = delta_new;
        
        for(i=0;i<imax;i++){  //for each CG iteration           
            
                // NewtonRaphsonMinimise(f,d,jmax,LMErr);
                SecantMinimise(f,d,jmax,LMErr,sigma_0);
            
                dfdx = f.dfdx();            
                for(a=0;a<n;a++)r[a] = -dfdx[a]; // r = -f'(x)
            
                delta_old = delta_new;          
                delta_new = 0; for(a=0;a<n;a++)delta_new += r[a]*r[a];  // delta_new = r.r            
                beta = delta_new / delta_old;
                        
                for(a=0;a<n;a++) d[a] = r[a] + beta * d[a]; // d = r + beta * d
            
                k++;
                rd = 0; for(a=0;a<n;a++) rd += r[a] * d[a]; // rd = r.d
            
                //if we've done kmax iterations or r is // to d, restrart with Steepest Descent direction
                if(k == kmax || rd <= 0){ 
                    //System.out.println("Restarted GC after " + k + "iterations.");
                    for(a=0;a<n;a++) d[a] = r[a]; // d = r
                    k = 0;              
                }
            
                // see if we're within tolerence
                if( delta_new < (CGErr*CGErr*delta_0) )break; 
        }
        
    }
    
    //  use secant method to find the minimum of f(x) in direction d
    private static int SecantMinimise(Function f, double d[], int jmax, double err, double sigma_0){
        double alpha;
        DoubleParam x[] = f.getParams();
        int a,n = x.length;
        double dfdx[];
        double eta,eta_prev;
        double delta_d;
        int j;
        delta_d = 0; for(a=0;a<n;a++) delta_d += d[a]*d[a]; // delta_d = d.d
        
        alpha = - sigma_0;
        
            // get f'(x + sigma_0*d)
            for(a=0;a<n;a++) x[a].set( x[a].get() + sigma_0 * d[a] );           
            dfdx = f.dfdx();
            for(a=0;a<n;a++) x[a].set( x[a].get() - sigma_0 * d[a] );           
        
            eta_prev = 0; for(a=0;a<n;a++) eta_prev += dfdx[a] * d[a]; // eta_prev = f'(x + s0*d).d
            for(j=0;j<jmax;j++){
                dfdx = f.dfdx();
                eta = 0; for(a=0;a<n;a++) eta += dfdx[a] * d[a]; // eta= f'(x).d
                if(eta==eta_prev)break;
                alpha = alpha * eta / (eta_prev - eta);
            
                for(a=0;a<n;a++) // x = x + alpha.d
                    x[a].set( x[a].get() + alpha * d[a] );
            
                eta_prev = eta;
            
                if( (alpha * alpha * delta_d) <= (err * err) )break;
           
            }
            //System.out.println("Secant done over "+ j +" steps, Calculating next CG step...");
        
            return j;
    }
    
    private static int NewtonRaphsonMinimise(Function f, double d[], int jmax, double err){
        double alpha;
        DoubleParam x[] = f.getParams();
        int a,b,n = x.length;
        double dfdx[];
        double d2fdx2[][];
        double delta_d;
        double nom,denom;
        int j;
        delta_d = 0; for(a=0;a<n;a++) delta_d += d[a]*d[a]; // delta_d = d.d
        //use NR algorithm to find the minimum of f(x) in direction d 
            for(j=0;j<jmax;j++){ //for each NR iteration                
                dfdx = f.dfdx();
                d2fdx2 = f.d2fdx2();
                nom = 0; denom = 0;
            
                // alpha = - [ f'(x).d ] / [ d.(f''(x)d) ]
                for(a=0;a<n;a++){  
                    nom += dfdx[a]*d[a];
                    for(b=0;b<n;b++)
                        denom += d[a] * d2fdx2[a][b] * d[b];
                }                
                alpha = -nom / denom;
                
                for(a=0;a<n;a++) // x = x + alpha * d
                x[a].set( x[a].get() + alpha * d[a] );          
                
                // see if we're within tolerence
                if( (alpha * alpha * delta_d) <= (err * err) )break;
            }
            // System.out.println("NR done over "+ j +" steps, Calculating next CG step...");
                    
        return j;
    }
    
    
    public static class DoubleParam {

        double val;
        
        final public double get(){ return val; }
        final public void set(double value){ val = value; }
        
        public DoubleParam(double value){
            val = value;
        }
    }

    public static interface Function {
        
        public void setParams(DoubleParam params[]);
        public DoubleParam [] getParams();

        public double evaluate();
        public double[] dfdx();
        public double[][] d2fdx2();
    }    
}
