package seed.minerva.apps.imse;


import jafama.FastMath;


import java.awt.Color;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map.Entry;

import javax.management.RuntimeErrorException;

import binaryMatrixFile.BinaryMatrixFile;
import binaryMatrixFile.BinaryMatrixWriter;

import otherSupport.ColorMaps;
import otherSupport.RandomManager;
import otherSupport.SettingsManager;

import algorithmrepository.Algorithms;
import algorithmrepository.LinearInterpolation1D;
import oneLiners.OneLiners;
import seed.minerva.MinervaOpticsSettings;
import seed.minerva.aug.mse.AugMSESystem;
import seed.minerva.optics.Util;
import seed.minerva.optics.tracer.Tracer;
import seed.minerva.optics.types.*;
import seed.minerva.optics.surfaces.*;
import seed.minerva.optics.drawing.AsciiOutForWendel;
import seed.minerva.optics.drawing.SVGRayDrawing;
import seed.minerva.optics.drawing.VRMLDrawer;
import seed.minerva.optics.interfaces.Absorber;
import seed.minerva.optics.interfaces.IsoIsoInterface;
import seed.minerva.optics.interfaces.IsoIsoStdFresnel;
import seed.minerva.optics.interfaces.Reflector;
import seed.minerva.optics.materials.IsotropicFixedIndexGlass;
import seed.minerva.optics.materials.SchottSFL6;
import seed.minerva.optics.materials.Vacuum;
import seed.minerva.optics.optics.*;
import svg.SVGSplitView3D;


/** Attempts to autofocus an optical system by moving a given element along a given axis
 * in an attempt to reduce the RMS spread of hit points on a plane.
 * 
 * @author oliford
 */
public class AutoFocus {
	
	final static double reHitTolerence = 1e-6;
	
	public static String outPath = MinervaOpticsSettings.getAppsOutputPath() + "/rayTracing/autoFocus";
	
	/** Optical system setup */
	public static AugMSESystem sys = new AugMSESystem(null);
	
	public static Plane imagePlane = sys.tubeOptics.fibreEnds;
	//public static Plane imagePlane = sys.imseOptics.ccd;
	public static Plane polarisationPlane = sys.tubeOptics.PEMsFront;
	public static Element initRaysTarget = sys.mirrorBox.protectionCoverFront;
	public static Element focusingElement = imagePlane;
	
	/** n Rays per evaluation */ 
	public static int nRays = 80000;
	
	/** What rays to trace */
	public static double minIntensity = 0.01;
	public static boolean traceReflections = false;

	/** Where to fire rays from */
	public static double rayStart[] =
			sys.calibLampPos[1];
			/*new double[]{
			sys.nbiStart[0] + 0.62*sys.nbiUnit[0],
			sys.nbiStart[1] + 0.62*sys.nbiUnit[1],
			sys.nbiStart[2] + 0.62*sys.nbiUnit[2]
		};             //*/
	
	public static double adjustVec[] = sys.tubeOptics.fibreEnds.getNormal();
	                    	
	/** Initial min/max of focus range */
	public static double d0 = -0.010;
	public static double d1 = +0.010;
	/** Number of scans, and steps per scan */
	public static int nScans = 7;
	public static int nAdjustmentsPerScan = 7;
	
	/** Output for individual hit data (useful if only 1 start point) */
	public static BinaryMatrixWriter hitsOut = new BinaryMatrixWriter(outPath + "/hits.bin", 20);
	
	/** Colour table for output */
	public static double col[][] = ColorMaps.jet(nAdjustmentsPerScan);
	
	/** stats storage */
	private static int nHitsInStats;
	private static double statsCollect[][];
	
	private static DecimalFormat fmt = new DecimalFormat("##0.000");
	
