package imseProc.graph.shapeFit;

import imseProc.graph.Series;
import jafama.FastMath;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import algorithmrepository.Algorithms;
import oneLiners.OneLiners;
import seed.digeom.FunctionND;
import seed.digeom.IDomain;
import seed.digeom.IFunction;
import seed.digeom.RectangularDomain;
import seed.optimization.BracketingByParameterSpace;
import seed.optimization.ConjugateGradientDirectionFR;
import seed.optimization.CoordinateDescentDirection;
import seed.optimization.GoldenSection;
import seed.optimization.HookeAndJeeves;
import seed.optimization.LineSearchOptimizer;
import seed.optimization.MaxIterCondition;
import seed.optimization.NewtonsMethod1D;
import edu.emory.mathcs.jtransforms.fft.DoubleFFT_1D;

/** we're looking for something roughly triangular between 0 and +45
 so start by doing a Schmidt trigger type search for the ups and downs
 then use that as windows to mix/max finding and use the extrema
 to find the period and offset of the triangle */	
public class FuncFitter {
	public static final String funcTypeNames[] = new String[]{ "None", "Triangle", "Sawtooth", "Sine", "|Sine|", "Quadratic", "Sine+Gradient" };
	
	public static final int FUNC_NONE = 0;
	public static final int FUNC_TRIANGLE = 1;
	public static final int FUNC_SAWTOOTH = 2;
	public static final int FUNC_SINE = 3;
	public static final int FUNC_ABSSINE = 4;
	public static final int FUNC_QUADRATIC = 5;
	public static final int FUNC_SINESHIFT = 6;
	
	private int funcType;
	private boolean autoFit = false;
	private boolean estimateInitP = false;
	
	private double x[];
	private double data[];
	private double finalFit[];
	
	private boolean paramEnable[];	
	private double initP[];	
	private double finalP[];
	
	private String infoTxt;
	
	private ArrayList<Series> seriesList = new ArrayList<Series>();
		
	public void doFit(Series seriesIn) {
		int oldLen = this.x == null ? -1 : this.x.length;
		int newLen = x == null ? -1 : x.length;
		if(finalFit != null && newLen == oldLen && !autoFit){
			//setInfoText();
			return;
		}
		
		this.x = seriesIn.x;
		this.data = seriesIn.data;
		this.finalFit = null; //invalidate
		
		seriesList.clear();
		
		double rangeX[] = OneLiners.getRange(data);
		double rangeD[] = OneLiners.getRange(data);
		
		//double initP[] = (estimateInitP || this.initP == null) ? null : this.initP.clone();
		if(estimateInitP)
			initP = null;
		double min[], max[];
		switch(funcType){			
			case FUNC_TRIANGLE:
			case FUNC_SAWTOOTH:
				//if(initP == null)
				//	initP = schmidtProc();
				//break;				
			case FUNC_SINE:
			case FUNC_SINESHIFT:
			case FUNC_ABSSINE:
				if(initP == null){
					//initP = fftProc(); //This SIGSEGVs the OpenJDK via SWT later on... err.
					initP = new double[]{
							x[x.length / 2], //wavelength
							0, //phase
							0,
							0,
						};//*/
					initP[2] = (rangeD[1] - rangeD[0]) / 2;
					initP[3] = (rangeD[1] + rangeD[0]) / 2;
				}
				
				min = new double[]{
						0,
						-Math.PI/5,
						-3 * (rangeD[1] - rangeD[0]),
						-5*rangeD[1]/2 + 7*rangeD[0]/2,						
				};
				
				max = new double[]{
						2*(rangeX[1]-rangeX[0]),
						Math.PI/5,
						3 * (rangeD[1] - rangeD[0]),
						7*rangeD[1]/2 + -5*rangeD[0]/2,
				};
				
				if(funcType == FUNC_SINESHIFT){
					if(initP.length < 5){
						initP = Arrays.copyOf(initP, 6);
						initP[4] = 0;
						initP[5] = 0;
					}
					min = Arrays.copyOf(min, 6);
					min[4] = -1;
					min[5] = -1;
					max = Arrays.copyOf(max, 6);
					max[4] = -1;
					max[5] = -1;
				}
				
				break;
				
			case FUNC_QUADRATIC:
				if(initP == null)
					initP = new double[]{
							(rangeX[1] + rangeX[0]) / 2, //x0
							(rangeD[1] + rangeD[0]) / 2, //O(1)
							0, //O(x)
							0.0000//O(x^2)						
						};
				
				min = new double[]{
						rangeX[0],
						-10,
						-10,
						-10,
				};
				
				max = new double[]{
						rangeX[1],
						10,
						10,
						10,
				};
				break;
			default:
				setInfoText();
				return;
					
		}
		
		//make sure everything is the correct length
		String n[] = getParamNames();
		if(initP.length != n.length)
			initP = Arrays.copyOf(initP, n.length);
		if(min.length != n.length)
			min = Arrays.copyOf(min, n.length);
		if(max.length != n.length)
			max = Arrays.copyOf(max, n.length);
		if(paramEnable.length != n.length)
			paramEnable = Arrays.copyOf(paramEnable, n.length);
		
		fitProc(initP, min, max);
		
		setInfoText();
	}
			
