package oscRemoval;

/** Attempt to remove oscillation in both ne and mirror movements from lateral channels by fitting 
 * the osciallation amplitude and initial phase to the mirror movements during the baseline period
 */

import iaTools.GeneticAlgorithm;
import iaTools.SimpleSearches;
import mdsPlus.JetMDSReader;
import support.AsciiMatrixFile;
import support.Grids;
import varsAndFuncs.DoubleParam;
import varsAndFuncs.Function;
import varsAndFuncs.LimitedDoubleParam;
import varsAndFuncs.RotationalDoubleParam;
import varsAndFuncs.VFUtils;

public class OscRemoval implements Function {
	final double lasWavelen = 195e-6; //laser wavelength = osc wavelen in mirror movement
	
	static final int SIGNAL_RMS = 1;
	static final int SIGNAL_RMP = 2;
	static final int SIGNAL_PSD = 3;
	static final int SIGNAL_PSP = 4;
	static final int SIGNAL_R = 5;
	static final int SIGNAL_Rp = 6;
	
	static final int SEARCH_SIMPLEX = 1;
	static final int SEARCH_GA = 2;
	static final int SEARCH_CONJGRAD = 3;
	

	/* ********************** Operating Parameters ********************* */
	final String rootDir = "/work/polarim/";
	
	final int pulse = 75411;				//pulse
	final int ch = 6; 						//channel, 1 based

	//non-plasma windows
	final double t0[] = new double[]{ 28.0 , 67.0 };
	final double t1[] = new double[]{ 40.0, Double.POSITIVE_INFINITY };

	final int signal = SIGNAL_RMS;			//signal to work on
	
	final boolean realTime = true;			//load real-time JPFs rather than normal ones
	final boolean jpfLid = false;			//Try to fit using unprocessed jpf/LID, rather than adding calced mirror movement to ppf/LID
	//final int search = SEARCH_SIMPLEX;			//search method
	final int search = SEARCH_CONJGRAD;			//search method
	final double sigma = 5;
	
	final boolean startFresh = false;
	
	//Genetic Algorithm Parameters
	final int gensPerSave = 500;
	final int totalGens = 10000;
	final int pop = 100;
	final int children = 10;
	
	/* ***************************************************************** */
	
	//The free parameters in the fit:
	DoubleParam A; //Amplitude
	DoubleParam phi0; //fixed phase
	DoubleParam y0, y1 ; //offset at start, offset at end (assume linear drift over time)	
	DoubleParam lambdaAdjust; //wavelength adjustment
	
	//Data handling and signals
	JetMDSReader jmds; //Data signals handler (JET MDS+)
	double mir[]; //mirror movement
	double sig[]; //raw signals 
	double t[]; //time base
	double lid[]; //line integrated density
		
