package imseProc.proc.seriesAvg;

import jafama.FastMath;


import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
import org.eclipse.swt.widgets.Composite;
import imseProc.core.ByteBufferImage;
import imseProc.core.IMSEProc;
import imseProc.core.ImagePipeController;
import imseProc.core.Img;
import imseProc.core.ImgPipe;
import imseProc.core.ImgProcPipe;
import imseProc.core.ImgSink;
import imseProc.core.ImgSource;

/** Image transform processor.
 * At first just rotation and stuff, but eventually will perform transform
 * from recorded images to beam plane R,Z coords etc
 * 
 * The SWT controller should be used to select points on the image which can
 * be given 3D points of the vessel. Along with a single 3D observation point
 * on the mirror, and the beam definitions, this should be used to map those
 * points to that line's intersection on the beam plane. 
 * 
 * 
 * @author oliford
 *
 */
public class SeriesAverageProcessor extends ImgProcPipe {
		
	private SeriesAverageSWTController controller;
	
	private HashMap<String, double[]> configMap = new HashMap<String, double[]>();
	public final static int CFG_I0 = 0;
	public final static int CFG_I1 = 1;
	public final static int CFG_SIGMA_I = 2; // Smoothing width
	public final static int CFG_INTERLACE = 3; // is interlaced? 
	public final static int CFG_SPIKE_THRES = 4; // spike threshold for radiation removal (or <= 0 to disable)
	
	public SeriesAverageProcessor() {
		super(ByteBufferImage.class);
		configMap.put("default", new double[]{ 0, -1, -1, -1, 1500 });		
	}
	
	public SeriesAverageProcessor(ImgSource source, int selectedIndex) {
		super(ByteBufferImage.class, source, selectedIndex);		
		configMap.put("default", new double[]{ 0, -1, -1, -1, 1500 });
	}
	
	@Override
	protected int[] sourceIndices(int outIdx) {
		double cfg[] = findConfig(outIdx);
		if(cfg == null) //no config means no inputs
			return new int[0];
		
		//invalid sigma means only the same image as input
		double sigmaIdx = cfg[CFG_SIGMA_I];
 		if(sigmaIdx <= 0){
 			return new int[]{ outIdx };
 		}
		
 		//get full range
		int idx0 = (int)cfg[CFG_I0];
		int idx1 = (int)cfg[CFG_I1];
		int nImagesIn = connectedSource.getNumImages();
		if(idx1 < 0) //-1 here means to the end
			idx1 = nImagesIn - 1;
		
		//truncate to within sigma range
		idx0 = Math.max(idx0, outIdx - (int)(4*sigmaIdx));
		idx1 = Math.min(idx1, outIdx + (int)(4*sigmaIdx));
		
		//if interlaced, make sure loops runs over odd/even input images
		boolean interlaced = cfg[CFG_INTERLACE] != 0;
		int step = 1;
		if(interlaced){
			step = 2;
			int offset = (outIdx % 2);
			idx0 = ((int)(idx0/2) * 2) + offset;
			idx1 = ((int)(idx1/2) * 2) + offset; 
		}
		
		//if rad removal, we definitely need the one before and one after
		// err... no , because we dont actually want to include them
		/*boolean radRemoval = cfg[CFG_SPIKE_THRES] > 0;
		if(radRemoval){
			if(idx0 > (outIdx - step) && (outIdx - step) >= 0) 
				idx0 = (outIdx - step);
			if(idx1 < (outIdx + step) && (outIdx + step) < nImagesIn) 
				idx1 = (outIdx + step);
		}*/
				
				
		int n = (idx1 - idx0 - 1) / step;
		int idxs[] = new int[n];
		for(int i=0; i < n; i ++){
			idxs[i] = idx0 + i*step;
		}
		return idxs;
	}

	@Override
	/** Image allocation */
	protected boolean checkOutputSet(int nImagesIn){
		outWidth = inWidth;
		outHeight = inHeight;
	    int nImagesOut = nImagesIn;
		
        //allocate or re-use
        ByteBufferImage newOutSet[] = ByteBufferImage.checkBulkAllocation(this, (ByteBufferImage[])imagesOut, outWidth, outHeight, ByteBufferImage.DEPTH_DOUBLE, nImagesOut, ByteOrder.LITTLE_ENDIAN);
        if(newOutSet != imagesOut){
        	imagesOut = newOutSet;
        	return true;
        }
        
        return false;
	}
	
	private double[] findConfig(int i){
		for(Entry<String, double[]> entry : configMap.entrySet()){
			if(entry != null && entry.getKey().length() > 0){
				double cfg[] = entry.getValue();					
				if(i >= cfg[CFG_I0] && (i <= cfg[CFG_I1] || cfg[CFG_I1] < 0)){
					return cfg;
				}
			}
		}
		return null;
	}
	