	/** FFT based estimate of wavelength and phase */ 
	private double[] fftProc() {
		
		int n = x.length;
		
		double abs[]=null, phs[]=null;
		double meanX=0, meanPhs=0;
		int maxIdx=-1;
				
		for(int k=0; k < 2; k++){
			double fft[];			
			if(k == 0){
				fft = data.clone();
			}else{
				int l = (int)(n / meanX);
				int nL = n / l;
				n = l * nL;
				if(n == 0)
					break;
				fft = Arrays.copyOf(data, n);
			}
				
			//remove the DC component
			//for(int i=0; i < n; i++){
				//fft[i] -= 22.5;
				
				//apply window - pointless as when doing live camera feed, 
				//the discontinuity moves
				//fft[i] *= FastMath.sin(Math.PI*i/(double)n);
			//}
			
			DoubleFFT_1D fft1D = new DoubleFFT_1D(n);
			
			//FFT it
			fft1D.realForward(fft);
			
			//calc abs(fft) for whole fft and find the highest peak
			maxIdx = -1;
			abs = new double[data.length];		
			for(int i=0; i < n/2; i++){
				abs[i] = (fft[i*2]*fft[i*2] + fft[i*2+1]*fft[i*2+1]);
				if(maxIdx < 0 || abs[i] > abs[maxIdx])
					maxIdx = i;
			}
	
			//roughly find the HWHM of the peak
			int hwhm;
			for(hwhm = 0; hwhm <= Math.min(maxIdx, (n/2)-maxIdx); hwhm++){
				double val = (abs[maxIdx-hwhm] + abs[maxIdx+hwhm])/2;
				if(val < abs[maxIdx]/2)
					break;
			}
			
			// calc phase from fft
			phs = new double[n];		
			for(int i=0; i < n/2; i++){
				phs[i] = FastMath.atan2(fft[i*2+1], fft[i*2]);
			}
		
			//calc abs weighted mean position (better estimate of peak centre) inside peak
			//and the abs^2 weighted phase
			meanX = 0;
			double sum = 0;
			double sumImag = 0;
			double sumReal = 0;
			for(int i=Math.max(1, maxIdx-3*hwhm); i < Math.min(n/2, maxIdx+3*hwhm); i++){
				meanX += abs[i] * i;
				sum += abs[i];
				sumReal += fft[i*2] * abs[i]*abs[i];
				sumImag += fft[i*2+1] * abs[i]*abs[i];
			}
			meanX /= sum;
			meanPhs = FastMath.atan2(sumImag, sumReal);
			
			meanPhs += Math.PI/2;
		}
			
		//interpolate central phase
		int i0 = (int)meanX;
		double centralPhs = (1.0 - (meanX - i0)) * phs[i0] + (meanX - i0)*phs[i0+1];
		centralPhs += Math.PI;
		double l = n / meanX;
		
		double offset = (meanPhs + Math.PI/2)*l/2/Math.PI;
		
		//calc the target triangular wave
		double tri[] = new double[data.length];
		double diff[] = new double[data.length];
		for(int i=0; i < n; i++){
			tri[i] = 45 * (1.0 - FastMath.abs(((i + offset)*2/l % 2.0) - 1.0));
			diff[i] = data[i] - tri[i];
			
		}		
		//rescale abs for graph
		double maxAbs = abs[maxIdx];
		double range[] = OneLiners.getRange(data);
		for(int i=0; i < n/2; i++){
			abs[i] *= range[1] / maxAbs;
		}
		
		//if(n != d.length)
			//fft = Arrays.copyOf(fft, d.length);
		int l0 = (int)l;
		double pl = l - l0;
		
		finalFit = tri;
		
		seriesList.add(new Series("Absolute Spectrum", x, abs));
		//auxTraces.add(diff);
		//auxTraces.add(Arrays.copyOf(phs, data.length));
		
		return new double[]{
				(1.0 - pl) * x[l0] + pl * x[l0+1], //wavelen
				meanPhs, //phase
				0, //amp (filled later)
				0, //y0 (filled later)
			};
	}
	
