package imseProc.proc.imgFit;

import java.util.Arrays;
import java.util.HashMap;
import org.eclipse.swt.widgets.Composite;
import signals.gmds.GMDSSignal;
import descriptors.gmds.GMDSSignalDesc;
import imseProc.core.ByteBufferImage;
import imseProc.core.IMSEProc;
import imseProc.core.ImagePipeController;
import imseProc.core.Img;
import imseProc.core.ImgPipe;
import imseProc.core.ImgSink;
import imseProc.core.ImgSource;

/** Image transform processor.
 * Fitting a cubic interpolation to each image with defined valid and non valid regions
 * 
 * @author oliford
 */
public class ImageFitProcessor extends ImgPipe implements ImgSource, ImgSink {
	
	private ByteBufferImage imagesOut[] = new ByteBufferImage[0];
	private int imagesInChangeID[] = new int[0];
		
	private ImageFitSWTController controller;
	boolean settingsChanged = false;
		
	/** Image fitting mask */
	private boolean mask[][] = new boolean[100][100];
	
	private boolean isIdle = false;
	
	private int nKnotsX = 5, nKnotsY = 5;
	
	private CubicMaskedImageFitter fitter = new CubicMaskedImageFitter();
		
	/** How far through the image series doCalc() currently is */
	private int calcProgress = -1;
	private boolean autoCalc;
	
	private boolean initFromImage = false;
	private boolean performFit = false;
	private int outputMaskMode = 0;
	
	private static final int OUTPUT_MASK_NONE = 0;	
	private static final int OUTPUT_MASK_ORIG = 1;
	private static final int OUTPUT_INVMASK_ORIG = 2;
	private static final int OUTPUT_MASK_FIT = 3;
	private static final int OUTPUT_INVMASK_FIT = 4;
	
	
	private boolean singleImage = false;
			
	public ImageFitProcessor() {
		for(int y=0; y < mask.length; y++)
			for(int x=0; x < mask[0].length; x++)
				mask[y][x] = true;
	}
	
	public void calc() {
		isIdle = false;
		IMSEProc.ensureFinalUpdate(this, new Runnable() { @Override public void run() { doCalcScan(); } });
	}
	
	/** Guaranteed to be called only once at a time */
	private void doCalcScan(){
		if(connectedSource == null){
			imagesOut = new ByteBufferImage[0];
			imagesInChangeID = new int[0];
			seriesMetaData = new HashMap<String, Object>();
			return;
		}
		
		//seriesMetaData = connectedSource.getSeriesMetaDataMap();
		seriesMetaData = new HashMap<String, Object>();
		
		boolean settingsHadChanged = settingsChanged;
		settingsChanged = false;
		
		
		fitter.setSize(nKnotsX, nKnotsY);			
		
		
		int nImages = singleImage ? 1 : connectedSource.getNumImages();
		
		if(imagesOut.length != nImages){
			for(int i=nImages; i < imagesOut.length; i++){
				if(imagesOut[i] != null)
					imagesOut[i].destroy();					
			}
			imagesOut = Arrays.copyOfRange(imagesOut, 0, nImages);
			imagesInChangeID = Arrays.copyOfRange(imagesInChangeID, 0, nImages);
		}
		
		for(int i=0; i < nImages; i++){
			
			calcProgress = i;
			updateAllControllers();
			
			Img imageIn = connectedSource.getImage(singleImage ? getSelectedSourceIndex() : i);
			doCalc(i, imageIn, settingsHadChanged);			
			
			if(settingsChanged)
				break;
		}
		
		calcProgress = -1;
		updateAllControllers();
		
		isIdle = true;
	}
	
