package algorithmrepository;

import jafama.FastMath;

/**
 * A class for doing 1D cubic interpolation on uniform, nonuniform,
 * monotonically increasing and non-monotonically increasing grids.
 *
 * Can also interpolate the continous derivatives.
 *
 * The interpolator can be initialised either with the function
 * values at the grid points, or the function values and known
 * derivatives at the grid points. In the former case, finite
 * differences will be calculated on the grid. In both cases
 * the class can do smooth interpolation of both the function and
 * the first derivative.
 * 
 * Usage:
 * CubicInterpolation1D ip = new CubicInterpolation1D(xp,fp);
 * ip.eval(4.3);
 * 
 * double[] x = new double[] { 1, 2.3, 3.1, 4 };
 * double[] der = new double[x.length];
 * double[] f = ip.eval(x,der);
 * System.out.println("Function and derivative at x=2.3: "+f[1]+" "+der[1]);
 * 
 * ip = new CubicInterpolation1D(xp,fp,dfdx); // dfdx known - higher accuracy 
 *
 * From http://www.algorithmrepository.org
 *
 */
public class CubicInterpolation1D implements Interpolation1D {
    private final static double tolFrac2 = 1e-10; //square tolerance for deciding if things are equally spaced, as a fraction of the actual knot-knot difference
    
    double[] xp, fp, dfdx;
    double[][] coeffs;
    int nump;
    double minx, maxx;
    boolean equallySpaced;
    private double extrapValue = Double.NaN;
    private int extrapolationMode = 0;

    public int getExtrapolationMode() { return extrapolationMode; }
    public double getExtrapolationValue() { return extrapValue; }
    public void setExtrapolation(int extrapolationMode, double extrapolationValue) {
        this.extrapolationMode = extrapolationMode;
        this.extrapValue = extrapolationValue;
    }
    
    /**
     * Initialises a 1D cubic interpolation from a given
     * grid and corresponding function values.
     *
     * The derivatives will be calculated from the grid
     * using central differences for the inner points
     * and forward/backward differences for the edge
     * points.
     *
     * @param xp An array x-coordinates.
     * @param fp The corresponding function values.
     */
    public CubicInterpolation1D(double[] xp, double[] fp) {
        this(xp, fp, null, EXTRAPOLATE_LINEAR, Double.NaN);
    }

    public CubicInterpolation1D(double[] xp, double[] fp, double[] dfdx) {
        this(xp, fp, dfdx, EXTRAPOLATE_LINEAR, Double.NaN);        
    }
    
    /**
     * Initialises a 1D cubic interpolation from a given
     * grid and corresponding function values.
     *
     * The derivatives will be calculated from the grid
     * using central differences for the inner points
     * and forward/backward differences for the edge
     * points.
     *
     * @param xp An array of x-coordinates.
     * @param fp The corresponding function values.
     * @param dfdx Optional array of derivatives. If these derivatives
     * are not given, derivatives will be calculated from finite differences
     * over the given grid.
     */
    public CubicInterpolation1D(double[] xp, double[] fp, double[] dfdx, int extrapolationMode, double extrapValue) {
        
        boolean monotonicallyIncreasing = true;
        for(int i=1;i<xp.length;++i) {
            if (xp[i] < xp[i-1]) {
                monotonicallyIncreasing = false;
                break;
            }
        }
        if (!monotonicallyIncreasing) {
            xp = xp.clone();
            fp = fp.clone();
            if (dfdx != null) dfdx = dfdx.clone();
            int[] index = new int[xp.length];
            Algorithms.quicksort(xp, index);
            Algorithms.order(fp, index);
            if (dfdx != null) Algorithms.quicksort(dfdx, index);            
        }
        
        this.xp = xp;
        this.fp = fp;
        this.dfdx = dfdx;
        this.nump = xp.length;
        this.minx = xp[0];
        this.maxx = xp[nump-1];        
        this.coeffs = new double[nump-1][];
                
        double d2, dist2 = (xp[1]-xp[0])*(xp[1]-xp[0]);
        equallySpaced = true;
        for(int i=2;i<xp.length;++i) {
            d2 = (xp[i]-xp[i-1])*(xp[i]-xp[i-1]);
            if ( (dist2-d2)*(dist2-d2) > (tolFrac2*d2)) {
                equallySpaced = false;
                break;
            }
        }
        
        this.extrapValue = extrapValue;
        this.extrapolationMode = extrapolationMode;
    }

    
    /**
     * Returns true if the points are regarded as equally spaced,
     * false otherwise.
     * 
     * If the points are equally spaced a faster lookup is used.
     * 
     * @return True if the points at which the function is given are
     * equally spaced, false otherwise. 
     */
    public boolean isEquallySpaced() {
        return equallySpaced;
    }
    