	private double[] schmidtProc() {
		
		int n = x.length;
		
		int thershold = 10; 
		
		boolean up = data[0] >= 22.5;
		
		LinkedList<int[]> extrema = new LinkedList<int[]>();
		
		int minMaxIdx = 0;
		
		for(int i=0; i < n; i++){
			if(!up){ // is down
				if(minMaxIdx >= 0 && data[i] < data[minMaxIdx])
					minMaxIdx = i;
					
				if(data[i] > (45 - thershold)){ //has it gone up?
					if(minMaxIdx >= 0)
						extrema.add(new int[]{ minMaxIdx, 0 });
					up = true;
					minMaxIdx = i;
				}
				
				
			}else{ // is up
				if(minMaxIdx >= 0 && data[i] > data[minMaxIdx])
					minMaxIdx = i;
				
				if(data[i] < thershold){ //has it gone down
					if(minMaxIdx >= 0)
						extrema.add(new int[]{ minMaxIdx, 1 });
					up = false;
					minMaxIdx = i;
				}
			}
		}
		
		double detect[] = new double[n];
				
		double period = 0;
		int lastP[] = null;
		for(int p[] : extrema){
			detect[p[0]] = 10 * p[1] - 5;
			
			if(lastP != null){
				period += (p[0] - lastP[0]);
			}
			lastP = p;
		}		
		
		period /= extrema.size() - 1;
		
		double offset = 0;
		int nInOffset = 0;
		for(int p[] : extrema){
			if(p[0] > period){
				offset += p[0] % period;
				nInOffset++;
			}
		}
		offset /= nInOffset;
		
		double tri[] = new double[n];
		double diff[] = new double[n];
		
		for(int i=0; i < n; i++){
			if(i > offset)
				tri[i] = 45 - 45 * Math.abs(((i - offset) / period) % 2.0 - 1);
			else
				tri[i] = 0;
			
			diff[i] = data[i] - tri[i];
			if(i < period)
				detect[i] += 15;
		}
		detect[(int)offset] += 5;
		
		seriesList.add(new Series("tri", x, tri));
		seriesList.add(new Series("detect", x, detect));
		seriesList.add(new Series("diff", x, diff));
		
		double dXdI = x[1] - x[0]; 
		return new double[]{
				2*period*dXdI, //wavelen
				(2*Math.PI*offset / (2*period)) + Math.PI, // phase
				0, //amp (set later)
				0, //y0  (set later)
			};
		
	}

	private void fitProc(double initP[], double minP[], double maxP[]) {
		if(x.length <= 0 || data.length <= 0)return;
		
		CostF costF = new CostF(funcType, x, data, paramEnable, initP, minP, maxP);
		
		//ConjugateGradientDirectionFR cg = new ConjugateGradientDirectionFR();
		//CoordinateDescentDirection cd = new CoordinateDescentDirection();
		//GoldenSection gs = new GoldenSection(new MaxIterCondition(500));
		//NewtonsMethod1D nr = new NewtonsMethod1D(new MaxIterCondition(500));
		
		//LineSearchOptimizer opt = new LineSearchOptimizer(null, cg, gs);
		
		//gs.setInitialBracketMethod(new BracketingByParameterSpace());		
		HookeAndJeeves opt = new HookeAndJeeves(costF);
		
		opt.setObjectiveFunction(costF);
		opt.init(costF.toSelectedParams(initP));
		
		int nIters = 200;
		for(int i=0; i < nIters; i++){
			opt.refine();
			
			/*double p[] = opt.getCurrentPos();
			double cost = opt.getCurrentValue();			
			System.out.println("i=" + i + "\tl = " + p[0] + "\tphs=" + p[1]*180/Math.PI + "°\tcost=" + cost); 
			//*/
		}

		double pSelected[] = opt.getCurrentPos();
		double pAll[] = costF.toAllParams(pSelected);
		this.finalFit = costF.func(pAll);
		
		System.out.println("Cost: init = " + costF.evalAllParams(initP) + "\tFinal = " + costF.evalAllParams(pAll));
		
		//BinaryMatrixFile.mustWrite("/tmp/blah.bin", new double[][]{ x, data, costF.func(initP), costF.func(p) }, true);
		
		//add these at start, shifting others down				
		seriesList.add(0, new Series("init", x, costF.func(initP))); 
		seriesList.add(1, new Series("final", x, finalFit));
		seriesList.add(2, new Series("diff", x, costF.diff(pAll)));			
		
		this.finalP = pAll;
	}

