package algorithmrepository;

import jafama.FastMath;


/**
 * A class for doing 2D cubic interpolation
 * on uniform grids. The cubic interpolator give continous
 * function values and first derivatives across grid points.
 *
 * Both function values, first derivatives and second order mixed derivatives
 * can be interpolated.
 *
 * The interpolator can be initialised either with function values and
 * derivatives at each point, or just function values, in which case the
 * neccessary derivatives at the grid points are calculated from finite
 * differences on the grid.
 *
 * The static methods of the class can be used directly, but the
 * standard way of using it is through:
 *
 * 	CubicInterpolation2D ip = new CubicInterpolation2D(double[] x, double[] y, double[][] f);
 *  ip.eval(3,4);
 *  ip.eval(new double[] {1,2,3 }, new double[] {2,4,8 });
 *  etc.
 *
 *  // Derivatives (dfdx,dfdy,d2fdxdy) can be interpolated through:
 *  double[] der = nwe double[3];
 *  ip.eval(3,4,der);
 *  System.out.println("dfdx(3,4)="+der[0]);
 *  System.out.println("dfdy(3,4)="+der[1]);
 *  System.out.println("d2fdxdy(3,4)="+der[2]);
 *
 *  See http://www.algorithmrepository.org
 *
 */
public class CubicInterpolation2DCallback {
    double[] xp, yp;
    Function2D fp, dfdx, dfdy, d2fdxdy;
    int numx, numy;
    double minx, maxx;
    double miny, maxy;
    double dx, dy;
    double[][][][] coeffs;
    double[] f_cell = new double[4], dfdx_cell=new double[4], dfdy_cell=new double[4], d2fdxdy_cell=new double[4];
    
    public static interface Function2D {
        double eval(double x, double y);
    }


    /**
     * Initialises bicubic interpolation for a uniform grid.
     *
     * Neccessary derivatives are calculated using finite differences
     * from the given function values on the grid.
     *
     * @param xp The x-coordinates of the grid points
     * @param yp The y-coordinates of the grid points.
     * @param fp A callback function that will be called if a new grid cell needs to
     * be precalculated by the interpolator.
     */
    public CubicInterpolation2DCallback(double[] xp, double[] yp, Function2D fp) {
        this(xp,yp,fp,null,null,null);
    }


    /**
     * Initialises bicubic interpolation for a uniform grid when
     * derivatives are of the funcion are known.
     *
     * @param xp The x-coordinates of the grid points
     * @param yp The y-coordinates of the grid points.
     * @param fp A callback function that will be called if a new grid cell needs to
     * be precalculated by the interpolator.
     * @param dfdx dfdx function.
     * @param dfdy  dfdy function.
     * @param d2fdxdy  d2fdxdy function.
     */
    public CubicInterpolation2DCallback(double[] xp, double[] yp, Function2D fp, Function2D dfdx, Function2D dfdy, Function2D d2fdxdy) {
        this.xp = xp;
        this.yp = yp;
        this.fp = fp;
        this.dfdx = dfdx;
        this.dfdy = dfdy;
        this.d2fdxdy = d2fdxdy;
        this.numx = xp.length;
        this.numy = yp.length;
        this.minx = xp[0];
        this.maxx = xp[numx-1];
        this.dx = xp[1]-xp[0];
        this.dy = yp[1]-yp[0];
        this.miny = yp[0];
        this.maxy = yp[numy-1];
        this.coeffs = new double[numx-1][numy-1][][];
    }


    /**
     * Single bicubic interpolation. The bicubic interpolation
     * uses for each point in a 2x2 grid: the function value,
     * the partial derivative of the function in the x-direction,
     * the partial derivative in the y-direction, the mixed second order
     * derivative d2f/dxdy.
     *
     * @param x1 left x-coordinate
     * @param x2 right x-coordinate
     * @param y1 lower y-coordinate
     * @param y2 upper y-coordinate
     * @param fp The four function values corresponding to the given
     * grid points. The order is anticlockwise: lower left, lower right,
     * upper right, upper left.
     * @param dfdx The four first derivatives in the x-direction corresponding
     * to the given grid points. The order is anticlockwise: lower left, lower right,
     * upper right, upper left.
     * @param dfdy The four first derivatives in the y-direction corresponding
     * to the given grid points. The order is anticlockwise: lower left, lower right,
     * upper right, upper left.
     * @param d2fdxdy The four mixed derivatives corresponding
     * to the given grid points. The order is anticlockwise: lower left, lower right,
     * upper right, upper left.
     * @param x The x-coordinate of the point that should be interpolated.
     * @param y The y-coordinate of the point that should be interpolated.
     * @param der An array of derivatives of length 3 that has to be allocated before the
     * call if used:
     * der[0] = dfdx
     * der[1] = dfdy
     * der[2] = d2fdxdy
     *
     * @return The interpolated value of the function at the given point.
     */
    public static double eval(double x1, double x2, double y1, double y2, double[] fp, double[] dfdx, double[] dfdy, double[] d2fdxdy, double x, double y, double[] der) {
        double dx = (x2-x1);
        double dy = (y2-y1);
        double u = (x-x1)/dx;
        double v = (y-y1)/dy;

        double[][] c = cubicInterpolationCoeffs2D(dx, dy, fp, dfdx, dfdy, d2fdxdy);
        double ret = CubicInterpolation2D.interp(u,v,dx,dy,c,der);

        return ret;
    }