	@Override
	protected boolean doCalc(Img imageOutG, Img[] sourceSet, boolean settingsHadChanged) throws InterruptedException {
		ByteBufferImage imageOut = (ByteBufferImage)imageOutG;
		int imgIdxOut = imageOut.getSourceIndex();
		double cfg[] = findConfig(imgIdxOut);
		
 		if(cfg == null){
			imageOut.invalidate();						
			return false;
		}
 		
 		int idx0 = (int)cfg[CFG_I0];
		int idx1 = (int)cfg[CFG_I1];		
		if(idx1 < 0)
			idx1 = connectedSource.getNumImages()-1;
		boolean interlaced = cfg[CFG_INTERLACE] != 0;
		boolean radRemoval = cfg[CFG_SPIKE_THRES] > 0;
		double spikeThreshold = cfg[CFG_SPIKE_THRES];
		double sigmaIdx = cfg[CFG_SIGMA_I];
 		
		//actual averaging sums (because NaNs arn't counted)
		double sumAmp[][] = new double[outHeight][outWidth];
		
		//zero output image
		for(int oY=0; oY < outHeight; oY++) {
			for(int oX=0; oX < outWidth; oX++) {
				sumAmp[oY][oX] = 0;
				imageOut.setPixelValue(oX, oY, 0);						
			}
		}
		
		//for each input image in the input set
		for(Img imageIn : sourceSet){
			if(imageIn == null)
				continue;
			int imgIdxIn = imageIn.getSourceIndex();
			
			//unfortunately, these escape the change detection
			Img imageInP = radRemoval ? connectedSource.getImage(imgIdxIn + (interlaced ? 2 : 1)) : null;
			Img imageInM = radRemoval ? connectedSource.getImage(imgIdxIn - (interlaced ? 2 : 1)) : null;
			
			imageIn.startReading(); try{ 
			if(imageInP != null)imageInP.startReading(); try{ 
			if(imageInM != null)imageInM.startReading(); try{ 
				
				// Gaussian amplitude of this image adding 
				double amp;
				if(sigmaIdx > 0){
					double arg = ((double)imgIdxIn - imgIdxOut) / (double)sigmaIdx;
					amp = FastMath.exp( -FastMath.pow2(arg) / 2);
				}else{
					if(imgIdxIn != imgIdxOut)
						continue;
					amp = 1;
				}
				
				for(int oY=0; oY < outHeight; oY++) {
					for(int oX=0; oX < outWidth; oX++) {
						double sum = imageOut.getPixelValue(oX, oY);
						double val;
						
						if(radRemoval)
							val = checkForSpike(imageIn, imageInP, imageInM, oX, oY, spikeThreshold);
						else
							val = imageIn.getPixelValue(oX, oY);
						
						if(!Double.isNaN(val)) {
							sumAmp[oY][oX] += amp;
							imageOut.setPixelValue(oX, oY, sum + (val * amp));									
						}
					}
				}
			}finally{ if(imageInM != null)imageInM.endReading(); }
			}finally{ if(imageInP != null)imageInP.endReading(); }
			}finally{ imageIn.endReading(); }
		}
		
		// divide out actual sums
		for(int oY=0; oY < outHeight; oY++) {
			for(int oX=0; oX < outWidth; oX++) {
				double sum = imageOut.getPixelValue(oX, oY);
				imageOut.setPixelValue(oX, oY, sum / sumAmp[oY][oX]);
			}						
		}
	
		return true;
	}
	
	
	private final double checkForSpike(Img imageIn, Img imageInP, Img imageInM, int x, int y, double spikeThreshold) {
		 
		double val = imageIn.getPixelValue(x, y);
		
		//can't do anything without neighboring images
		if(imageInP == null || imageInM == null || x <= 1 || y <= 1 || x >= imageIn.getWidth() - 2 || y >= imageIn.getHeight() - 2)
			return val;
		
		//to call it a spike, we need it to be above both neighboring time, and above at least 18 of the 24 surrounding pixels
		double valTP = imageInP.getPixelValue(x, y);
		double valTM = imageInM.getPixelValue(x, y);
		
		boolean spTP = val > valTP + spikeThreshold;
		boolean spTM = val > valTM + spikeThreshold;
		
		int spXY = 0;
		for(int yy=-2; yy <= 2; yy++){
			for(int xx=-2; xx <= 2; xx++){
				if(xx == 0 && yy == 0)
						continue;
				if(val > imageIn.getPixelValue(x+xx, y+yy) + spikeThreshold )
					spXY++;
			}
		}
		
		if(spTP && spTM && spXY >= 18) {
			return (valTP + valTM) / 2;
		}
		
		return val;
	}
	
