package imseProc.proc.transform;

import jafama.FastMath;

import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map.Entry;

import mds.AugMDSFetcher;
import oneLiners.OneLiners;

import org.eclipse.swt.widgets.Composite;

import algorithmrepository.Algorithms;
import seed.minerva.aug.mse.AugMSESystem;
import seed.minerva.optics.Util;
import signals.aug.AUGSignal;
import descriptors.aug.AUGSignalDesc;
import imseProc.core.ByteBufferImage;
import imseProc.core.IMSEProc;
import imseProc.core.ImagePipeController;
import imseProc.core.Img;
import imseProc.core.ImgProcPipe;
import imseProc.core.ImgSource;
import imseProc.proc.transform.FeatureTransform.Point;

/** Image transform processor.
 * Translate an image from some arbitrary image coordinates to
 * R,Z of nearest intersection of a pixel's LOS to a given beam axis. 
 * 
 * The SWT controller should be used to select points on the image which can
 * be given 3D points of the vessel. Along with a single 3D observation point
 * on the mirror, and the beam definitions, this should be used to map those
 * points to that line's intersection on the beam plane. 
 * 
 * 
 * @author oliford
 *
 */
public class TransformProcessor extends ImgProcPipe {
		
	private TransformSWTController controller;
	
	private FeatureTransformCubic xform;
	private TransformFitter transformFitter;
			
	/** Beams indices 0-3 for Q1-Q4, 1 for lat/lon, -2 for auto */
	public static final int BEAMSEL_AUTO = -2;
	public static final int BEAMSEL_LATLON = -1;
	public int beamSelection = BEAMSEL_AUTO;
	
	/** Image A,B coordinates (corners are not necessarily the same as the transform corners) */
	private double A0, B0, A1, B1, dA, dB;
	
	private int cubicNX = 2, cubicNY = 2;
	
	/** Unused points, converted from R,Z back to image X,Y */
	private HashMap<FeatureTransform.Point, double[]> backConvertedXY;
	
	/** The average view position.
	 * This is the bit that the transform can't know so has to be entered
	 * as a point in the points table.
	 * 
	 * It can be calculated from the fitted ray-tracer. (see minerva-optics-imse/seed.minerva.apps.imse/TransformMatch) */
	private double viewPos[];
	
	/** XY coords on input image of a given pixel on the output image [beamIdx][i_][i_] */
	private double outputImageX[][][]; 	
	private double outputImageY[][][]; 	
	
	public TransformProcessor() { 
		super(ByteBufferImage.class);
	}
	
	public TransformProcessor(ImgSource source, int selectedIndex) {
		super(ByteBufferImage.class, source, selectedIndex);		
		
	}
	
	@Override
	protected int[] sourceIndices(int outIdx) {		
		return new int[]{ outIdx }; //always 1:1
	}
	
	@Override
	protected void preCalc(boolean settingsHadChanged) {
		super.preCalc(settingsHadChanged);
		
		if(beamSelection == BEAMSEL_AUTO){
			calcBeamIndices();
		}
	}
	
	
	/** Image allocation */
	protected boolean checkOutputSet(int nImagesIn){
		int nImagesOut;
		
		//need to do the transform now, so we know how big the output images have to be
		try{
			calcCoordTransform(inWidth, inHeight);
			if(!xform.isValid())
				throw new RuntimeException("Transform invalid");
			viewPos = xform.getViewPosition();
			calcABAxes();
			calcOutputPixelXY();
			nImagesOut = nImagesIn;
		}catch(RuntimeException err){
			System.err.println("Error during transform init/fit: ");
			err.printStackTrace();
			nImagesOut = 0;
		}		
		
        //allocate or re-use
        ByteBufferImage newOutSet[] = ByteBufferImage.checkBulkAllocation(this, (ByteBufferImage[])imagesOut, outWidth, outHeight, ByteBufferImage.DEPTH_DOUBLE, nImagesOut, ByteOrder.LITTLE_ENDIAN);
        if(newOutSet != imagesOut){
        	imagesOut = newOutSet;
        	return true;
        }
        
        return false;
	}
	