	public void doCalc(int imgIdx, Img imageIn, boolean settingsHadChanged) {
				
 		if(imageIn == null || !imageIn.isRangeValid()){
			if(imagesOut[imgIdx] != null){
				//imagesOut[imgIdx] = null;
				imagesOut[imgIdx].invalidate();
				notifyImageSetChanged();
			}			
			return;
		}
		
		int newChangeID = imageIn.getChangeID();
		if(newChangeID == imagesInChangeID[imgIdx] && !settingsHadChanged){
			return; //not changed
		}
		
		try{
			int width = imageIn.getWidth();
			int height = imageIn.getHeight();
								
			//overwrite the data and do an in-image update if possible 
			boolean setChanged = false;
			if(imagesOut[imgIdx] == null || imagesOut[imgIdx].getWidth() != width || imagesOut[imgIdx].getHeight() != height){
				if(imagesOut[imgIdx] != null)
					imagesOut[imgIdx].destroy();				
					
				imagesOut[imgIdx] = new ByteBufferImage(this, imgIdx, width, height, ByteBufferImage.DEPTH_DOUBLE, false);
				
				setChanged = true;
			}
			
			fitter.setMask(mask);

			imageIn.startReading();
			try{
				fitter.fit(imageIn, initFromImage, performFit);
				
				imagesOut[imgIdx].startWriting();			
				try{
					for(int iY=0; iY < height; iY++){
						for(int iX=0; iX < width; iX++){
							int mX = (int)((double)iX * mask[0].length / width + 0.5);
							int mY = (int)((double)iY * mask.length / height + 0.5);
							if(mX < 0) mX=0;
							if(mX >= mask[0].length) mX=mask[0].length-1;						
							if(mY < 0) mY=0;
							if(mY >= mask.length) mY=mask.length-1;		
							
							double val;
							
							if(outputMaskMode == OUTPUT_MASK_ORIG){
								val = mask[mY][mX] ? imageIn.getPixelValue(iX, iY) : 0;
								
							}else if(outputMaskMode == OUTPUT_INVMASK_ORIG){
								val = !mask[mY][mX] ? imageIn.getPixelValue(iX, iY) : 0;
								
							}else{
								val = fitter.evalFittedImage(iX, iY);
							
								if((outputMaskMode == OUTPUT_MASK_FIT && !mask[mY][mX]) ||
									(outputMaskMode == OUTPUT_INVMASK_FIT && mask[mY][mX]) ){
										val = Double.NaN;
								}
							}
							
							imagesOut[imgIdx].setPixelValue(iX, iY, val);
						}
					}
				}finally{
					imagesOut[imgIdx].endWriting();
				}
			}finally{
				imageIn.endReading();
			}
				
			imagesOut[imgIdx].imageChanged(false);
			imagesInChangeID[imgIdx] = newChangeID;
				
			if(setChanged){
				notifyImageSetChanged();
			}
									
			//System.out.println("ImageFFTProcessor created/updated image: " + imageOut.toString());
						
		}catch (InterruptedException e) {
			System.err.println("ImageFFTProcessor.doCalc(): Interrupted waiting for image lock");
			imagesOut[imgIdx] = null;
			notifyImageSetChanged();
			return;
		}
		
	}
	
	@Override
	public void notifySourceChanged() {
		super.notifySourceChanged();
		if(autoCalc)
			calc();
	}
	
	@Override
	public void imageChanged(int idx) { 
 		if(autoCalc)
			calc();
 	}
	
	@Override
	public int getNumImages() {
		return (imagesOut != null) ? imagesOut.length : 0;
	}

	@Override
	public Img getImage(int imgIdx) {
		return (imgIdx >= 0 && imgIdx < imagesOut.length) ? imagesOut[imgIdx] : null;
	}
	
	@Override
	public Img[] getImageSet() { return imagesOut; }

	@Override
	public ImgSource clone() { return new ImageFitProcessor(); }

	@Override
	/** For the sink side */
	public ImagePipeController createPipeController(Class interfacingClass, Object args[], boolean asSink) {
		if(interfacingClass == Composite.class){
			controller = new ImageFitSWTController((Composite)args[0], (Integer)args[1], this, asSink);
			controllers.add(controller);
			return controller;
		}
		return null;
	}
			
	public void setAutoCalc(boolean autoCalc) {
		if(!this.autoCalc && autoCalc){
			this.autoCalc = true;
			updateAllControllers();
			calc();			
		}else{
			this.autoCalc = autoCalc;
			updateAllControllers();
		}
	}
	public boolean getAutoCalc() { return this.autoCalc; }

	public int getCalcProgress() { return calcProgress; }
	
	@Override
	public void setSource(ImgSource source) {
		super.setSource(source);
		calc();
	}
	
	@Override
	public void destroy() {
		for(int i=0; i < imagesOut.length; i++){
			if(imagesOut[i] != null)
				imagesOut[i].destroy();
		}
		
		//relase out memory just in case someone holds on to us
		//but if we actually null these, doCalc()s still in the calc queue can get caught out
		imagesOut = new ByteBufferImage[0];
		imagesInChangeID = new int[0];
		super.destroy();
	}
	
	public boolean[][] getMask(){ return mask; }
	
	public void setMaskRect(double x0, double y0, double x1, double y1, boolean maskOn){
		int nX = mask[0].length;
		int nY = mask.length;
		int iX0 = (int)(x0 * nX + 0.5); 
		int iY0 = (int)(y0 * nY + 0.5);
		int iX1 = (int)(x1 * nX + 0.5); 
		int iY1 = (int)(y1 * nY + 0.5);

		if(iX0 < 0) iX0 = 0;
		if(iX1 >= nX) iX1 = nX - 1;
		if(iY0 < 0) iY0 = 0;
		if(iY1 >= nY) iY1 = nY - 1;
		
		
		for(int iY = iY0; iY <= iY1; iY++){
			for(int iX = iX0; iX <= iX1; iX++){
				mask[iY][iX] = maskOn;
			}
		}
		settingsChanged = true;
		calc();
	}
	