    /**
     *
     * Does a bicubic interpolation on uniform grid given in the constructor.
     *
     * Extrapolation: Doesn't do extrapolation, returns -Inf if outside grid.
     *
     * @param x The x-coordinates where the function should be interpolated
     * @param y The y-coordinates where the function should be interpolated
     *
     * @return The interpolated values of the function.
     */
    public double[] eval(double[] x, double[] y) {
        return eval(x,y,null);
    }


    /**
     *
     * Does a bicubic interpolation on uniform grid given in the constructor.
     *
     * Extrapolation: Doesn't do extrapolation, returns -Inf if outside grid.
     *
     * @param x The x-coordinate where the function should be interpolated
     * @param y The y-coordinate where the function should be interpolated
     *
     * @return The interpolated value of the function.
     */
    public double eval(double x, double y) {
        return eval(x,y,null);
    }

    /**
     *
     * Does a bicubic interpolation on uniform grid given in the constructor.
     *
     * Extrapolation: Doesn't do extrapolation, returns -Inf if outside grid.
     *
     * @param x The x-coordinate where the function should be interpolated
     * @param y The y-coordinate where the function should be interpolated
     * @param der On return der contains dfdx, dfdy, d2fdxdy for the interpolation point.
     * If this functionality is sought, der has to be allocated by the caller to der[3].
     *
     * @return The interpolated value of the function.
     */
    public double eval(double x, double y, double[] der) {
        double[][] c;
        int ix = (int) FastMath.floor((numx-1)*(x-minx)/(maxx-minx));
        int iy = (int) FastMath.floor((numy-1)*(y-miny)/(maxy-miny));

        if (x == maxx) ix -= 1;
        if (y == maxy) iy -= 1;

        if (ix >= numx-1 || ix < 0 || iy >= numy-1 || iy < 0) { // Extrapolation to the right
            return Double.NEGATIVE_INFINITY;
        }

        if (coeffs[ix][iy] == null) { // Recalculate parameters for cell
            if (dfdx == null) {
                gridCellDerivatives(xp, yp, fp, ix, iy, f_cell, dfdx_cell, dfdy_cell, d2fdxdy_cell);
            }
            else {
                double xc = minx+dx*ix;
                double yc = miny+dy*iy;
                
                f_cell[0] = fp.eval(xc, yc); 
                f_cell[1] = fp.eval(xc+dx, yc); 
                f_cell[2] = fp.eval(xc+dx, yc+dy); 
                f_cell[3] = fp.eval(xc, yc+dy); 

                dfdx_cell[0] = dfdx.eval(xc, yc);
                dfdx_cell[1] = dfdx.eval(xc+dx, yc); 
                dfdx_cell[2] = dfdx.eval(xc+dx, yc+dy); 
                dfdx_cell[3] = dfdx.eval(xc, yc+dy); 

                dfdy_cell[0] = dfdy.eval(xc, yc); 
                dfdy_cell[1] = dfdy.eval(xc+dx, yc); 
                dfdy_cell[2] = dfdy.eval(xc+dx, yc+dy); 
                dfdy_cell[3] = dfdy.eval(xc, yc+dy); 

                d2fdxdy_cell[0] = d2fdxdy.eval(xc, yc); 
                d2fdxdy_cell[1] = d2fdxdy.eval(xc+dx, yc); 
                d2fdxdy_cell[2] = d2fdxdy.eval(xc+dx, yc+dy); 
                d2fdxdy_cell[3] = d2fdxdy.eval(xc, yc+dy);
            }

            c = cubicInterpolationCoeffs2D(dx,dy,f_cell,dfdx_cell,dfdy_cell,d2fdxdy_cell);
            coeffs[ix][iy] = c;
        }
        else {
            c = coeffs[ix][iy];
        }

        double u = (x-xp[ix])/dx;
        double v = (y-yp[iy])/dy;
        double ret = CubicInterpolation2D.interp(u,v,dx,dy,c,der);

        return ret;
    }