	@Override
	protected boolean doCalc(Img imageOutG, Img[] sourceSet, boolean settingsHadChanged) throws InterruptedException {
		Img imageIn = sourceSet[0];		
		ByteBufferImage imageOut = (ByteBufferImage)imageOutG;
		
		int beamIdx;
		if(beamSelection == BEAMSEL_LATLON){
			beamIdx = 0;
		}else if(beamSelection == BEAMSEL_AUTO){
			//beamIdx = getBeamIndex(imageIn.getSourceIndex());
			beamIdx = (Integer)getImageMetaData("beamIndex", imageIn.getSourceIndex());
			
		}else{
			beamIdx = beamSelection;
		}
				
		imageIn.startReading();									
		try{
			for(int oY=0; oY < outHeight; oY++) {
				for(int oX=0; oX < outWidth; oX++) {
				
					double fX = outputImageX[beamIdx][oY][oX] * inWidth;
					double fY = outputImageY[beamIdx][oY][oX] * inHeight;
				
					if(fX < 0 || fX >= (inWidth-2) ||
						fY < 0 || fY >= (inHeight-2) ){
							//out of limits
						imageOut.setPixelValue(oX, oY, Double.NaN);
						continue;
					}
					
					int iiX = (int)fX;
					int iiY = (int)fY;
					
					fX -= iiX;
					fY -= iiY;
					
					if(fX < 0 || fY < 0 || fX >= 1 || fY >= 1){
						System.err.println("wtf??");
					}
					
					double val = 	(1 - fX) * (1 - fY) * imageIn.getPixelValue(iiX,   iiY) + 
								         fX  * (1 - fY) * imageIn.getPixelValue(iiX+1, iiY) + 
									(1 - fX) *      fY  * imageIn.getPixelValue(iiX,   iiY+1) + 
								         fX  *      fY  * imageIn.getPixelValue(iiX+1, iiY+1) ;
					
					imageOut.setPixelValue(oX, oY, val);
				}
			}
		}finally{
			imageIn.endReading();
		}
		
		return true;
	}
	
	private void calcCoordTransform(int imgWidth, int imgHeight) {
		
		if(xform == null)
			xform = new FeatureTransformCubic();

		xform.calcPointsLatLon();
		xform.calcLinear();
		
		//put these in source metadata list so that any saving of the transform also stores the transform used to generate it
		connectedSource.setSeriesMetaData("Transform/linearXYtoLL", xform.getXYtoLLMat());
		connectedSource.setSeriesMetaData("Transform/linearLLtoXY", xform.getLLtoXYMat());
				
		if(cubicNX >= 1){
			xform.initCubic(cubicNX, cubicNY);
			
			transformFitter = new TransformFitter(xform);
			transformFitter.fit();
			
			connectedSource.setSeriesMetaData("Transform/cubicKnotsX", xform.getKnotValsX());
			connectedSource.setSeriesMetaData("Transform/cubicKnotsY", xform.getKnotValsY());			
		}

		List<FeatureTransform.Point> points = xform.getPoints();
		backConvertedXY = new HashMap<FeatureTransform.Point, double[]>();
		for(FeatureTransform.Point p : points){
			
			if(p.enable != FeatureTransform.EN_IGNORE){					
				double backXY[] = xform.latlonToXY(p.lon, p.lat);
				double AB2[] = xform.xyToLatLon(backXY[0], backXY[1]);					
				double lin[] = xform.latlonToXYLinear(p.lon, p.lat); 						
				System.out.println(p.name + ": Lon: " + p.lon + " --> x="+backXY[0]+" --> lon="+AB2[0] + ", linX = " + lin[0]);
				System.out.println("Lat: " + p.lat + " --> y="+backXY[1]+" --> lat="+AB2[1] + ", linY = " + lin[1]);
				backConvertedXY.put(p, backXY);
			}
		}
		
		// we may have changed the R,Z data
		//saveMapPoints();
		
	}
	
