package imseProc.proc.imgFit;

import jafama.FastMath;
import imseProc.core.Img;
import oneLiners.OneLiners;
import seed.digeom.FunctionND;
import seed.digeom.IDomain;
import seed.digeom.RectangularDomain;
import seed.optimization.ConjugateGradientDirectionFR;
import seed.optimization.CoordinateDescentDirection;
import seed.optimization.GoldenSection;
import seed.optimization.HookeAndJeeves;
import seed.optimization.LineSearchOptimizer;
import seed.optimization.MaxIterCondition;
import seed.optimization.NewtonsMethod1D;
import algorithmrepository.CubicInterpolation2D;

/** Actual fitter for the 2D image cubic interp fitter */ 
public class CubicMaskedImageFitter extends FunctionND {
	
	private CubicInterpolation2D interp;
	private boolean mask[][];
	private double minP, maxP;
	private int imgWidth, imgHeight;
	private int nPX = -1, nPY = -1;
	private Img targetImage;
	private double sigma;

	public CubicMaskedImageFitter() { }
	
	public void setSize(int nPX, int nPY) {
		if(nPX != this.nPX || nPY != this.nPY){
			this.nPX = nPX;
			this.nPY = nPY;
			double x[] = OneLiners.linSpace(0, 1, nPX);
			double y[] = OneLiners.linSpace(0, 1, nPY);
			interp = new CubicInterpolation2D(x, y, new double[nPX][nPY]);
			setP(OneLiners.fillArray(1, nPX*nPY));
		}
		
	}
	
	public double[][] getKnotVals(){ return interp.getfp(); }

	public void fit(Img imageIn, boolean init, boolean doFit) {
		if(imageIn == null)
			return;
		
		imgWidth = imageIn.getWidth();
		imgHeight = imageIn.getHeight();
		
		//set max param values 100% of range above/below min/max 
		double range = imageIn.getMax() - imageIn.getMin();
		minP = imageIn.getMin() - range;
		maxP = imageIn.getMax() + range;
		sigma = range / 10;
		
		this.targetImage = imageIn;
		
		double p[] = init ? initToAverage() : getP();
		
		if(!Double.isNaN(p[0]) && doFit)
			p = doFit(p);
	
		setP(p);

		this.targetImage = null;
	}
	
	
	/** Set the interp function to the average of the target image near that point */
	private double[] initToAverage(){
		int n = 0;
		double fullAvg = 0;
		for(int iY=0; iY < imgHeight; iY++){
			for(int iX=0; iX < imgWidth; iX++){
				double val = targetImage.getPixelValue(iX, iY);
				if(!Double.isNaN(val)){
					fullAvg += val;
					n++;
				}
			}
		}
		if(n == 0){
			System.err.println("CubicMaskedImageFit: Entire input image is NaN");
			return OneLiners.fillArray(Double.NaN, nPX * nPY);
		}		
		fullAvg /= n;
		
		
		double p[] = new double[nPX * nPY];
		boolean noData = true;
		for(int pY = 0; pY < nPY; pY++) {
			for(int pX = 0; pX < nPX; pX++) {
				int iX0 = (int)(((double)pX - 0.5) / (nPX-1) * imgWidth - 0.5);
				int iY0 = (int)(((double)pY - 0.5) / (nPY-1) * imgHeight - 0.5);
				int iX1 = (int)(((double)pX + 0.5) / (nPX-1) * imgWidth + 0.5);
				int iY1 = (int)(((double)pY + 0.5) / (nPY-1) * imgHeight + 0.5);
				if(iX0 < 0) iX0 = 0;
				if(iX1 >= imgWidth) iX1 = imgWidth - 1;
				if(iY0 < 0) iY0 = 0;
				if(iY1 >= imgHeight) iY1 = imgHeight - 1;
				
				n = 0;
				for(int iY=iY0; iY <= iY1; iY++){
					for(int iX=iX0; iX <= iX1; iX++){
						double val = targetImage.getPixelValue(iX, iY);
						if(!Double.isNaN(val)){
							p[pY*nPX + pX] += val;
							n++;
						}
					}
				}
				if(n > 0){
					noData = false;				
					p[pY*nPX + pX] /= n;
				}else{
					p[pY*nPX + pX] = fullAvg;
				}		
			}				
		}
		
		return noData ? OneLiners.fillArray(Double.NaN, p.length) : p;
	}
	
