package seed.minerva.apps.imse;

import imseProc.proc.transform.FeatureTransform;
import imseProc.proc.transform.FeatureTransformCubic;
import jafama.FastMath;

import java.awt.color.CMMException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import mds.GMDSFetcher;
import descriptors.gmds.GMDSSignalDesc;

import oneLiners.OneLiners;
import otherSupport.ColorMaps;
import otherSupport.SettingsManager;
import otherSupport.StatusOutput;
import seed.digeom.FunctionND;
import seed.minerva.MinervaOpticsSettings;
import seed.minerva.aug.mse.AugMSESystem;
import seed.minerva.aug.mse.AugMseFaroData;
import seed.minerva.imse.IMSEOptics135_50;
import seed.minerva.imse.IMSEOptics135_50_rescaled;
import seed.minerva.optics.Util;
import seed.minerva.optics.collection.HitPositionAverage;
import seed.minerva.optics.collection.HitsCollector;
import seed.minerva.optics.collection.IntensityInfo;
import seed.minerva.optics.drawing.SVGCylindricalProjection;
import seed.minerva.optics.drawing.SVGRayDrawing;
import seed.minerva.optics.drawing.VRMLDrawer;
import seed.minerva.optics.interfaces.NullInterface;
import seed.minerva.optics.interfaces.Reflector;
import seed.minerva.optics.optics.Box;
import seed.minerva.optics.optimisation.OptimiseOptic;
import seed.minerva.optics.surfaces.Iris;
import seed.minerva.optics.surfaces.Plane;
import seed.minerva.optics.surfaces.Square;
import seed.minerva.optics.tracer.Tracer;
import seed.minerva.optics.types.Element;
import seed.minerva.optics.types.Intersection;
import seed.minerva.optics.types.Optic;
import seed.minerva.optics.types.Pol;
import seed.minerva.optics.types.RaySegment;
import seed.optimization.BracketingByParameterSpace;
import seed.optimization.ConjugateGradientDirectionFR;
import seed.optimization.GoldenSection;
import seed.optimization.IStoppingCondition;
import seed.optimization.LineSearchOptimizer;
import seed.optimization.MaxIterCondition;
import seed.optimization.Optimizer;
import seed.optimization.genetic.MSHGPPConfig;
import seed.optimization.genetic.MegaSuperHyperGeneticProblemPacifier;
import signals.gmds.GMDSSignal;
import algorithmrepository.Algorithms;
import binaryMatrixFile.BinaryMatrixWriter;

/** Compare/fit ray tracing of known positions to transform from experimental IMSE image 
 * See transformMatch.py
 * */
public class TransformMatch {

	private static String outPath;
	private final static int nRays = 30; //per point, for averaging
	private final static int maxAttempts = 10000;
	private final static int nRaysDraw = 1;
	//private final static double p0 = 0.27; // positions along beam for beam axis
	//private final static double p1 = 0.81; // 
	private final static double p0 = 0.2; // positions along beam for beam axis
	private final static double p1 = 0.85; // 
	private final static int nBeamPoints = 20;
	private final static int nImgLimitPoints = 20;
	private final static int nHoleRimPoints = 20;
	//*/
	
	// IMSE or AUG pulse number to read transform from GMDS data
	//private final static int pulse = 178; // Jan2013 primary
	private final static int pulse = 351; // Apr2013 primary
	private final static String experiment = "AUG";
	
	/** What rays to trace */
	public static double minIntensity = 0.01;
	public static boolean traceReflections = false;
	
	private static final double svgRayWidth = 0.0001;
	private static final double svgOpticWidth = 0.0003;
	
	/** Optical system setup */
	private static AugMSESystem sys = new AugMSESystem(new IMSEOptics135_50_rescaled());
	private static AugMseFaroData faro = new AugMseFaroData();
	
	private static Plane imagePlane = sys.imseOptics.ccd;
	private static Element initRaysTarget = sys.mainMirror; //sys.mirrorBox.holeGlassFront;
	
	
	private static double cols[][];
	