	@Override
	public SeriesAverageProcessor clone() { return new SeriesAverageProcessor(connectedSource, getSelectedSourceIndex());	}

	@Override
	/** For the sink side */
	public ImagePipeController createPipeController(Class interfacingClass, Object args[], boolean asSink) {
		if(interfacingClass == Composite.class){
			controller = new SeriesAverageSWTController((Composite)args[0], (Integer)args[1], this, asSink);
			controllers.add(controller);
			return controller;
		}
		return null;
	}
	
	public HashMap<String, double[]> getConfigMap(){ return configMap; }
	public void configModified(){
		settingsChanged = true;
		updateAllControllers();
		if(autoCalc)
			calc();
	}

	public void saveConfig(){
		/*
		if(connectedSource == null) return;
		
		String pointNames[] = new String[cfg.size()];
		double pointData[][] = new double[cfg.size()][];
		
		int i=0;
		for(Entry<String, double[]> entry : cfg.entrySet()){
			pointNames[i] = entry.getKey();
			pointData[i] = entry.getValue();
			i++;
		}
		
		String experiment = IMSEProc.getMetaExp(connectedSource);
		int pulse = IMSEProc.getMetaPulse(connectedSource);
		
		try{
			GMDSSignalDesc sigDesc = new GMDSSignalDesc(pulse, experiment, "TimeAvg/pointNames");
			GMDSSignal sig = new GMDSSignal(sigDesc, pointNames);
			IMSEProc.globalGMDS().writeToCache(sig);
	
			sigDesc = new GMDSSignalDesc(pulse, experiment, "TimeAvg/pointData");
			sig = new GMDSSignal(sigDesc, pointData);
			IMSEProc.globalGMDS().writeToCache(sig);
			
		}catch(RuntimeException e){
			System.err.println("Transform: Error saving points to exp " + experiment + 
								", pulse " + pulse + " because " + e.getMessage());
		}
		*/
	}
	
	public void loadMapPoints(int overridePulse) {
		/*
		if(connectedSource == null) return;
		
		String experiment = IMSEProc.getMetaExp(connectedSource);
		int pulse ;
		try{
			pulse = overridePulse >= 0 ? overridePulse : IMSEProc.getMetaPulse(connectedSource);
		}catch(RuntimeException e){
			pulse = 0;
		}
		
		if(experiment == null || experiment.length() <= 0){
			experiment = "AUG";
		}
		
		//try reading from this specific pulse first
		String pointNames[] = null;
		double pointData[][] = null;
		try{	
			GMDSSignal sig;
			GMDSSignalDesc sigDesc;
			
			sigDesc = new GMDSSignalDesc(pulse, experiment, "TimeAvg/pointData");		
			sig = (GMDSSignal)IMSEProc.globalGMDS().getSig(sigDesc);
			pointData = (double[][])sig.get2DData();
			
			try{
				sigDesc = new GMDSSignalDesc(pulse, experiment, "TimeAvg/pointNames");		
				sig = (GMDSSignal)IMSEProc.globalGMDS().getSig(sigDesc);
				pointNames = (String[])sig.get1DData();
				
			}catch(RuntimeException e){
				//hack for HDF5 string reading problem = regenerate now and save as netCDF from now on
				pointNames = new String[pointData.length];
				for(int i=0; i <pointNames.length; i++){
					pointNames[i] = "point_" + i;
				}				
			}
			
			
		}catch(RuntimeException e){
			System.err.println("Transform: Couldn't load points for exp " + experiment +
									", pulse " + pulse + "because: "+ e.getMessage() + ". Trying pulse 0");
			
			try{
				GMDSSignalDesc sigDesc = new GMDSSignalDesc(0, experiment, "TimeAvg/pointNames");		
				GMDSSignal sig = (GMDSSignal)IMSEProc.globalGMDS().getSig(sigDesc);
				pointNames = (String[])sig.get1DData();
				
				sigDesc = new GMDSSignalDesc(0, experiment, "TimeAvg/pointData");		
				sig = (GMDSSignal)IMSEProc.globalGMDS().getSig(sigDesc);
				pointData = (double[][])sig.get2DData();
				
			}catch(RuntimeException e2) {
				System.err.println("Transform: Couldn't load points from pulse 0.");
				return;
			}
		}
			
		cfg.clear();
		for(int i=0; i < pointNames.length; i++){
			cfg.put(pointNames[i], pointData[i]);
		}
	
		
		updateAllControllers();
		calc();
		*/
	}
	
	public void invalidate() {
		settingsChanged = true;
	}

}