    /**
     * Returns the x coordinates.
     * 
     * @return The x coordinates.
     */
    public double[] getX() {
        return xp;
    }
    
    /**
     * Returns the vector of function values.
     * 
     * @return The vector of function values.
     */
    public double[] getF() {
        return fp;
    }

    public void setF(double[] fp) {
        this.fp = fp;
        this.coeffs = new double[nump-1][];        
    }

    /**
     * Does cubic interpolation on an array of points. The grid
     * with function values and optional derivaties have been given
     * in the constructor.
     *
     * Can also interpolate the (continous) derivatives.
     *
     * Extrapolation: Extrapolates linearly in the tangent direction if
     * a given point is outside the grid.
     *
     * See http://www.algorithmrepository.org
     *
     * @param x The x-values where f should be evaluated. Do not have to be
     * in any specific order.
     * @param der If allocated on call (to double[numPoints]) the
     * (continous) first derivative of the function at the
     * interpolation points will be returned.
     * @return The function values at the given points.
     */
    public double[] eval(double[] x, double[] der) {
        double[] ret = new double[x.length];
        double dfa, dfb,u;
        double[] c;

        int left;
        for(int i=0;i<x.length;++i) {
            if (equallySpaced) {
                left = (int)((nump-1)*(x[i]-minx)/(maxx-minx));    
                // we actually should have floor() here because e.g. ((int)-5.5) = 5, not 6
                // However, this is picked up by the anti-jitter loop anyway
            }
            else {
                left = binarySearch(xp,x[i]);
            }
            
            do{
                if (left >= nump-1) { // Extrapolate to the right
                    switch(extrapolationMode){
                    case EXTRAPOLATE_LINEAR:             
                        if (dfdx != null) {
                            dfb = dfdx[nump-1];
                        }
                        else {
                            dfb = (fp[nump-1] - fp[nump-2])/(xp[nump-1] - xp[nump-2]);
                        }
                        ret[i] = fp[nump-1]+(x[i] - xp[nump-1])*dfb;
                        if (der != null) {
                            der[i] = dfb;
                        }
                        break;
                    case EXTRAPOLATE_CONSTANT_VALUE:                 
                        ret[i] = extrapValue;
                        if (der != null)
                            der[i] = 0;
                        break;
                    case EXTRAPOLATE_CONSTANT_END_KNOT:
                        ret[i] = fp[nump-1];
                        if (der != null)
                            der[i] = 0;
                        break;
                    case EXTRAPOLATE_EXCEPTION:
                    default:
                        throw new IllegalArgumentException("Extrapolation with invalid or no extrapolation mode ("+extrapolationMode+").");
                    }      
                    
                }
                else if (left < 0) { // Extrapolate to the left
                    switch(extrapolationMode){
                    case EXTRAPOLATE_LINEAR:             
                        if (dfdx != null) {
                            dfa = dfdx[0];
                        }
                        else {
                            dfa = (fp[1] - fp[0])/(xp[1] - xp[0]);
                        }
                        ret[i] = fp[0]+(x[i] - xp[0])*dfa;
                        if (der != null) {
                            der[i] = dfa;
                        }
                        break;
                    case EXTRAPOLATE_CONSTANT_VALUE:                 
                       ret[i] = extrapValue;
                       if (der != null)
                           der[i] = 0;                   
                       break;
                    case EXTRAPOLATE_CONSTANT_END_KNOT:
                        ret[i] = fp[0];
                        if (der != null)
                            der[i] = 0;
                        break;
                    case EXTRAPOLATE_EXCEPTION:
                    default:
                        throw new IllegalArgumentException("Extrapolation with invalid or no extrapolation mode ("+extrapolationMode+".");
                    }
                }
                else {
                    if(equallySpaced){
                        //the points might not be equally spaced to within tol2
                        // so if we've got it slightly wrong we have to correct.
                        if(x[i] < xp[left]){ left--; continue; }
                        if(x[i] > xp[left+1]){ left++; continue; }
                        
                    }
                    if(left < 0 || left >= (xp.length-1))
                        throw new RuntimeException("Internal error: binary search failed, are the Xs in order?");
                    if(x[i] < xp[left] || x[i] > xp[left+1])
                        throw new RuntimeException("Internal error: binary search failed, are the Xs in order?");

                    if (coeffs[left] == null) { // Calculate coefficients for new grid cell
                        if (dfdx != null) {
                            dfa = dfdx[left];
                            dfb = dfdx[left+1];
                        }
                        else {
                            // Do edge derivatives using forward difference
                            // and inner derivatives using central difference
                            if (left == 0) {
                                dfa = (fp[1] - fp[0])/(xp[1] - xp[0]);
                                dfb = (fp[2] - fp[0])/(xp[2] - xp[0]);
                            }
                            else if (left == nump - 2) {
                                dfa = (fp[left+1] - fp[left-1])/(xp[left+1] - xp[left-1]);
                                dfb = (fp[nump-1] - fp[nump-2])/(xp[nump-1] - xp[nump-2]);
                            }
                            else {
                                dfa = (fp[left+1] - fp[left-1])/(xp[left+1] - xp[left-1]);
                                dfb = (fp[left+2] - fp[left])/(xp[left+2] - xp[left]);
                            }
                        }

                        c = cubicInterpolationCoeffs1D(xp[left],fp[left],dfa,xp[left+1],fp[left+1],dfb);
                        coeffs[left] = c;
                    }
                    else {
                        c = coeffs[left];
                    }
                    double dx = (xp[left+1] - xp[left]);
                    u = (x[i] - xp[left])/dx;
                    ret[i] = c[0]+c[1]*u+c[2]*u*u+c[3]*u*u*u;
                    if (der != null) {
                        der[i] = (c[1]+2*c[2]*u+3*c[3]*u*u)/dx;
                    }
                     
                }
                break;
                
            }while(true); // loop only for tolerence failures
                    
                   
                   
        }

        return ret;
    }