    /**
     *
     * Does a bicubic interpolation on uniform grid given in the constructor.
     *
     * Extrapolation: Doesn't do extrapolation, returns -Inf if outside grid.
     *
     * @param x The x-coordinates where the function should be interpolated
     * @param y The y-coordinates where the function should be interpolated
     * @param der On return der contains dfdx, dfdy, d2fdxdy for each interpolation point.
     * If this functionality is sought, der has to be allocated by the caller to der[numPoints][3],
     * where numPoints is the length of the x-vector above.
     *
     * @return The interpolated values of the function.
     */
    public double[] eval(double[] x, double[] y, double[][] der) {
        double[] ret = new double[x.length];

        if (der == null) {
            for(int i=0;i<x.length;++i) {
                ret[i] = eval(x[i],y[i],null);
            }
        }
        else {
            for(int i=0;i<x.length;++i) {
                ret[i] = eval(x[i],y[i],der[i]);
            }
        }

        return ret;
    }

    /**
     * Precalculates coefficients neccessary for cubic interpolation.
     * This is a convenience method to get the inverted
     * coefficient matrix for the linear equation system in the 16 free
     * parameters of a 3rd degree polynomial in 2 dimensions, from
     * values of the function and its derivatives:
     *
     *    a*c=g
     *
     *    where g = (f00,f01,f11,f01,fu00,fu01,fu11,fu01,fv00,fv01,fv11,fv01,fuv00,fuv01,fuv11,fuv01)'
     *    fij is function value at the i,j corner point of the cell
     *    fuij is dfdu at the i,j corner point of the cell (u is normalised x within the cell)
     *    fvij is dfdv at the i,j corner point of the cell (v is normalised y within the cell)
     *    fuvij is d2fdudv at the i,j corner point of the cell
     *
     *    The coefficients c are then
     *
     *    c = ainv*g
     *
     *    where ainv is the inverse of a.
     *
     * In the return value from this method c is arranged as a 4x4 matrix so that
     * the function values and its derivatives can be calculated as:
     *
     * (u,v are normalised ([0,1]) x,y-coordinates within the cell)
     *
     * f(u,v) = sum_i( sum_j( c_ij*u^(i)*v^(j) ) )
     * dfdu(u,v) = sum_i( sum_j( i*c_ij*u^(i-1)*v^(j) ) )
     * dfdv(u,v) = sum_i( sum_j( j*c_ij*u^(i)*v^(j-1) ) )
     * d2fdudv(u,v) = sum_i( sum_j( i*j*c_ij*u^(i-1)*v^(j-1) ) )
     *
     * Remember: to go from derivatives in (u,v) to derivatives in (x,y) do the
     * following transformation:
     *
     * dfdx = dx*dfdu
     * dfdy = dy*dfdv
     * d2fdxdy = dx*dy*d2fdudv
     *
     * @param dx Extent of grid cell in x-direction
     * @param dy Extent of grid cell in y-direction
     * @param fp The four function values corresponding to the given
     * grid points. The order is anticlockwise: lower left, lower right,
     * upper right, upper left.
     * @param dfdx The four first derivatives in the x-direction corresponding
     * to the given grid points. The order is anticlockwise: lower left, lower right,
     * upper right, upper left.
     * @param dfdy The four first derivatives in the y-direction corresponding
     * to the given grid points. The order is anticlockwise: lower left, lower right,
     * upper right, upper left.
     * @param d2fdxdy The four mixed derivatives corresponding
     * to the given grid points. The order is anticlockwise: lower left, lower right,
     * upper right, upper left.
     * @return A 4x4 matrix of values from which interpolated
     * values and interpolated derivatives can be calculated within the given cell by
     * the formulas above.
     *
     */
    public static double[][] cubicInterpolationCoeffs2D(double dx, double dy, double[] fp, double[] dfdx, double[] dfdy, double[] d2fdxdy) {
        double[][] c = new double[4][4];
        int p,l;
        double dxdy = dx*dy;

        for(int i=0;i<4;++i) {
            for(int j=0;j<4;++j) {
                c[i][j] = 0;
                p = i*4+j;
                l = 0;
                for(int k=0;k<4;++k,++l) {
                    c[i][j] += ainv[p][l]*fp[k];
                }
                for(int k=0;k<4;++k,++l) {
                    c[i][j] += ainv[p][l]*dfdx[k]*dx;
                }
                for(int k=0;k<4;++k,++l) {
                    c[i][j] += ainv[p][l]*dfdy[k]*dy;
                }
                for(int k=0;k<4;++k,++l) {
                    c[i][j] += ainv[p][l]*d2fdxdy[k]*dxdy;
                }
            }
        }
        return c;
    }