	//Main program
	public OscRemoval(String args[]) {
		//jmds = new JetMDSReader(null,-1,"/home/oford/work/data");
		jmds = new JetMDSReader("mdsplus.jet.efda.org",8000,"/home/oford/work/data");
		
		double rms[],rmp[],psd[],psp[];		//raw JPF signals

		//setup starting point for params
		phi0 = new RotationalDoubleParam(0,-Math.PI,Math.PI,-Math.PI,Math.PI,"phi0",true);
		lambdaAdjust = new LimitedDoubleParam(1.0, 0.9, 1.1, 0.8, 1.2, "wavelen adjust",true);
		if(signal == SIGNAL_R || signal == SIGNAL_Rp){
			A = new LimitedDoubleParam(0.1,-1,1,-4,4,"Amp",true);
			 y0 = new LimitedDoubleParam(0,-1,1,-2,2,"y0",true);
			 y1 = new LimitedDoubleParam(0,-0.05,0.05,-0.10,0.10,"y1",true);
		}else{
			 A = new LimitedDoubleParam(0,0,10,0,500,"Amp",true);
			 y0 = new LimitedDoubleParam(0,-1000,1000,-2000,2000,"y0",true);
			 y1 = new LimitedDoubleParam(0,-2,2,-10,10,"y1",true);			 		
		}
		
		if(realTime){
			//Real-time data
			t = jmds.getTimeVector(pulse + "/jpf/df/g4r-rms<raw:00"+ch);			
			rms = jmds.get1DData(pulse + "/jpf/df/g4r-rms<raw:00"+ch);
			rmp = jmds.get1DData(pulse + "/jpf/df/g4r-rmp<raw:00"+ch);
			psd = jmds.get1DData(pulse + "/jpf/df/g4r-psd<raw:00"+ch);
			psp = jmds.get1DData(pulse + "/jpf/df/g4r-psp<raw:00"+ch);
		}else{
			//normal JPF data
			t = jmds.getTimeVector(pulse + "/jpf/df/g4-rms1");
			rms = jmds.get1DData(pulse + "/jpf/df/g4-rms"+ch);
			rmp = jmds.get1DData(pulse + "/jpf/df/g4-rmp"+ch);
			psd = jmds.get1DData(pulse + "/jpf/df/g4-psd"+ch);
			psp = jmds.get1DData(pulse + "/jpf/df/g4-psp"+ch);
		}
			
		switch(signal){ //which signal are we working on?
			case SIGNAL_RMS: sig = rms; break;
			case SIGNAL_RMP: sig = rmp; break;
			case SIGNAL_PSD: sig = psd; break;
			case SIGNAL_PSP: sig = psp; break;
			case SIGNAL_R: 
					 sig = new double[psd.length];
					 for(int i=0;i<psd.length;i++)
						 sig[i] = psd[i] / rms[i];
					 break;
			case SIGNAL_Rp: 
					 sig = new double[psp.length];
					 for(int i=0;i<psp.length;i++)
						 sig[i] = psp[i] / Math.sqrt(rms[i] * rmp[i]);
		}
		
		
		if(jpfLid){ //use JPF line integrated density, which has no had mirror movements removed so we can use it all-in-one
			double rawLID[] = jmds.get1DData(pulse+"/jpf/df/g1v-ml<"+(ch-4));
			double rawLIDT[] = jmds.getTimeVector(pulse+"/jpf/df/g1v-ml<"+(ch-4));
			lid = Grids.reGrid(t,rawLIDT,rawLID,false);
			mir = lid;			
		}else{ //treat mirror movement and LID seperately, get LID from PPF
			double rawMir[] = jmds.get1DData(pulse+"/ppf/kg1v/mir"+ch);
			double rawMirT[] = jmds.getTimeVector(pulse+"/ppf/kg1v/mir"+ch);
			mir = Grids.reGrid(t,rawMirT,rawMir,false);
			
			double rawLID[] = jmds.get1DData(pulse+"/ppf/kg1v/lid"+ch);
			double rawLIDT[] = jmds.getTimeVector(pulse+"/ppf/kg1v/lid"+ch);
			lid = Grids.reGrid(t,rawLIDT,rawLID,false);		
			for(int i=0;i<lid.length;i++)
				lid[i] = lid[i] * 2; //someone dividied the PPF LID by 2 for laterals
		}
		
		if(!startFresh){ //if  not starting from scratch, load the current state
			double tmp[][] = AsciiMatrixFile.mustLoad(rootDir + "/" + pulse + "/rmOsc-params.txt",true);
			if(tmp[0].length != (getParams()).length)
				throw new RuntimeException("Wrong number of params in file, need to startFresh.");
			DoubleParam.setArray(getParams(), tmp[0]);			
		}
		
		switch(search){ //chose a search algorithm
		case SEARCH_CONJGRAD:
			SimpleSearches.conjGrad(this, 100, 10, 10, SimpleSearches.LINESEARCH_NEWTONRAPHSON);
			break;
		case SEARCH_SIMPLEX:
			SimpleSearches.downhillSimplexSearch(this, 50000, 1e-10);
			break;
		case SEARCH_GA:
			GeneticAlgorithm ga = new GeneticAlgorithm(this);
			if(startFresh)
				ga.initRandom(pop, children, true);
			else
				ga.initFromCurrent(pop, children, true);
			for(int i=0;i<(totalGens/gensPerSave);i++){
				ga.go(gensPerSave, 100); //*/
				VFUtils.dumpParams(this, false);
				System.out.println("logP = " + evaluate());				
				dump("",0,0,0);
			}
			break;
		default: throw new IllegalArgumentException("Unknown search type"); 
		}
		
		//SimpleSearches.conjGrad(this,100, 10, 10, SimpleSearches.LINESEARCH_GOLDENSECTION);
		
		/* VFUtils.dumpParams(this, false);		 
		 ParamWriter pWriter = new ParamWriter(rootDir + "/"+pulse+"/rmOsc-",getParams(),null);		
		AutoCorrelMCMC mcmc = new AutoCorrelMCMC(this,new SampleHandler[]{ pWriter });		
		mcmc.go(5, 1000, 0.0001, 100, "rmOsc"); //*/
		
		VFUtils.dumpParams(this, false);
		System.out.println("logP = " + evaluate());
		
		dump("",0,0,0);
	}
	