	private double[] doFit(double initP[]){
		//do the fit
		//ConjugateGradientDirectionFR cg = new ConjugateGradientDirectionFR();
		//CoordinateDescentDirection cd = new CoordinateDescentDirection();
		//GoldenSection ls = new GoldenSection(new MaxIterCondition(500));
		//NewtonsMethod1D ls = new NewtonsMethod1D(new MaxIterCondition(500));
		
		//LineSearchOptimizer opt = new LineSearchOptimizer(null, cg, ls);
		
		//gs.setInitialBracketMethod(new BracketingByParameterSpace());		
		HookeAndJeeves opt = new HookeAndJeeves(this);
		
		opt.setObjectiveFunction(this);
		opt.init(initP);
		
		//System.out.println(opt.getCurrentValue());
		//double p2[] = opt.getCurrentPos().clone();
		//System.out.println(eval(p2));
		//p2[2*nPX+2] = 1000;				
		//System.out.println(eval(p2));
		
		int nIters = 50;
		for(int i=0; i < nIters; i++){
			opt.refine();
			
			double p[] = opt.getCurrentPos();
			double cost = opt.getCurrentValue();			
			//System.out.println("i=" + i + "\tp22 = "+p[1*nPX+1]+"\tcost=" + cost); 
			//*/
		}

		double finalP[] = opt.getCurrentPos();	
		System.out.println("Cost: init = " + eval(initP) + "\tFinal = " + eval(finalP));
		finalP = opt.getCurrentPos();	
		System.out.println("Cost: init = " + eval(initP) + "\tFinal = " + eval(finalP));
		
		return finalP;
	}

	public void setMask(boolean[][] mask) { this.mask = mask; }
	
	@Override
	public IDomain getDomain() { return new RectangularDomain(OneLiners.fillArray(minP, nPX*nPY), OneLiners.fillArray(maxP, nPX*nPY)); }
	
	public double evalFittedImage(int iX, int iY) {  
		return interp.eval((double)iX / imgWidth, (double)iY / imgHeight);
	}
	
	private double[] getP(){
		double gP[][] = interp.getfp();
		double p[] = new double[nPX*nPY];
		
		for(int y=0; y < nPY; y++)
			for(int x=0; x < nPX; x++)
				p[y*nPX + x] = gP[x][y];
		
		return p;
	}
	
	
	private void setP(double p[]){
		double gP[][] = interp.getfp();
		
		for(int y=0; y < nPY; y++)
			for(int x=0; x < nPX; x++)
				gP[x][y] = p[y*nPX + x];
		
		interp.setfp(gP);
	}
	
	@Override
	public double eval(double[] x) {		
		int nSkip = 3;
		setP(x);
		
		double logP = 0;
		double r2p = FastMath.sqrt(2*Math.PI);
		double sigmaMaskToMean = 20*sigma;
		
		double meanVal = 0;
		int n=0;
		for(int iY=0; iY < imgHeight; iY+=nSkip	){
			for(int iX=0; iX < imgWidth; iX+=nSkip){
				double val = targetImage.getPixelValue(iX, iY);
				if(!Double.isNaN(val)){
					meanVal += val;
					n++;
				}
			}
		}
		meanVal /= n;
		
		for(int iY=0; iY < imgHeight; iY+=nSkip	){
			for(int iX=0; iX < imgWidth; iX+=nSkip){
				
				int mX = (int)((double)iX * mask[0].length / imgWidth + 0.5);
				int mY = (int)((double)iY * mask.length / imgHeight + 0.5);
				if(mX < 0 || mY < 0 || mX >= mask[0].length || mY >= mask.length)
					continue;
				
				double targVal, targSigma;
				if(!mask[mY][mX]){
					//targVal = meanVal;
					//targSigma = sigmaMaskToMean;
					continue;
				}else{
					targVal = targetImage.getPixelValue(iX, iY);
					targSigma = sigma;
				}
								
				if(!Double.isNaN(targVal)){
					double fitVal = interp.eval((double)iX / imgWidth, (double)iY / imgHeight);
	                logP += 0.5 * FastMath.pow2((targVal - fitVal)/targSigma);// - (sigma * r2p);
				}
			}
		}			
		
		for(int i=0; i < nPX*nPY; i++){
				logP += 0.5 * FastMath.pow2((x[i] - meanVal)/sigmaMaskToMean);// - (sigma * r2p);                
		}
		
		return logP;
	}
}