	public List<Series> getSeriesList(){
		return  (funcType == FUNC_NONE) ? null : seriesList;
	}
	
	
	public String[] getParamNames(){
		switch(funcType){
			case FUNC_TRIANGLE:
			case FUNC_SAWTOOTH:
			case FUNC_SINE:
			case FUNC_ABSSINE:
				return new String[]{ "Wavelength", "Phase", "Amplitude", "Offset" };
			case FUNC_QUADRATIC:
				return new String[]{ "X0", "Y0", "O(x)", "O(x²)" };
			case FUNC_SINESHIFT:
				return new String[]{ "Wavelength", "Phase", "Amplitude", "Offset(0)", "Amp Gradient", "Offset Gradient" };
			default:
				return new String[]{ };
		}
	}
	
	private void setInfoText() {
		if(funcType == FUNC_NONE) {
			infoTxt = "Type: " + funcTypeNames[funcType];  
			return;
		}
		
		if(finalFit == null) {
			infoTxt = "No fit yet";
			return;
		}
		
		switch(funcType){
			case FUNC_TRIANGLE:
			case FUNC_SAWTOOTH:
			case FUNC_SINE:
			case FUNC_SINESHIFT:
			case FUNC_ABSSINE:
				StringBuilder strB = new StringBuilder(2048);
				/*strB.append("Type: " + funcTypeNames[funcType] +  
						"\nWavelength: " + finalP[0] +
						"\nPhase: " + (finalP[1]*180.0/Math.PI) +
						"\nAmp: " + finalP[2] +
						"\nY0: " + finalP[3] + 
						"\n");*/
				
				if(x == null || x.length == 0){
					infoTxt = strB.toString();
					return;
				}
				
				double maxX = Algorithms.max(x);
				int i=-1;
				do{
					strB.append(i + 
							": 0°=" + (i + -finalP[1]/2/Math.PI + 0.00)*finalP[0] +
							", 90°=" + (i + -finalP[1]/2/Math.PI + 0.25)*finalP[0] +
							", 180°=" + (i + -finalP[1]/2/Math.PI + 0.50)*finalP[0] +
							", 270°=" + (i + -finalP[1]/2/Math.PI + 0.75)*finalP[0] + "\n");
					i++;
				}while(i < 10 && ((i-1) + 0.25) * finalP[0] < maxX);
				
				System.out.println(strB.length());
				
				infoTxt = strB.toString();
				break;
				
			case FUNC_QUADRATIC:
				infoTxt = "";
				/*infoTxt = "\nType: " + funcTypeNames[funcType] +
							"\nX0 = " + finalP[0] + 
							"\nY0 = " + finalP[1] + 
							"\nO(x) = " + finalP[2] +
							"\nO(x²) = " + finalP[3];*/ 
				break;
			
		}
	}
	
	public void setFuncType(int funcType){ 
		if(this.funcType != funcType)
			finalFit = null;
		this.funcType = funcType;
	}
	
	public void invalidate(){ this.finalFit = null;	}
	public void setAutoFit(boolean autoFit){ this.autoFit = autoFit;	}	
	public void setEstimateInitParams(boolean estimateInitP){ this.estimateInitP = estimateInitP;	}
	
	public void setInitParams(double initP[]){ this.initP = initP;	}
	public double[] getInitParams(){ return initP; }
	public double[] getFinalParams(){ return finalP; }
	
	public void setParamEnable(boolean paramEnable[]){ this.paramEnable = paramEnable;	}
	
	public boolean needsUpdate(){ return funcType != FUNC_NONE && finalFit == null; }
	
	public String getInfoTxt() { return infoTxt != null ? infoTxt : "Not inited"; }
}