	private void calcOutputPixelXY(){
		
		if(beamSelection == BEAMSEL_LATLON){
			outputImageX = new double[1][outHeight][outWidth];
			outputImageY = new double[1][outHeight][outWidth];
			
			for(int oY=0; oY < outHeight; oY++){
				for(int oX=0; oX < outWidth; oX++){
					double A = A0 + oX * dA;
					double B = B1 - oY * dB;
					
			        double fXY[] = xform.latlonToXY(A, B);					
					outputImageX[0][oY][oX] = fXY[0];
					outputImageY[0][oY][oX] = fXY[1];
				}
			}
			return;
		}
		
		outputImageX = new double[4][][];
		outputImageY = new double[4][][];
		double outputPhiClosest[][][] = new double[4][][];
		
		double avgLon = (xform.getMinLon() + xform.getMaxLon()) / 2; 
		double avgLat = (xform.getMinLat() + xform.getMaxLat()) / 2;
	
		double avgLosVec[] = new double[]{
				FastMath.cos(avgLat) * FastMath.cos(avgLon),
				FastMath.cos(avgLat) * FastMath.sin(avgLon),
				FastMath.sin(avgLat),
		};
		
		for(int iB=0; iB < 4; iB++){
			//only calculate for necessarry beams
			if(beamSelection == BEAMSEL_AUTO || beamSelection == iB){
				outputImageX[iB] = new double[outHeight][outWidth];
				outputImageY[iB] = new double[outHeight][outWidth];
				outputPhiClosest[iB] = new double[outHeight][outWidth];
				
				double nS[] = AugMSESystem.nbiStartAll[iB];
				double nU[] = AugMSESystem.nbiUnitAll[iB];		
				
				double s = Algorithms.pointOnLineNearestAnotherLine(nS, nU, viewPos, avgLosVec);
				
				double p[] = new double[]{
						nS[0] + s * nU[0],
						nS[1] + s * nU[1],
						nS[2] + s * nU[2],
				};
				
				double phi = FastMath.atan2(p[1], p[0]);
		     
				for(int oY=0; oY < outHeight; oY++){
					for(int oX=0; oX < outWidth; oX++){
						//output image coordinates 
						double R = A0 + oX * dA;
						double Z = B1 - oY * dB;
						
						//iterate until we find the phi on this (R,Z) ring that is closest to the beam
				        for(int k=0; k < 100; k++){
				        	double curPos[] = new double[]{
				        			R * FastMath.cos(phi),
				        			R * FastMath.sin(phi),
				        			Z
				        	};
				        	
				        	double u[] = Util.reNorm(Util.minus(curPos, viewPos));
				        	
				        	s = Algorithms.pointOnLineNearestAnotherLine(viewPos, u, nS, nU);
				        	
				        	p = Util.plus(viewPos, Util.mul(u, s));
				        	
				        	double r = FastMath.sqrt(p[0]*p[0] + p[1]*p[1]);
				        	phi = FastMath.atan2(p[1], p[0]);
				        	
				        	double d = FastMath.sqrt(FastMath.pow2(R - r) + FastMath.pow2(Z - p[2]));
				        			
				        	//System.out.println(k+ "\t" + phi*180/Math.PI +"\t " + d);
				        	
				        	if(d < 1e-4)
				        		break;
				        }
				        
				        p = Util.reNorm(Util.minus(p, viewPos));
				        double A = FastMath.atan2(p[1], p[0]);
				        double B = FastMath.asin(p[2]);
				        
				        outputPhiClosest[iB][oY][oX] = phi;
							
				        double fXY[] = xform.latlonToXY(A, B);					
						outputImageX[iB][oY][oX] = fXY[0];
						outputImageY[iB][oY][oX] = fXY[1];
						
					}
				}
			}
		}
		
		//the phi and the view pos is then enough information for anyone else to reconstruct the full 3D geometry
		setSeriesMetaData("Transform/outputPhiClosest", outputPhiClosest);
		setSeriesMetaData("Transform/viewPosition", viewPos);
		
	}
	