	//Output stuff to look at
	public void dump(String fileName, double minX, double maxX, int steps) {
		double data[][] = new double[9][];
		data[0] = t;
		data[1] = mir;
		data[2] = sig;
		data[3] = lid;
		data[4] = new double[t.length];
		data[5] = new double[t.length];
		data[6] = new double[t.length];
		data[7] = new double[t.length];
		data[8] = new double[t.length];
		
		for(int i=0;i<t.length;i++){
				double yMir=0,yLid=0;
				
				if(jpfLid){
					yLid = y0.get() 
						+ y1.get() * t[i]
					    + A.get() * Math.cos(2 * Math.PI * lid[i] / (lambdaAdjust.get() * 1.143e19) 												
										+ phi0.get());
					
					yMir=yLid; //same thing
				}else{
					yMir = y0.get() 
							+ y1.get() * t[i]
							+ A.get() * Math.cos(2 * Math.PI * mir[i] / (lambdaAdjust.get() * 0.5 * lasWavelen)												
												+ phi0.get());
				
					yLid = y0.get() 
							+ y1.get() * t[i]
							+ A.get() * Math.cos(2 * Math.PI * mir[i] /  (lambdaAdjust.get() * 0.5 * lasWavelen)
												- 2 * Math.PI * lid[i] / 1.143e19
												+ phi0.get());
				}
	
				data[4][i] = yMir;
				data[5][i] = sig[i] - yMir + y0.get() + y1.get() * t[i];
				data[6][i] = yLid;
				data[7][i] = sig[i] - yLid + y0.get() + y1.get() * t[i];
				for(int j=0;j<t0.length;j++)
					if(t[i] >= t0[j] && t[i] <= t1[j])
						data[8][i] = 1;
				
		}
		
		AsciiMatrixFile.mustWrite(rootDir + "/" + pulse + "/rmOsc.txt",data,true);
		AsciiMatrixFile.mustWrite(rootDir + "/" + pulse + "/rmOsc-params.txt",
				new double[][]{ DoubleParam.toArray(getParams()) },true);
		
		jmds.writeToCache(pulse + "/local/rmosc/ch"+ch+"/params",
				new double[(getParams()).length],
				DoubleParam.toArray(getParams()));
		
		switch(signal){
			case SIGNAL_RMS: jmds.writeToCache(pulse + "/local/rmosc/ch"+ch+"/rms", t, data[7]); break;
			case SIGNAL_RMP: jmds.writeToCache(pulse + "/local/rmosc/ch"+ch+"/rmp", t, data[7]); break;
			case SIGNAL_PSD: jmds.writeToCache(pulse + "/local/rmosc/ch"+ch+"/psd", t, data[7]); break;
			case SIGNAL_PSP: jmds.writeToCache(pulse + "/local/rmosc/ch"+ch+"/psp", t, data[7]); break;
			case SIGNAL_R: jmds.writeToCache(pulse + "/local/rmosc/ch"+ch+"/r", t, data[7]); break;
			case SIGNAL_Rp: jmds.writeToCache(pulse + "/local/rmosc/ch"+ch+"/rp", t, data[7]); break;
		}
	}
	
	//The fitting cost function - optimise this value
	public double evaluate() {
		double logP = 0;
		
		for(int i=0;i<t.length;i++)
			for(int j=0;j<t0.length;j++)
				if(t[i] >= t0[j] && t[i] <= t1[j]){
					double y;
					if(jpfLid)
						y = y0.get() + y1.get() * t[i] 
								+ A.get() * Math.cos(2 * Math.PI * mir[i] / (lambdaAdjust.get() * 1.143e19)
													+ phi0.get());
					else
						y = y0.get() + y1.get() * t[i] 
								+ A.get() * Math.cos(2 * Math.PI * mir[i] / (lambdaAdjust.get() * 0.5 * lasWavelen)
													+ phi0.get());
					
					logP -= (y - sig[i]) * (y - sig[i]) / ( 2* sigma*sigma );
				}
			
		
	
		return logP;
	}
	
	//Stuff for the fitters
	public double[] dumpLine(double minX, double maxX, int steps) { throw new RuntimeException("Not implemented"); }
	public DoubleParam[] getParams() { return new DoubleParam[]{ A,phi0,lambdaAdjust,y0,y1 }; }
	public static void main(String[] args) { new OscRemoval(args); };
}