    /**
     * Does cubic interpolation on an array of points. The grid
     * with function values and optional derivaties have been given
     * in the constructor.
     *
     * Extrapolation: Extrapolates linearly in the tangent direction if
     * a given point is outside the grid.
     *
     * See http://www.algorithmrepository.org
     *
     * @param x The x-values where f should be evaluated. Do not have to be
     * in any specific order.
     * @return The function values at the given points.
     */
    public double[] eval(double[] x) {
        return eval(x,null);
    }


    /**
     * Does a single cubic interpolation.
     *
     * Extrapolation: Extrapolates linearly in the tangent direction if
     * the point is outside the grid.
     *
     * See http://www.algorithmrepository.org
     *
     * @param x The x-value where f should be evaluated. 
     * @return The function values at the given points.
     */
    public double eval(double x) {
        return eval(new double[] { x }, null)[0];
    }


    /**
     * Static method exposing the functionality of cubic interpolation.
     * 
     * Does a single interpolation based on 2 grid points, and the derivatives
     * of the function at the grid points. From this a unique 3rd degree
     * polynomial between the two points can be found. 
     *
     * @param xa Left x-coordinate
     * @param fa Function value at left x-coordinate.
     * @param dfa Derivative of function at left x-coordinate.
     * @param xb Right x-coordinate
     * @param fb Function value at right x-coordinate.
     * @param dfb Derivative of function at right x-coordinate.
     * @param x The value at which the function should be interpolated.
     *
     * @return The interpolated value.
     */
    public static double eval(double xa, double fa, double dfa, double xb, double fb, double dfb, double x) {
        double length = xb-xa;
        double u = (x-xa)/length;

        double[] c = cubicInterpolationCoeffs1D(xa,fa,dfa,xb,fb,dfb);
        double ret = c[0]+c[1]*u+c[2]*u*u+c[3]*u*u*u;

        return ret;
    }




