package imseProc.proc.transform;

import otherSupport.RandomManager;
import binaryMatrixFile.BinaryMatrixFile;
import jafama.FastMath;
import imseProc.core.Img;
import oneLiners.OneLiners;
import seed.digeom.Function;
import seed.digeom.FunctionND;
import seed.digeom.IDomain;
import seed.digeom.RectangularDomain;
import seed.optimization.ConjugateGradientDirectionFR;
import seed.optimization.CoordinateDescentDirection;
import seed.optimization.CostFunctionDeltaCondition;
import seed.optimization.GoldenSection;
import seed.optimization.HookeAndJeeves;
import seed.optimization.IStoppingCondition;
import seed.optimization.LineSearchOptimizer;
import seed.optimization.MaxIterCondition;
import seed.optimization.NewtonsMethod1D;
import seed.optimization.StoppingCondition;
import seed.optimization.StoppingOr;
import algorithmrepository.Algorithms;
import algorithmrepository.CubicInterpolation2D;

/** Transform between iamge XY (0-1) and some other coordinate AB
 */
public class TransformFitter extends FunctionND {
	/** Range of knot X,Y values in interp when fitting, these can get quite high/low in the extrapolated areas */
	private static final double minP = -10, maxP = +10;
	
	private FeatureTransformCubic xform;
		
	/** sigma of fitting */
	private double sigma = 0.001;
	
	private int nLon, nLat;
			
	public TransformFitter(FeatureTransformCubic xform){
		this.xform = xform;
		this.nLon = xform.getNKnotsLon();
		this.nLat = xform.getNKnotsLat();
		
	}	

	/**
	 * @param fitDataXY		target points in normalised image coords
	 * @param fitDataAB		target points in dest (beam) coords
	 * @param initTransform		3x3 linear transform matrix x,y --> R,Z 
	 * @param doFit
	 */
	public void fit() {
		//first invert to linear transform
		
		xform.initCubicToLinear();
		double p[] = getP();
		
		if(!Double.isNaN(p[0]))
			p = doFit(p);
	
		setP(p);	
		
	}
	
	public void dumpAll(){
		double minA = xform.getMinLon();
		double maxA = xform.getMinLon();
		double minB = xform.getMinLat();
		double maxB = xform.getMinLat();
		
		double mg[][][] = Algorithms.meshgrid(OneLiners.linSpace(minA, maxA, 500), OneLiners.linSpace(minB, maxB, 500));
		double aaL[] = OneLiners.flatten(mg[0]);
		double bbL[] = OneLiners.flatten(mg[1]);
		
		double linXY[][] = xform.latlonToXYLinear(aaL, bbL); 
		
		BinaryMatrixFile.mustWrite("/tmp/Xl.bin", OneLiners.unflatten(linXY[0],500,500), false);
		BinaryMatrixFile.mustWrite("/tmp/Yl.bin", OneLiners.unflatten(linXY[1],500,500), false);
		
		if(xform.isCubicValid()){
			double xyL[][] = xform.latlonToXY(aaL, bbL);
			BinaryMatrixFile.mustWrite("/tmp/X.bin", OneLiners.unflatten(xyL[0],500,500), false);
			BinaryMatrixFile.mustWrite("/tmp/Y.bin", OneLiners.unflatten(xyL[1],500,500), false);
			BinaryMatrixFile.mustWrite("/tmp/Xk.bin", xform.getKnotValsX(), false);
			BinaryMatrixFile.mustWrite("/tmp/Yk.bin", xform.getKnotValsY(), false);
		}		
	}
	
	private double[] doFit(double initP[]){
		//do the fit
		//ConjugateGradientDirectionFR cg = new ConjugateGradientDirectionFR();
		//CoordinateDescentDirection cd = new CoordinateDescentDirection();
		//GoldenSection ls = new GoldenSection(new MaxIterCondition(500));
		//NewtonsMethod1D ls = new NewtonsMethod1D(new MaxIterCondition(500));
		
		//LineSearchOptimizer opt = new LineSearchOptimizer(null, cg, ls);
		
		//gs.setInitialBracketMethod(new BracketingByParameterSpace());		
		HookeAndJeeves opt = new HookeAndJeeves(this);
		
		opt.setObjectiveFunction(this);
		opt.init(initP);
		
		//System.out.println(opt.getCurrentValue());
		//double p2[] = opt.getCurrentPos().clone();
		//System.out.println(eval(p2));
		//p2[2*nPX+2] = 1000;				
		//System.out.println(eval(p2));
		
		int nIters = 50;
		for(int i=0; i < nIters; i++){
			opt.refine();
			
			double p[] = opt.getCurrentPos();
			double cost = opt.getCurrentValue();			
			System.out.println("i=" + i + "\tp22 = "+p[1*nLon+1]+"\tcost=" + cost); 
			//*/
		}

		double finalP[] = opt.getCurrentPos();	
		System.out.println("Cost: init = " + eval(initP) + "\tFinal = " + eval(finalP));
		finalP = opt.getCurrentPos();	
		System.out.println("Cost: init = " + eval(initP) + "\tFinal = " + eval(finalP));
		
		return finalP;
	}

	@Override
	public IDomain getDomain() { return new RectangularDomain(OneLiners.fillArray(minP, nLon*nLat*2), OneLiners.fillArray(maxP, nLon*nLat*2)); }
	
	private double[] getP(){
		double gPX[][] = xform.getKnotValsX();
		double gPY[][] = xform.getKnotValsY();
		double p[] = new double[nLon*nLat*2];
		
		for(int iB=0; iB < nLat; iB++){
			for(int iA=0; iA < nLon; iA++){
				p[2*(iB*nLon + iA)] = gPX[iA][iB];
				p[2*(iB*nLon + iA)+1] = gPY[iA][iB];
			}
		}
		
		return p;
	}
		
	private void setP(double p[]){
		double gPX[][] = xform.getKnotValsX();
		double gPY[][] = xform.getKnotValsY();
		
		for(int iB=0; iB < nLat; iB++){
			for(int iA=0; iA < nLon; iA++){
				gPX[iA][iB] = p[2*(iB*nLon + iA)];
				gPY[iA][iB] = p[2*(iB*nLon + iA)+1];
			}
		}
		
		xform.setKnotValsX(gPX);
		xform.setKnotValsY(gPY);
	}
	
	@Override
	public double eval(double[] x) {			
		setP(x);
		
		double logP = 0;
		double r2p = FastMath.sqrt(2*Math.PI);
		
		for(FeatureTransform.Point p : xform.getPoints()){
			if(p.includeInFit()){
				double fitXY[] = xform.latlonToXY(p.lon, p.lat);
				logP += 0.5 * FastMath.pow2((p.imgX - fitXY[0])/sigma);// - (sigma * r2p);
				logP += 0.5 * FastMath.pow2((p.imgY - fitXY[1])/sigma);// - (sigma * r2p);
			}
		}
		
		return logP;
	}
}