	private final static double wavelen = 653e-9;
	
	private final static double globalUp[] = new double[]{ 0, 0, 1 };
	
	public static void main(String[] args) {
		run();
	}
	
	private static final int PSI_MEAN = 0;
	private static final int PSI_STDEV = 1;
	private static final int CHI_MEAN = 2;
	private static final int CHI_STDEV = 3;
	
	private static VRMLDrawer vrmlOut;
	private static SVGRayDrawing svgOutFlat;
	
	private static void run(){
		
		outPath = MinervaOpticsSettings.getAppsOutputPath() + "/rayTracing/augImse/transformMatch/p_" + pulse;
		
		sys.setupForPulse(pulse);
		//sys.setupForPulse(-20130701);
		
		OneLiners.TextToFile(outPath + "/optics-hashCode.txt", Integer.toString(sys.hashCode()));
		
		System.out.println("optics hash: " + sys.hashCode());
		
		//Util.rotateOnZ(sys.mainMirror, sys.mainMirror.getCentre(), -0.5 * Math.PI / 180);
		//moved these into sys code
		//sys.mainMirror.rotate(sys.mainMirror.getCentre(), Algorithms.rotationMatrix(sys.mainMirror.getUp(), 0.3 * Math.PI / 180));
		//sys.mainMirror.rotate(sys.mainMirror.getCentre(), Algorithms.rotationMatrix(sys.mainMirror.getRight(), 0.6 * Math.PI / 180));
		//Util.rotateOnX(sys.mainMirror, sys.mainMirror.getCentre(), -0.5 * Math.PI / 180);		
		
		vrmlOut = new VRMLDrawer(outPath + "/pictures.vrml");
		vrmlOut.setSmallLineLength(0.0001);
		vrmlOut.setDrawPolarisationFrames(true);
		vrmlOut.addVRML(AugMSESystem.vrmlScaleToAUGDDD);
		//vrmlOut.setSkipRays(9);
	
		String serverID = SettingsManager.defaultGlobal().getProperty("imseProc.cis.gmdsServerID", "pccis");		
		GMDSFetcher gmds = GMDSFetcher.defaultInstance(serverID);
		
		FeatureTransformCubic xform = new FeatureTransformCubic(gmds, experiment, pulse);
		
		List<FeatureTransform.Point> xformPoints = xform.getPoints();
		int nXFPoints = xformPoints.size();
		int nPoints = nXFPoints + nBeamPoints*4 + nImgLimitPoints + nHoleRimPoints;
		cols = ColorMaps.jet(nPoints);
		
		//make mirror box front transparent when mirror is out, otherwise we sometimes get no output 
		//sys.mirrorBox.frontUpper.setInterface(NullInterface.ideal());
		//sys.mirrorBox.holeIris.setInterface(NullInterface.ideal());
		
		//3D SVG output along centre of camera view plane
		double rot[][] = new double[3][]; 
		double bwPos[] = new double[]{ -2.205, -0.505, 0.349 }; // point "diag B L #11", roughly in the middle of the view at the backwall 
		rot[0] = Util.reNorm(Util.minus(bwPos, faro.getCameraPos())); //x along view centre (ish)
		rot[1] = Util.reNorm(Util.cross(globalUp, rot[0])); // y to the left
		rot[2] = Util.reNorm(Util.cross(rot[0], rot[1])); // z up (ish)
		
		svgOutFlat = new SVGRayDrawing(outPath + "/raysFlat", new double[]{ -2, -4, -1, 1, 0, 1 }, true, rot);
		double cols[][] = ColorMaps.jet(nPoints);
		svgOutFlat.generateLineStyles(cols, svgRayWidth, svgOpticWidth);
		
		
		// ******* Background feature points (transform) ********
		BinaryMatrixWriter binOut = new BinaryMatrixWriter(outPath + "/out-transform.bin", 8);
		int iP=0;
		for(FeatureTransform.Point p : xformPoints) {
			
			double startPos[] = new double[]{ p.x, p.y, p.z };
			double imgPos[] = new double[]{ p.imgX, p.imgY };
			Element target = initRaysTarget;
			
			traceFromPoint(iP++, p.name, startPos, imgPos, target, false, binOut);			
		}
		binOut.close();
				
		// ******* Beam axis points ********
		ArrayList<double[][]> viewLines = new ArrayList<double[][]>(); 
		for(int iB=0; iB < 4; iB++){
			binOut = new BinaryMatrixWriter(outPath + "/out-beams-Q"+(iB+1)+".bin", 8);
			
			for(int j=0; j < nBeamPoints; j++){
			
				//beam points
				double p = p0 + (p1 - p0) * j / (nBeamPoints - 1.0); 
				double startPos[] = new double[]{ //-1.3, -1.3, 0.1 };
						sys.nbiStartAll[iB][0] + p * sys.nbiUnitAll[iB][0],
						sys.nbiStartAll[iB][1] + p * sys.nbiUnitAll[iB][1],
						sys.nbiStartAll[iB][2] + p * sys.nbiUnitAll[iB][2],
				};
				double imgPos[] = new double[]{ Double.NaN, Double.NaN };
				double R = FastMath.sqrt(startPos[0]*startPos[0] + startPos[1]*startPos[1]);
				Element target = initRaysTarget;
				String pointName = "beam-Q" + (iB+1) + "-p_" + p + "-R_" + R;	
				
				double uVec[] = traceFromPoint(iP++, pointName, startPos, imgPos, target, false, binOut);
				viewLines.add(new double[][]{ startPos, uVec });
			}
			binOut.close();
		}
		calcViewConvergence(viewLines, xform);		
		
		// ******* Image plane aperture limit points ********
		binOut = new BinaryMatrixWriter(outPath + "/out-imageLimit.bin", 8);
		for(int j=0; j < nImgLimitPoints; j++){
		
			//hole rim points
			Iris aperture = sys.tubeOptics.lens2Iris; Element target = sys.tubeOptics.lens3Front;				
			//Iris aperture = sys.tubeOptics.hwpIris; Element target = sys.tubeOptics.lens3Front;
			
			double r = aperture.getApatureRadius(); //as model
			//double r = 0.080/2; //smaller mirror box entrance (too small) so it can actually be in view 
			//double r = 0.105/2; //slightly smaller L2 (is 116mm in model) field lens aperture seems to make observed edge
			r *= 0.999; //we redefined it inside the optics, but we need this is be slightly smaller so it actually misses it
			double a[] = aperture.getUp();
			double b[] = aperture.getRight();
			double c[] = aperture.getCentre();				
			double n[] = aperture.getNormal();
			double d = 0.001;
			double theta = ((double)j - (nXFPoints + nBeamPoints)) / nImgLimitPoints * 2 * Math.PI;
			double cosTheta = FastMath.cos(theta), sinTheta = FastMath.sin(theta);
			double startPos[] = new double[]{
					c[0] + r*cosTheta * a[0] + r*sinTheta * b[0] + d*n[0],
					c[1] + r*cosTheta * a[1] + r*sinTheta * b[1] + d*n[1],
					c[2] + r*cosTheta * a[2] + r*sinTheta * b[2] + d*n[2],
			};
			double imgPos[] = new double[]{ Double.NaN, Double.NaN };
			String pointName = "imgLimit_" + j;
			
			traceFromPoint(iP++, pointName, startPos, imgPos, target, true, binOut);
			
		}
		binOut.close();
		

		// ******* Hole rim points ********
		binOut = new BinaryMatrixWriter(outPath + "/out-holeRim.bin", 8);
		for(int j=0; j < nHoleRimPoints; j++){
		
			//hole rim points
			Iris aperture = sys.mirrorBox.holeIris; 
			Element target = sys.mainMirror;
			
			//double r = aperture.getApatureRadius(); //as model
			//double r = 0.080/2; //smaller mirror box entrance (too small) so it can actually be in view 
			double r = 0.130/2; //smaller mirror box entrance 
			double a[] = aperture.getUp();
			double b[] = aperture.getRight();
			double c[] = aperture.getCentre();				
			double n[] = aperture.getNormal();
			double d = 0.001;
			double theta = ((double)j - (nXFPoints + nBeamPoints)) / nHoleRimPoints * 2 * Math.PI;
			double cosTheta = FastMath.cos(theta), sinTheta = FastMath.sin(theta);
			double startPos[] = new double[]{
					c[0] + r*cosTheta * a[0] + r*sinTheta * b[0] + d*n[0],
					c[1] + r*cosTheta * a[1] + r*sinTheta * b[1] + d*n[1],
					c[2] + r*cosTheta * a[2] + r*sinTheta * b[2] + d*n[2],
			};
			double imgPos[] = new double[]{ Double.NaN, Double.NaN };
			String pointName = "holeRim_" + j;
						
			traceFromPoint(iP++, pointName, startPos, imgPos, target, false, binOut);
			
		}
		binOut.close();
		
		//draw main optics
		vrmlOut.drawOptic(sys);
		svgOutFlat.drawElement(sys);
		
		//draw box at FARO convergence point
		double camSize = 0.05;
		Box box = new Box("cameraPos", faro.getCameraPos(), camSize,camSize,camSize, null, NullInterface.ideal());
		vrmlOut.drawOptic(box);
		svgOutFlat.drawElement(box);
		
		//draw beams		
		Optic beams = AugMSESystem.makeAllBeamCylds(0.05, 0.15, 0.40, 0.65);
		vrmlOut.drawOptic(beams);
		svgOutFlat.drawElement(beams);
		
		//draw boxes at fit points
		vrmlOut.startGroup("TransformPoints");
		svgOutFlat.startGroup("TransformPoints");
		double boxSize = 0.01;
		
		for(FeatureTransform.Point p : xformPoints){
			Box pointBox = new Box("TransformPoints_" + vrmlOut.cleanString(p.name),
						new double[]{ p.x, p.y, p.z }, boxSize,boxSize,boxSize, null, NullInterface.ideal());
			vrmlOut.drawOptic(pointBox);
			svgOutFlat.drawElement(pointBox);			
		}
		vrmlOut.endGroup();
		svgOutFlat.endGroup();
		
		vrmlOut.addVRML("}"); //end of rotate/transform
		vrmlOut.destroy();
		
		svgOutFlat.destroy();		
	}