    /**
     * Returns precalculated coefficients for further cubic interpolations in the same
     * grid cell.
     *
     * Cubic interpolation in 1D fits ('u' is [0,1] normalised x within a cell):
     *
     * 	sum(ci*u^i), 0<=i<=3
     *
     * from function values and derivatives at each point.
     *
     * The coefficients can thus be calculated from
     *
     *  	a*c=g
     *
     *  where g = (f0,f1,dfdu0,dfdu1), the function values and derivatives
     *  at the left and right grid point within a given cell.
     *
     *  c is then calculated from
     *
     *  	c=ainv*g
     *
     *  which is what this method does.
     *
     *
     * @param xa Left x-coordinate
     * @param fa Function value at left x-coordinate.
     * @param dfa Derivative of function at left x-coordinate.
     * @param xb Right x-coordinate
     * @param fb Function value at right x-coordinate.
     * @param dfb Derivative of function at right x-coordinate.
     * @return ret The returned value of coefficients from which the interpolated value is
     * calculated with
     *
     * 	f(u) = sum( ci*u^i )
     *
     *  Derivatives can be calculated from
     *
     *  dfdu(u) = sum( i*ci^u )
     *
     *  which is related to dfdx by
     *
     *  dfdx = dfdu/dx, dx being the width of a grid cell.
     *
     */
    public static double[] cubicInterpolationCoeffs1D(double xa, double fa, double dfa, double xb, double fb, double dfb) {
        double[] ret = new double[4];
        double c0=0, c1=0, c2=0, c3=0;
        double length = xb-xa;
        dfa *= length;
        dfb *= length;

        c0 += ainv[0][0] * fa;
        c0 += ainv[0][1] * fb;
        c0 += ainv[0][2] * dfa;
        c0 += ainv[0][3] * dfb;

        c1 += ainv[1][0] * fa;
        c1 += ainv[1][1] * fb;
        c1 += ainv[1][2] * dfa;
        c1 += ainv[1][3] * dfb;

        c2 += ainv[2][0] * fa;
        c2 += ainv[2][1] * fb;
        c2 += ainv[2][2] * dfa;
        c2 += ainv[2][3] * dfb;

        c3 += ainv[3][0] * fa;
        c3 += ainv[3][1] * fb;
        c3 += ainv[3][2] * dfa;
        c3 += ainv[3][3] * dfb;

        ret[0] = c0;
        ret[1] = c1;
        ret[2] = c2;
        ret[3] = c3;

        return ret;
    }

    private static final double[][] ainv = {
        { 1, 0, 0, 0 },
        { 0, 0, 1, 0 },
        {-3, 3, -2, -1 },
        { 2, -2, 1, 1 } };

    
    /**
     * Does a binary search (bisection) to find the index of the element in the
     * array x for which x[index] <= x < x[index+1], where x is monotonically
     * increasing.
     *
     * Adapted from numerical recipes ch. 3.4.
     *
     * @param x An array of monotonically increasing values.
     * @param xv The value whose nearest lower index is sought in the array.
     * @return The index that corresponds to x[index] <= x < x[index+1] or
     * -1 if x < x[0], x.length if x > x[x.length-1]
     */
    public static int binarySearch(double[] x, double xv) {
        int num = x.length;
        int ju=num,jm,jl=-1;

        while (ju-jl > 1) {
            jm = (ju+jl) >> 1;
            if (xv >= x[jm]) {
                jl = jm;
            }
            else {
                ju = jm;
            }
        }
        if (xv > x[num-1]) return num;
        if (xv == x[0]) return 0;
        if (xv == x[num-1]) return num-2;

        return jl;
    }

}