	public void setMaskCircle(double x0, double y0, double radius, boolean maskOn){
		int nX = mask[0].length;
		int nY = mask.length;
		
		
		int iX0 = (int)((x0-radius) * nX - 0.5); 
		int iY0 = (int)((y0-radius) * nY - 0.5);
		int iX1 = (int)((x0+radius) * nX + 0.5); 
		int iY1 = (int)((y0+radius) * nY + 0.5);
		
		if(iX0 < 0) iX0 = 0;
		if(iX1 >= nX) iX1 = nX - 1;
		if(iY0 < 0) iY0 = 0;
		if(iY1 >= nY) iY1 = nY - 1;
		
		double aSq = radius*radius;
		for(int iY = iY0; iY <= iY1; iY++){
			for(int iX = iX0; iX <= iX1; iX++){
				double x = (double)iX / nX;   
				double y = (double)iY / nY;
				
				double rSq = (x-x0)*(x-x0) + (y-y0)*(y-y0);
				if(rSq < aSq){
					mask[iY][iX] = maskOn;
				}
			}	
		}
		settingsChanged = true;
		calc();
	}
	

	public void saveMask(){
		if(connectedSource == null) return;
		
		byte maskData[][] = new byte[mask.length][mask[0].length];
		
		for(int y=0; y < mask.length; y++){			
			for(int x=0; x < mask[0].length; x++){
				maskData[y][x] = (byte)(mask[y][x] ? -1 : 0);
			}
		}
		
		String experiment = IMSEProc.getMetaExp(connectedSource);
		int pulse = IMSEProc.getMetaPulse(connectedSource);
		
		try{	
			GMDSSignalDesc sigDesc = new GMDSSignalDesc(pulse, experiment, "ImageFit/mask");
			GMDSSignal sig = new GMDSSignal(sigDesc, maskData);
			IMSEProc.globalGMDS().writeToCache(sig);
			
		}catch(RuntimeException e){
			System.err.println("Transform: Error saving mask to exp " + experiment + 
								", pulse " + pulse + " because " + e.getMessage());
		}
	}
	
	public void loadMask(int overridePulse) {
		if(connectedSource == null) return;
		
		String experiment = IMSEProc.getMetaExp(connectedSource);
		int pulse ;
		try{
			pulse = overridePulse >= 0 ? overridePulse : IMSEProc.getMetaPulse(connectedSource);
		}catch(RuntimeException e){
			pulse = 0;
		}
		
		if(experiment == null || experiment.length() <= 0){
			experiment = "AUG";
		}
		
		//try reading from this specific pulse first		
		byte maskData[][] = null;
		try{	
			
			GMDSSignalDesc sigDesc = new GMDSSignalDesc(pulse, experiment, "ImageFit/mask");		
			GMDSSignal sig = (GMDSSignal)IMSEProc.globalGMDS().getSig(sigDesc);
			maskData = (byte[][])sig.get2DData();
						
		}catch(RuntimeException e){
			System.err.println("Transform: Couldn't load mask for exp " + experiment +
									", pulse " + pulse + "because: "+ e.getMessage() + ". Trying pulse 0");
			
			try{
				GMDSSignalDesc sigDesc = new GMDSSignalDesc(0, experiment, "ImageFit/mask");		
				GMDSSignal sig = (GMDSSignal)IMSEProc.globalGMDS().getSig(sigDesc);
				maskData = (byte[][])sig.get2DData();				
				
			}catch(RuntimeException e2) {
				System.err.println("Transform: Couldn't load points from pulse 0.");
				return;
			}
		}

		mask = new boolean[maskData.length][maskData[0].length];				
		for(int y=0; y < mask.length; y++){			
			for(int x=0; x < mask[0].length; x++){
				mask[y][x] = maskData[y][x] != 0;
			}
		}
		
		updateAllControllers();
		calc();
	}
	
	public void invalidate() {
		settingsChanged = true;
	}

	public void setNKnots(int nKnotsX, int nKnotsY) {
		this.nKnotsX = nKnotsX;
		this.nKnotsY = nKnotsY;
 		settingsChanged = true;
		if(autoCalc){
			calc();
		}
	}

	public void setModes(boolean initFromImage, boolean performFit, boolean singleImage, int outputMaskMode) {
		this.initFromImage = initFromImage;
		this.performFit = performFit;
		this.singleImage  = singleImage;
		this.outputMaskMode = outputMaskMode;
		settingsChanged = true;
		if(autoCalc)
			calc();		
	}
	
	@Override
	public boolean isIdle() { return isIdle; 	}
}