	/** Work out the effective view position from the collection of average view lines 
	 * @param xform */
	private static double[] calcViewConvergence(List<double[][]> viewLines, FeatureTransformCubic xform) {

		double cameraPos[] = new double[3];
		int n=0;
				
		for(double lineA[][] : viewLines){
			for(double lineB[][] : viewLines){			
				if(lineA == lineB || Double.isNaN(lineA[1][0]) || Double.isNaN(lineB[1][0]))
					continue;
		
				double s = Algorithms.pointOnLineNearestAnotherLine(lineA[0], lineA[1], lineB[0], lineB[1]);

				cameraPos[0] += lineA[0][0] + s * lineA[1][0];
				cameraPos[1] += lineA[0][1] + s * lineA[1][1];
				cameraPos[2] += lineA[0][2] + s * lineA[1][2];
				n++;
				
			}
		}

		cameraPos[0] /= n; cameraPos[1] /= n; cameraPos[2] /= n;
		
		System.out.print("New effective view pos: ");
		OneLiners.dumpArray(cameraPos);
		
		//clear existing view point
		for(FeatureTransform.Point p : xform.getPoints())
			if(p.enable == FeatureTransform.EN_VIEW)
				p.enable = FeatureTransform.EN_IGNORE;
		
		String viewPointName = "view (TransformMatch-"+sys.hashCode()+")";
		FeatureTransform.Point p = xform.getPointByName(viewPointName);
				
		if(p == null){
			p = new FeatureTransform.Point(viewPointName);
			xform.getPoints().add(p);
		}
		p.x = cameraPos[0];
		p.y = cameraPos[1];
		p.z = cameraPos[2];
		p.imgX = 0;
		p.imgY = 0;
		p.lat = 0;
		p.lon = 0;
		p.R = 0;
		p.Z = 0;
		p.enable = FeatureTransform.EN_VIEW;		
				
		String serverID = SettingsManager.defaultGlobal().getProperty("imseProc.cis.gmdsServerID", "pccis");		
		GMDSFetcher gmds = GMDSFetcher.defaultInstance(serverID);
		xform.savePoints(gmds, experiment, pulse);
		
		return cameraPos;
	}