	private double[] ABtoRZ(double AB[], int beamIdx){
		
		double u[] = new double[]{
				FastMath.cos(AB[1]) * FastMath.cos(AB[0]),
				FastMath.cos(AB[1]) * FastMath.sin(AB[0]),
				FastMath.sin(AB[1]),		
		};

		double s = Algorithms.pointOnLineNearestAnotherLine(viewPos, u, 
								AugMSESystem.nbiStartAll[beamIdx],
								AugMSESystem.nbiUnitAll[beamIdx]);
			
		double intersectionPos[] = new double[]{
				viewPos[0] + s * u[0],
				viewPos[1] + s * u[1],
				viewPos[2] + s * u[2],
		};
		
		return new double[]{
				FastMath.sqrt(intersectionPos[0]*intersectionPos[0] +
								intersectionPos[1]*intersectionPos[1]),
				intersectionPos[2]
		};		
	}
	
	
	/* don't use this, it's awful
	 * private double[] RZtoAB(double RZ[]){
		double B[] = AugMSESystem.nbiStartAll[beamSelection];
		double v[] = AugMSESystem.nbiUnitAll[beamSelection];
		
		double t[] = Algorithms.cylinderLineIntersection(
				B, v, 
				new double[]{0, 0, 0}, 
				new double[]{0,0,1}, 
				RZ[0]*RZ[0]);
		
		double t0 = Math.min(t[0], t[1]);
		
		double p[] = new double[]{
				B[0] + t0 * v[0],
				B[1] + t0 * v[1],
				B[2] + t0 * v[2],
		};
		
		double phi = FastMath.atan2(p[1], p[0]);

        double u[] = null;
        
        for(int k=0; k < 100; k++){
        	double curPos[] = new double[]{
        			RZ[0] * FastMath.cos(phi),
        			RZ[0] * FastMath.sin(phi),
        			RZ[1]
        	};
        	
        	u = Util.reNorm(Util.minus(curPos, cameraPos));
        	
        	double s = Algorithms.pointOnLineNearestAnotherLine(cameraPos, u, B, v);
        	
        	p = Util.plus(cameraPos, Util.mul(u, s));
        	
        	double r = FastMath.sqrt(p[0]*p[0] + p[1]*p[1]);
        	phi = FastMath.atan2(p[1], p[0]);
        	
        	double d = FastMath.sqrt(FastMath.pow2(RZ[0] - r) + FastMath.pow2(RZ[1] - p[2]));
        			
        	//System.out.println(k+ "\t" + phi*180/Math.PI +"\t" + d);
        	
        	if(d < 1e-4)
        		break;
        }
        
        p = Util.reNorm(Util.minus(p, cameraPos));
        
        return new double[]{ 
        	FastMath.atan2(p[1], p[0]),
        	FastMath.asin(p[2])
        };
	}*/
	
	private void calcABAxes(){
		//def R,Z of a pixel as R,Z in top left corner of that pixel
		double dAB = 0.001; //desired size of pixel
		dA = dAB; //force isometric
		dB = dAB;
		
		//use transform corners as a starting point
		//R1,Z1 is coord of top left corner of the top right pixel  
		double lon0 = xform.getMinLon();
		double lon1 = xform.getMaxLon(); 
		double lat0 = xform.getMinLat();
		double lat1 = xform.getMaxLat();
		
		if(beamSelection == BEAMSEL_LATLON){
			A0 = lon0; A1 = lon1;
			B0 = lat0; B1 = lat1;
		}else{
			//one, or possible all beams. If automatic, do min/max of all
			A0 = Double.POSITIVE_INFINITY; A1 = Double.NEGATIVE_INFINITY;
			B0 = Double.POSITIVE_INFINITY; B1 = Double.NEGATIVE_INFINITY;
			
			for(int iB=0; iB < 4;  iB++){
				if(beamSelection == BEAMSEL_AUTO || beamSelection == iB){
					//if we're going all the way to R,Z on a particular beam....
					double RZ0[] = ABtoRZ(new double[]{ lon0, lat0 }, iB);
					double RZ1[] = ABtoRZ(new double[]{ lon0, lat1 }, iB);
					double RZ2[] = ABtoRZ(new double[]{ lon1, lat0 }, iB);
					double RZ3[] = ABtoRZ(new double[]{ lon1, lat1 }, iB);
					A0 = Math.min(A0, Math.min(Math.min(RZ0[0], RZ1[0]), Math.min(RZ2[0], RZ3[0])));
					A1 = Math.max(A1, Math.max(Math.max(RZ0[0], RZ1[0]), Math.max(RZ2[0], RZ3[0])));
					B0 = Math.min(B0, Math.min(Math.min(RZ0[1], RZ1[1]), Math.min(RZ2[1], RZ3[1])));
					B1 = Math.max(B1, Math.max(Math.max(RZ0[1], RZ1[1]), Math.max(RZ2[1], RZ3[1])));
				}
			}
		}
		
		//image size ought to be 32-bit aligned, for some reason
		outWidth = 4*(int)(((A1 - A0) / dA + 1) / 4);
		outHeight = 4*(int)(((B1 - B0) / dB + 1) / 4);
		A1 = A0 + (outWidth-1) * dA;
		B1 = B0 + (outHeight-1) * dB;
		
		//make axes and write that to the metadata
		double A[] = new double[outWidth];
		for(int i=0; i < outWidth; i++){
			A[i] = A0 + i*dA;
		}
		double B[] = new double[outHeight];
		for(int i=0; i < outHeight; i++){
			B[i] = B1 - i*dB;
		}
		
		if(beamSelection == BEAMSEL_LATLON){
			connectedSource.setSeriesMetaData("/Transform/imageOutR", new double[]{ Double.NaN });
			connectedSource.setSeriesMetaData("/Transform/imageOutZ", new double[]{ Double.NaN });
			for(int i=0; i < outWidth; i++){ A[i] *= 180 / Math.PI; };
			for(int i=0; i < outHeight; i++){ B[i] *= 180 / Math.PI; };
			connectedSource.setSeriesMetaData("/Transform/imageOutLon", A);
			connectedSource.setSeriesMetaData("/Transform/imageOutLat", B);
		}else{
			
			connectedSource.setSeriesMetaData("/Transform/imageOutR", A);
			connectedSource.setSeriesMetaData("/Transform/imageOutZ", B);
			connectedSource.setSeriesMetaData("/Transform/imageOutLon", new double[]{ Double.NaN });
			connectedSource.setSeriesMetaData("/Transform/imageOutLat", new double[]{ Double.NaN });
		}
	}
	