	public static void main(String[] args) {
		double uB[] = Util.reNorm(Util.cross(sys.nbiUnit, new double[]{0,0,1}));
		double uA[] = Util.cross(uB, sys.nbiUnit);
		
		VRMLDrawer outlierOut = new VRMLDrawer(outPath + "/outliers.vrml", 0.0001);
		outlierOut.setDrawPolarisationFrames(false);
		
		double nextSD0 = d0;
		double nextSD1 = d1;
		
		//we can't set the centre, since in general the element might not have one
		//so we have to keep using .shift() in steps. However, we can check that the
		//bounding sphere centre returns to the same place each time
		double origCentre[] = focusingElement.getBoundarySphereCentre();
		
		for(int iS=0; iS < nScans; iS++){
			double sd0 = nextSD0;
			double sd1 = nextSD1;

			nextSD0 = Double.NaN;
			nextSD1 = Double.NaN;
			double bestVar = Double.POSITIVE_INFINITY;
			
			focusingElement.shift(new double[]{ sd0 * adjustVec[0], sd0 * adjustVec[1], sd0 * adjustVec[2] });
			
			double dD = (sd1 - sd0) / (nAdjustmentsPerScan - 1);
			
			System.out.println("Scan " + iS + ": "+fmt.format(sd0/1e-3)+" mm --> "+fmt.format(sd1/1e-3)+" mm");
			
			for(int iD=0; iD < nAdjustmentsPerScan; iD++){ //for each focusing adjustment 
				double D = sd0 + iD * dD;
				
				//shift ready for next position
				if(iD > 0)
					focusingElement.shift(new double[]{ dD * adjustVec[0], dD * adjustVec[1], dD * adjustVec[2] });
				
				double dirs[][] = new double[nRays][3];
		
				nHitsInStats = 0;
				statsCollect = new double[nRays][];
				for(int j=0;j<nRays;j++){
					
					RaySegment ray = new RaySegment();
					ray.startPos = rayStart.clone();
					dirs[j] = Tracer.generateRandomRayTowardSurface(ray.startPos, initRaysTarget);
					ray.dir = dirs[j];
					//ray.dir = Util.reNorm(Util.minus(initRaysTarget.getBoundarySphereCentre(), ray.startPos));
		
					double right[] = Util.cross(ray.dir, new double[]{ 0, 0, 1 });
					ray.up = Util.reNorm(Util.cross(right, ray.dir));
					
					//For now, we're going to ignore the problem of what the MSE emission
					//makes as an initial polarisation and just do things in terms of 'up' and 'right'
					ray.E0 = new double[][]{ {1, 0, 0, 0}, {0, 0, 1, 0} };
					ray.length = Double.POSITIVE_INFINITY;
					ray.wavelength = 653e-9;
					ray.startHit = null;
											
					Tracer.trace(sys, ray, 50, minIntensity, traceReflections);
								
					processRay(iD, iS, D, ray);
									
					Pol.recoverAll();
					//ray.dumpPath();
					
				}
				
				double stats[] = processStats();
				double mX = stats[0], mY = stats[1], var = stats[2];
				
				//draw outliers
				int nOutliers = 0;
				if(outlierOut != null){
					nOutliers = drawOutliers(outlierOut, dirs, mX, mY, var);
				}
					
				
				System.out.println("step = "+iS + ","+iD+",\tadjust = "+fmt.format(D/1e-3)+" mm,\tspotSize = " + fmt.format(FastMath.sqrt(var)/1e-6)+" um,\thits = "+nHitsInStats+" / "+nRays + "\tnOutliers = "+nOutliers);
				
				nHitsInStats=0;
				
				if(var > 0 & var < bestVar){
					bestVar = var;
					nextSD0 = D - dD;
					nextSD1 = D + dD;
				}
				
				
			}
			
			//element should have ended up at s1, so we can take it back to it's original position
			focusingElement.shift(new double[]{ -sd1 * adjustVec[0], -sd1 * adjustVec[1], -sd1 * adjustVec[2] });
			
			double l = Util.length(Util.minus(origCentre, focusingElement.getBoundarySphereCentre()));
			if(l > (dD/1000))
				throw new RuntimeException("Adjustment element doesn't seem to be returning to its start position.");
		}	
		
		if(outlierOut != null){
			outlierOut.drawOptic(sys);
			outlierOut.destroy();
		}
	}
	