	/** @return Average unit vector of successful rays */ 
	private static double[] traceFromPoint(int j, String pointName, double startPos[], double imgPos[], 
			Element target, boolean fireBackwards, BinaryMatrixWriter binOut) {
		
		svgOutFlat.getSVG3D().startGroup("point" + j + "_" + pointName);
		
		int nAttempts = 0, nHit = 0;
		
		IntensityInfo intensityInfo = new IntensityInfo(sys);		
		HitPositionAverage imgPosInfo = new HitPositionAverage();
		HitPositionAverage mirrorPosInfo = new HitPositionAverage();
		
		for(int i=0; i < nRays; i++){
			 do{
				
				Pol.recoverAll();
				
				RaySegment ray = new RaySegment();
				ray.startPos =  startPos.clone();
				ray.dir = Tracer.generateRandomRayTowardSurface(ray.startPos, target);
			
				//(firing backwards) need to turn the ray around and move back a bit
				if(fireBackwards){
					ray.dir = Tracer.generateRandomRayTowardSurface(ray.startPos, sys.tubeOptics.lens1Front);
					ray.startPos = Util.plus(ray.startPos, Util.mul(ray.dir, 0.020));
					ray.dir = Util.mul(ray.dir, -1.0);
				}
				
				ray.length = Double.POSITIVE_INFINITY;
				
				ray.up = Util.cross(Util.reNorm(Util.cross(ray.dir, globalUp)), ray.dir); // This is MSE-like emission
				
				ray.E0 = new double[][]{ { 1,0,0,0 } };
			
				ray.wavelength = wavelen; 
									
				Tracer.trace(sys, ray, 100, minIntensity, traceReflections);
				
				ray.processIntersections(null, intensityInfo);
				ray.processIntersections(imagePlane, imgPosInfo);
				List<Intersection> hits = ray.getIntersections(imagePlane); 
								
				if(hits.size() > 0){
					ray.processIntersections(sys.mainMirror, mirrorPosInfo);
					
					if(nHit < nRaysDraw){
						vrmlOut.drawRay(ray, cols[j]);
						svgOutFlat.drawRay(ray, j);							
					}
					
					nHit++;
					break;
				}
				if(nAttempts > maxAttempts){
					break;
				}
				nAttempts++;
			}while(true);
			
		}
		// if(nHit == 0)
			// intensityInfo.dump(true);
			
		 		
		double sigR2 = imgPosInfo.getSigmaRR() / sys.imseOptics.ccdWidth;
		double sigU2 = imgPosInfo.getSigmaUU() / sys.imseOptics.ccdHeight;
		double sigD2 = sigR2 + sigU2;
		
		binOut.writeRow(startPos, 
				imgPos[0], imgPos[1],
				(imgPosInfo.getMeanR() / sys.imseOptics.ccdWidth) + 0.5, 
				(imgPosInfo.getMeanU() / sys.imseOptics.ccdHeight) + 0.5,
				FastMath.sqrt(sigD2));
		
		System.out.println("\n---------------------------------------- "+j+" "+pointName+" ----------------------------------------");
		System.out.println(j + ": Hits = " + nHit + " of " + nAttempts);
					
		svgOutFlat.getSVG3D().endGroup();
		
		return Util.reNorm(Util.minus(sys.mainMirror.planeRUToPosXYZ(mirrorPosInfo.getMeanPosRU()), startPos));
	}

}