	@Override
	public void notifySourceChanged() {
		super.notifySourceChanged();
		if(autoCalc)
			calc();
	}

	@Override
	public void imageChanged(int idx) { 
 		if(autoCalc)
			calc();
 	}
	
	@Override
	public int getNumImages() {
		return (imagesOut != null) ? imagesOut.length : 0;
	}

	@Override
	public Img getImage(int imgIdx) {
		return (imgIdx >= 0 && imgIdx < imagesOut.length) ? imagesOut[imgIdx] : null;
	}
	
	@Override
	public Img[] getImageSet() { return imagesOut; }

	@Override
	public TransformProcessor clone() { return new TransformProcessor(connectedSource, getSelectedSourceIndex());	}

	@Override
	/** For the sink side */
	public ImagePipeController createPipeController(Class interfacingClass, Object args[], boolean asSink) {
		if(interfacingClass == Composite.class){
			controller = new TransformSWTController((Composite)args[0], (Integer)args[1], this, asSink);
			controllers.add(controller);
			return controller;
		}
		return null;
	}
	
	public double getA0(){ return A0; }
	public double getA1(){ return A1; }
	public double getB0(){ return B0; }
	public double getB1(){ return B1; }
		
	public void setAutoCalc(boolean autoCalc) {
		if(!this.autoCalc && autoCalc){
			this.autoCalc = true;
			updateAllControllers();
			calc();			
		}else{
			this.autoCalc = autoCalc;
			updateAllControllers();
		}
	}
	public boolean getAutoCalc() { return this.autoCalc; }

	public int getCalcProgress() { return calcProgress; }
	
	@Override
	public void setSource(ImgSource source) {
		super.setSource(source);
		if(autoCalc)
			calc();
	}
	
	@Override
	public void destroy() {		
		super.destroy();
	}
	
	public FeatureTransform getTransform(){ return xform; }
	public void pointsMapModified(){
		settingsChanged = true;
		updateAllControllers();
		if(autoCalc)
			calc();
	}

	public void saveTransform(){
		if(connectedSource == null) return;
		
		String experiment = IMSEProc.getMetaExp(connectedSource);
		int pulse = IMSEProc.getMetaPulse(connectedSource);
		
		try{
			xform.saveToSignal(IMSEProc.globalGMDS(), experiment, pulse);
			
		}catch(RuntimeException e){
			System.err.println("Transform: Error saving points to exp " + experiment + 
								", pulse " + pulse + " because " + e.getMessage());
		}
	}