    final static int[][] ainv = {
        { 1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 },
        { 0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0 },
        { -3,0,0,3,0,0,0,0,-2,0,0,-1,0,0,0,0 },
        { 2,0,0,-2,0,0,0,0,1,0,0,1,0,0,0,0 },
        { 0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0 },
        { 0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0 },
        { 0,0,0,0,-3,0,0,3,0,0,0,0,-2,0,0,-1 },
        { 0,0,0,0,2,0,0,-2,0,0,0,0,1,0,0,1 },
        { -3,3,0,0,-2,-1,0,0,0,0,0,0,0,0,0,0 },
        { 0,0,0,0,0,0,0,0,-3,3,0,0,-2,-1,0,0 },
        { 9,-9,9,-9,6,3,-3,-6,6,-6,-3,3,4,2,1,2 },
        { -6,6,-6,6,-4,-2,2,4,-3,3,3,-3,-2,-1,-1,-2 },
        { 2,-2,0,0,1,1,0,0,0,0,0,0,0,0,0,0 },
        { 0,0,0,0,0,0,0,0,2,-2,0,0,1,1,0,0 },
        { -6,6,-6,6,-3,-3,3,3,-4,4,2,-2,-2,-2,-1,-1 },
        { 4,-4,4,-4,2,2,-2,-2,2,-2,-2,2,1,1,1,1 } };



    /**
     * Calculates dfdx, dfdy, d2fdxdy on each corner point of a grid cell in
     * a uniform grid.
     *
     * Derivatives for internal points use central differences.
     * Derivatives for boundary points use forward and backward differences.
     *
     * @param xp The x-coordinates of the grid points
     * @param yp The y-coordinates of the grid points.
     * @param fp The function values at the grid points
     * fp[i][j] is the function value at the ith x-coordinate and jth y-coordinate.
     * @param ix Lower x-index of grid cell
     * @param iy Lower y-index of grid cell
     * @param fv On return contains the function values for the corner points with (0,0) at (ix,iy)
     * @param dfdx On return contains dfdx for the corner points with (0,0) at (ix,iy)
     * @param dfdy On return contains dfdy for the corner points with (0,0) at (ix,iy)
     * @param d2fdxdy On return contains d2fdxdy for the corner points with (0,0) at (ix,iy)
     */
    public static void gridCellDerivatives(double[] xp, double[] yp, Function2D fp, int ix, int iy, double[] fv, double[] dfdx, double[] dfdy, double[] d2fdxdy) {
        double[] der = new double[3];

        gridPointDerivatives(xp,yp,fp,ix,iy,der);
        double dx = xp[1]-xp[0];
        double dy = yp[1]-yp[0];
        double xc = xp[0]+dx*ix;
        double yc = yp[0]+dy*iy;
        fv[0] = fp.eval(xc, yc);  
        dfdx[0] = der[0];
        dfdy[0] = der[1];
        d2fdxdy[0] = der[2];

        gridPointDerivatives(xp,yp,fp,ix+1,iy,der);
        fv[1] = fp.eval(xc+dx, yc); 
        dfdx[1] = der[0];
        dfdy[1] = der[1];
        d2fdxdy[1] = der[2];

        gridPointDerivatives(xp,yp,fp,ix+1,iy+1,der);
        fv[2] = fp.eval(xc+dx, yc+dy); 
        dfdx[2] = der[0];
        dfdy[2] = der[1];
        d2fdxdy[2] = der[2];

        gridPointDerivatives(xp,yp,fp,ix,iy+1,der);
        fv[3] = fp.eval(xc, yc+dy); 
        dfdx[3] = der[0];
        dfdy[3] = der[1];
        d2fdxdy[3] = der[2];
    }