	private static int drawOutliers(VRMLDrawer out, double dirs[][],
			double mX, double mY, double var) {
		
		int nOutliers = 0;
		double nSigmaOutSq = 5.0*5.0;
		
		for(int i=0; i < nHitsInStats; i++){
			
			double d2 = FastMath.pow2((statsCollect[i][1] - mX)) + FastMath.pow2((statsCollect[i][2] - mY));
			double s2 = d2 / var;
					
			if(s2 > nSigmaOutSq){
				RaySegment ray = new RaySegment();
				ray.startPos = rayStart.clone();
				ray.dir = dirs[i];
				double right[] = Util.cross(ray.dir, new double[]{ 0, 0, 1 });
				ray.up = Util.reNorm(Util.cross(right, ray.dir));
				ray.E0 = new double[][]{ {1, 0, 0, 0}, {0, 0, 1, 0} };
				ray.length = Double.POSITIVE_INFINITY;
				ray.wavelength = 653e-9;
				ray.startHit = null;
										
				Tracer.trace(sys, ray, 50, minIntensity, traceReflections);
				
				out.drawRay(ray);
				
				nOutliers++;
			}
			
		}
		
		return nOutliers;
			
	}

	
	/* **************************** Processing Stuff **************************** */

	/** Main processing, per ray */
	private static void processRay(int iD, int iS, double D, RaySegment ray) {
		if(hitsOut == null);
		
		List<Intersection> imagePlaneHits = ray.getIntersections(imagePlane);
				
		if(imagePlaneHits.size() == 0)
			return; //if it never hit the img plane, we don't care what pol it is
		
		//For every hit pair (sub ray from earlier on etc)
		for(int j = 0; j < imagePlaneHits.size(); j++) {
			Intersection imgHit = imagePlaneHits.get(j);
			
			//Walk backwards until we find the last time it went through the polarisation sensitive plane
			Intersection polHit = imgHit;
			
			while(polHit.surface != polarisationPlane) {
				if(polHit.incidentRay.startHit == null)
					return; //never found it and got to start of ray
				polHit = polHit.incidentRay.startHit;
			}
			
			//Intersection polHit = polPlaneHits.get(j);
			
			//make sure the sense of polarisation on the pol plane incident ray 
			//matches the pol plane's sense of up
			polHit.incidentRay.rotatePolRefFrame(polarisationPlane.getUp());
			
			double imgPos[] = imagePlane.posXYZToPlaneUR(imgHit.pos);
			hitsOut.writeRow(
						iD,	iS, D,									// 1: adjustment (iD,iS,D)
						ray.startPos, 								// 4: start posXYZ 
						ray.startPos, 								// 7: start posXYZ 
						imgPos[0] + iD * 0.01,  									//10: image position
						imgPos[1] + iS * 0.01,
						col[iD], 									//12: colour
						 Pol.intensity(imgHit.incidentRay.E1[0]),	//15: total intensity
						 Pol.psi(polHit.incidentRay.E1[0]),			//16: rotation dir of light originally 'up'
						 Pol.chi(polHit.incidentRay.E1[0]),			//17: ellipticity of light originally 'up'
						 Pol.intensity(imgHit.incidentRay.E1[1]),	//18: total intensity
						 Pol.psi(polHit.incidentRay.E1[1]),			//19: rotation dir of light originally 'right'						
						 Pol.chi(polHit.incidentRay.E1[1])			//20: ellipticity of light originally 'right'				
						
					);
			
			statsCollect[nHitsInStats] = new double[]{
							imgHit.incidentRay.endIntensity(),
							imgPos[0],
							imgPos[1] 
						};
			
			nHitsInStats++;
			
		}
		
	}

	private static double[] processStats() {
		if(nHitsInStats == 0)
			return null;
		
		double mX=0, mY=0, sumI=0;
		for(int i=0; i < nHitsInStats; i++){
			sumI += statsCollect[i][0];
			mX += statsCollect[i][0] * statsCollect[i][1];
			mY += statsCollect[i][0] * statsCollect[i][2];
		}
		mX /= sumI; mY /= sumI;
		
		double var=0;
		for(int i=0; i < nHitsInStats; i++){
			var += statsCollect[i][0] * (FastMath.pow2(statsCollect[i][1] - mX) + FastMath.pow2(statsCollect[i][2] - mY));
		}
		var /= sumI;
						
		return new double[]{ mX, mY, var };
	}	
		
}