	public void loadMapPoints(int overridePulse) {		
		String experiment = IMSEProc.getMetaExp(connectedSource);
		if(experiment == null || experiment.length() <= 0){
			experiment = "AUG";
		}
		
		int pulse ;
		try{
			pulse = overridePulse >= 0 ? overridePulse : IMSEProc.getMetaPulse(connectedSource);
		}catch(RuntimeException e){
			pulse = 0;
		}
		
		if(xform == null)
			xform = new FeatureTransformCubic();
		
		//try reading from this specific pulse first
		try{			
			xform.loadPoints(IMSEProc.globalGMDS(), experiment, pulse);
			 
		}catch(RuntimeException err){
			System.err.println("Transform: Couldn't load points for exp " + experiment +
					", pulse " + pulse + "because: "+ err.getMessage() + ". Trying pulse 0 instead");

			try{
				xform.loadPoints(IMSEProc.globalGMDS(), experiment, pulse);
				
			}catch(RuntimeException e2) {
				System.err.println("Transform: Couldn't load points from pulse 0.");
				return;
			}
		}
		
		//we don't particularly want to load the actual transform here		

		updateAllControllers();
		if(autoCalc)
			calc();
		
	}
	
	public HashMap<Point, double[]> getBackConvertedXY(){ return backConvertedXY; }

	public void invalidate() {
		settingsChanged = true;
	}

	public void setBeamSelection(int beamSelection) {
		this.beamSelection = beamSelection;
		settingsChanged = true;
		if(autoCalc)
			calc();
	}

	public void setCubicInterp(boolean enable, int nKnotX, int nKnotsY) {
		if(enable){
			cubicNX = nKnotX;
			cubicNY = nKnotsY;
		}else{
			cubicNX = -1;
			cubicNY = -1;
		}
		settingsChanged = true;
		if(autoCalc)
			calc();		
	}

	public int getCubicNX() { return this.cubicNX; }
	public int getCubicNY() { return this.cubicNY; }

	/** Works out the most significant beam on during each frame */
	private void calcBeamIndices() {
		final String powerSignal = "/sig/NIS/PNIQ";
		final double beamOnPowerThreshold = 100e3; //100kW
		
		try{
			//work out beam indices
			double frameTimes[] = (double [])getSeriesMetaData("/time");
			
			Object augPulse = getSeriesMetaData("/aug/pulse");
			if(augPulse == null || !(augPulse instanceof Integer)){
				throw new RuntimeException("No augPulse in metaData");
			}
			
			AUGSignal pwrSig = (AUGSignal)AugMDSFetcher.defaultInstance().getSig(new AUGSignalDesc("/aug/" + ((Integer)augPulse) + "/" + powerSignal));
			double pwrTime[] = (double[])pwrSig.getTVectorAsType(double.class);
			double pwrAll[][][] = (double[][][])pwrSig.getDataAsType(double.class);
			
			int beamIdx[] = new int[frameTimes.length];
			for(int iF=0; iF < frameTimes.length; iF++){
				int idx = OneLiners.getNearestIndex(pwrTime, frameTimes[iF]);
				
				//we prefer, in order, Q3, Q2, Q4, Q1... otherwise default to Q3 (i.e. if it's off)				
				beamIdx[iF] = AugMSESystem.BEAM_Q3;				
				if(pwrAll[0][AugMSESystem.BEAM_Q1][idx] > beamOnPowerThreshold) beamIdx[iF] = AugMSESystem.BEAM_Q1;
				if(pwrAll[0][AugMSESystem.BEAM_Q4][idx] > beamOnPowerThreshold) beamIdx[iF] = AugMSESystem.BEAM_Q4;
				if(pwrAll[0][AugMSESystem.BEAM_Q2][idx] > beamOnPowerThreshold) beamIdx[iF] = AugMSESystem.BEAM_Q2;
				if(pwrAll[0][AugMSESystem.BEAM_Q3][idx] > beamOnPowerThreshold) beamIdx[iF] = AugMSESystem.BEAM_Q3;
				
			}
			
			setSeriesMetaData("beamIndex", beamIdx);
			
		}catch(RuntimeException err){
			System.err.println("WARNING: Can't get power signals for NBI because '"+err.getMessage()+"', defaulting to Q3 for all.");
			int beamIdx[] = OneLiners.fillIntArray(2, getNumImages());
			setSeriesMetaData("beamIndex", beamIdx);
		}
	}
}