    /**
     * Calculates dfdx, dfdy, d2fdxdy for the specified grid point on
     * a uniform grid.
     *
     * Derivatives for internal points use central differences.
     * Derivatives for boundary points use forward and backward differences.
     *
     * @param xp The x-coordinates of the grid points
     * @param yp The y-coordinates of the grid points.
     * @param fp The function values at the grid points
     * fp[i][j] is the function value at the ith x-coordinate and jth y-coordinate.
     * @param ix x-index of grid point
     * @param iy y-index of grid point
     * @param der On return (must be allocated by caller) contains:
     * der[0] = dfdx
     * der[1] = dfdy
     * der[2] = d2fdxdy
     */
    public static void gridPointDerivatives(double[] xp, double[] yp, Function2D fp, int ix, int iy, double[] der) {
        int numx = xp.length;
        int numy = yp.length;
        double dx = xp[1]-xp[0];
        double dy = yp[1]-yp[0];
        double dfdx,dfdy,d2fdxdy;
        
        double xc = xp[0]+ix*dx;
        double yc = yp[0]+iy*dy;
        
        // Do edge derivatives using forward difference
        // and inner derivatives using central difference

        // dfdx
        if (ix == 0) {
            dfdx = (fp.eval(xc+dx, yc) - fp.eval(xc, yc))/dx;
        }
        else if (ix == numx-1) {
            dfdx = (fp.eval(xc, yc) - fp.eval(xc-dx, yc))/dx;

        } else {
            dfdx = (fp.eval(xc+dx, yc) - fp.eval(xc-dx, yc))/(2*dx);
        }

        // dfdy
        if (iy == 0) {
            dfdy = (fp.eval(xc, yc+dy) - fp.eval(xc, yc))/dy;
        }
        else if (iy == numy-1) {
            dfdy = (fp.eval(xc, yc) - fp.eval(xc, yc-dy))/dy;

        } else {
            dfdy = (fp.eval(xc, yc+dy) - fp.eval(xc, yc-dy))/(2*dy);
        }

        // d2fdxdy
        if (ix == 0 && iy != numy-1) {
            d2fdxdy = (fp.eval(xc+dx, yc+dy) - fp.eval(xc, yc+dy) - fp.eval(xc+dx, yc) + fp.eval(xc, yc))/(dx*dy);
        }
        else if (ix == 0 && iy == numy -1) {
            d2fdxdy = (fp.eval(xc+dx, yc) - fp.eval(xc, yc) - fp.eval(xc+dx, yc-dy) + fp.eval(xc, yc-dy))/(dx*dy);
        }
        else if (ix == numx-1 && iy != numy-1) {
            d2fdxdy = (fp.eval(xc, yc+dy) - fp.eval(xc-dx, yc+dy) - fp.eval(xc, yc) + fp.eval(xc-dx, yc))/(dx*dy);
        }
        else if (ix == numx-1 && iy == numy-1) {
            d2fdxdy = (fp.eval(xc, yc) - fp.eval(xc-dx, yc) - fp.eval(xc, yc-dy) + fp.eval(xc-dx, yc-dy))/(dx*dy);
        }
        else if (iy == 0) {
            d2fdxdy = (fp.eval(xc+dx, yc+dy) - fp.eval(xc, yc+dy) - fp.eval(xc+dx, yc) + fp.eval(xc, yc))/(dx*dy);
        }
        else if (iy == numy-1) {
            d2fdxdy = (fp.eval(xc+dx, yc) - fp.eval(xc, yc) - fp.eval(xc+dx, yc-dy) + fp.eval(xc, yc-dy))/(dx*dy);
        }
        else {
            d2fdxdy = (fp.eval(xc+dx, yc+dy) - fp.eval(xc-dx, yc+dy) - fp.eval(xc+dx, yc-dy) + fp.eval(xc-dx, yc-dy))/(4*dx*dy);
        }

        der[0] = dfdx;
        der[1] = dfdy;
        der[2] = d2fdxdy;
    }
    

    /**
     * Returns the vector of x-coordinates from which the grid is 
     * defined.
     * 
     * @return The vector of x-coordinates defining the grid division
     * in the x-direction.
     */
    public double[] getxp() {
        return xp;
    }

    
    /**
     * Returns the vector of y-coordinates from which the grid is 
     * defined.
     * 
     * @return The vector of y-coordinates defining the grid division
     * in the y-direction.
     */    
    public double[] getyp() {
        return yp;
    }
    
    
    /**
     * Returns the grid values function. 
     * 
     * @return The ca;;back function.
     */    
    public Function2D getfp() {
        return fp;
    }
}
