package imseProc.core;


import imseProc.core.Img.RangeValidity;


import java.lang.ref.SoftReference;
import java.lang.ref.WeakReference;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.IntBuffer;
import java.nio.ShortBuffer;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

import otherSupport.bufferControl.DirectBufferControl;

import algorithmrepository.exceptions.NotImplementedException;

/** Image described by byte data in a (preferably direct) byte buffer */
public class ByteBufferImage extends Img {
	public static final int DEPTH_DOUBLE = -64;
	public static final int DEPTH_FLOAT = -32;
	
	/** The ByteBuffer for this actual image */
	protected ByteBuffer imageBuffer;       

	/** The byteBuffer actually allocated, of which imageBuffer is a part and this
	 * images entry number into it.
	 * If null, then imageBuffer was allocated itself. */ 
	private ByteBuffer allocBuffer;
	/** List of all images still using this alloc buffer */ 
	private LinkedList<ByteBufferImage> allImagesInAlloc; 

	/** Bit depth of image, -ve for floating point (so -64 for double, -32 for float) */
	protected int bitDepth;

	/** Checks whether the given set of images all match the given properties
	 * If not, all the images in the set are destroyed and a new matching set is created
	 *  
	 * @param imgs
	 * @param width
	 * @param height
	 * @param bitDepth
	 * @param nImgs
	 * @param init
	 * @param byteOrder
	 * @return
	 */
	public static ByteBufferImage[] checkBulkAllocation(ImgSource source, ByteBufferImage imgs[], int width, int height, int bitDepth, int nImgs, ByteOrder byteOrder){
		if(imgs != null){
			if(imgs.length == nImgs){
				boolean allMatched = true;

				for(int i=0; i < nImgs; i++){
					if(imgs[i] == null || imgs[i].destroyed || imgs[i].getWidth() != width || imgs[i].getHeight() != height 
							|| imgs[i].bitDepth != bitDepth || imgs[i].imageBuffer.order() != byteOrder){
						allMatched = false;
						break;
					}
				}
				if(allMatched){
					//ok, we can re-use		
					return imgs;
				}
			}

			//match failed, need to reallocate
			for(int i=0; i < imgs.length; i++){
				if(imgs[i] != null)
					imgs[i].destroy();
			}
		}

		System.gc(); //good time for this, since the last destory() should have freed lots of stuff
		return allocateBulk(source, width, height, bitDepth, nImgs, false, byteOrder);          
	}

	/**
	 * Allocate memory for multiple ByteBufferImages in as large single contiguous chunks as possible.
	 * 
	 * 
	 * @param width
	 * @param height
	 * @param bitDepth
	 * @param nImgs
	 * @param init
	 * @param byteOrder
	 * @param allocBuffers	If non-null, filled with the actual ByteBuffers allocated
	 * @return
	 */
	public static ByteBufferImage[] allocateBulk(ImgSource source, int width, int height, int bitDepth, int nImgs, boolean init, ByteOrder byteOrder){
		return allocateBulk(source, 0, width, height, bitDepth, nImgs, init, byteOrder);
	}

	private static ByteBufferImage[] allocateBulk(ImgSource source, int startIdx, int width, int height, int bitDepth, int nImgs, boolean init, ByteOrder byteOrder){
		int bytesPerPixel = getBytesPerPixel(bitDepth);
		int bytesPerImage = width*height*bytesPerPixel;
		long totalBytes = (long)bytesPerImage * nImgs;

		if(totalBytes > Integer.MAX_VALUE){
			//bisect into two allocations
			int n1 = nImgs / 2;
			int n2 = nImgs - n1;
			ByteBufferImage set1[] = null, set2[] = null;

			try{
				set1 = allocateBulk(source, startIdx, width, height, bitDepth, n1, init, byteOrder);
				set2 = allocateBulk(source, startIdx+n1, width, height, bitDepth, n2, init, byteOrder);
			}finally{
				//IF set OOMs and set1 didn't we need to free it, but don't really need to intercept the exception here
				if(set2 == null && set1 != null){
					for(int i=0; i < set1.length; i++){
						if(set1[i] != null && !set1[i].isDestroyed()){
							set1[i].destroy();
						}
					}
				}
			}

			ByteBufferImage imgs[] = new ByteBufferImage[nImgs];
			System.arraycopy(set1, 0, imgs, 0, n1);
			System.arraycopy(set2, 0, imgs, n1, n2);
			return imgs;
		}
		
		ByteBuffer allBuff = DirectBufferControl.allocateDirect((int)totalBytes);
		System.out.println("ByteBufferImage.allocateBulk() allocated " + (totalBytes/1000000) + " MB");
		allBuff.order(byteOrder);


		ByteBufferImage imgs[] = new ByteBufferImage[nImgs];
		LinkedList<ByteBufferImage> imageList = new LinkedList<ByteBufferImage>();
		for(int i=0; i < nImgs; i++){                   
			allBuff.limit((i+1)*bytesPerImage);
			allBuff.position(i*bytesPerImage);
			ByteBuffer imgBuff = allBuff.slice();
			imgBuff.order(byteOrder);
			imgs[i] = new ByteBufferImage(source, startIdx+i, width, height, bitDepth, imgBuff, allBuff, imageList, init);
			imageList.add(imgs[i]);                 

		}

		return imgs;
	}

	public ByteBufferImage(ImgSource source, int sourceIndex, int width, int height, int bitDepth) {
		this(source, sourceIndex, width, height, bitDepth, true);

	}

	public ByteBufferImage(ImgSource source, int sourceIndex, int width, int height, int bitDepth, boolean init) {
		super(source, sourceIndex);

		this.bitDepth = bitDepth;
		this.width = width;
		this.height = height;
		//imageBuffer = DirectBufferControl.allocateDirect(width * height * getBytesPerPixel());
		imageBuffer = ByteBuffer.allocate(width * height * getBytesPerPixel());
		if(init){
			replaceData(imageBuffer);
		}else{
			min = Double.NaN; 
			max = Double.NaN;
			sum = Double.NaN;
		}

	}
	
	protected ByteBufferImage(ImgSource source, int sourceIndex, int width, int height, int bitDepth, 
			ByteBuffer imageBuffer, ByteBuffer allocBuffer, LinkedList<ByteBufferImage> allImagesInAlloc) {
		this(source, sourceIndex, width, height, bitDepth, imageBuffer, allocBuffer, allImagesInAlloc, false);
	}

	/** @param noInit If true, ranges are not calculated and no modification notification is made */
	protected ByteBufferImage(ImgSource source, int sourceIndex, int width, int height, int bitDepth, 
			ByteBuffer imageBuffer, ByteBuffer allocBuffer, LinkedList<ByteBufferImage> allImagesInAlloc, boolean init) {
		super(source, sourceIndex);

		this.width = width;
		this.height = height;
		this.bitDepth = bitDepth;
		this.imageBuffer = imageBuffer;
		this.allocBuffer = allocBuffer;
		this.destroyed = (imageBuffer == null);
		this.allImagesInAlloc = allImagesInAlloc;

		if(init){
			replaceData(imageBuffer);
		}else{
			min = Double.NaN; 
			max = Double.NaN;
			sum = Double.NaN;
		}
	}

	/** Copy constructor used by some derived classes e.g. ComplexImage */
	protected ByteBufferImage(ByteBufferImage img) {
		super(img.source, img.sourceIndex);
		this.width = img.width;
		this.height = img.height;
		this.bitDepth = img.bitDepth;
		this.imageBuffer = img.imageBuffer;
		this.destroyed = img.destroyed;
		if(img.allocBuffer == null){ //them and us are now sharing buffers
			img.allocBuffer = imageBuffer;
			img.allImagesInAlloc = new LinkedList<ByteBufferImage>();
			img.allImagesInAlloc.add(img);
		}               
		this.allocBuffer = img.allocBuffer;
		this.allImagesInAlloc = img.allImagesInAlloc;
		allImagesInAlloc.add(this);
	}

	public ByteBufferImage(ImgSource source, int sourceIndex, short shortArray[][]) {
		super(source, sourceIndex);

		this.height = shortArray.length;
		this.width = shortArray[0].length;
		this.bitDepth = 16;
				
		//ByteBuffer buf = DirectBufferControl.allocateDirect(width*height*getBytesPerPixel());
		ByteBuffer buf = ByteBuffer.allocate(width*height*getBytesPerPixel());
		buf.order(ByteOrder.LITTLE_ENDIAN);
		ShortBuffer sBuff = buf.asShortBuffer();
		for(int iY=0; iY < height; iY++) {			
			sBuff.put(shortArray[iY]);
		}
		replaceData(buf);
	}

	public ByteBufferImage(ImgSource source, int sourceIndex, int intArray[][]) {
		super(source, sourceIndex);

		this.height = intArray.length;
		this.width = intArray[0].length;
		this.bitDepth = 32;
				
		//ByteBuffer buf = DirectBufferControl.allocateDirect(width*height*getBytesPerPixel());
		ByteBuffer buf = ByteBuffer.allocate(width*height*getBytesPerPixel());
		buf.order(ByteOrder.LITTLE_ENDIAN);
		IntBuffer iBuff = buf.asIntBuffer();
		for(int iY=0; iY < height; iY++) {			
			iBuff.put(intArray[iY]);
		}
		replaceData(buf);
	}

	public ByteBufferImage(ImgSource source, int sourceIndex, double dblArray[][]) {
		super(source, sourceIndex);
		this.height = dblArray.length;
		this.width = dblArray[0].length;
		this.bitDepth = DEPTH_DOUBLE;

		//ByteBuffer buf = DirectBufferControl.allocateDirect(width*height*getBytesPerPixel());
		ByteBuffer buf = ByteBuffer.allocate(width*height*getBytesPerPixel());
		buf.order(ByteOrder.LITTLE_ENDIAN);
		DoubleBuffer dBuff = buf.asDoubleBuffer();
		for(int iY=0; iY < height; iY++) {                      
			dBuff.put(dblArray[iY]);
		}
		replaceData(buf);
	}

	public ByteBufferImage(ImgSource source, int sourceIndex, int width, int height, int bitDepth, byte data[]) {
		super(source, sourceIndex);

		this.width = width;
		this.height = height;
		this.bitDepth = bitDepth;
		replaceData(data);
	}
	
	public void replaceData(byte data[]){
		if(data.length != width * height * getBytesPerPixel())
				throw new IllegalArgumentException("Image data length " +data.length+ 
						" invalid for "+width+" x "+height+" x "+bitDepth+" image");
		imageBuffer = ByteBuffer.wrap(data);
		imageBuffer.position(0);
		imageBuffer.limit(imageBuffer.capacity());
		
		imageChanged(true);
	}
	
	public void replaceData(ByteBuffer data){
		if(data.capacity() != width * height * getBytesPerPixel())
				throw new IllegalArgumentException("Image data length " +data.capacity()+ 
						" invalid for "+width+" x "+height+" x "+bitDepth+" image");
		this.imageBuffer = data;
		imageBuffer.position(0);
		imageBuffer.limit(imageBuffer.capacity());
		imageChanged(true);
	}
	
	public int getBytesPerPixel() {
		return getBytesPerPixel(bitDepth);
	}
	
	public static int getBytesPerPixel(int bitDepth) {
		if(bitDepth > 0 && bitDepth <= 8){
			return 1;
		}else if(bitDepth > 8 && bitDepth <= 16){
			return 2;
		}else if(bitDepth > 16 && bitDepth <= 32){
			return 4;
		}else if(bitDepth == DEPTH_DOUBLE){
			return 8;
		}else
			throw new IllegalArgumentException("Invalid bit depth " + bitDepth);
	}

	public ByteBuffer getReadOnlyBuffer() {
		return destroyed ? null : imageBuffer.asReadOnlyBuffer();
	}
		
	public ByteBuffer getWritableBuffer() {
		if(!writing)
			throw new RuntimeException("getByteBuffer() requires write lock");

		return destroyed ? null : imageBuffer;
	}

	public int getBitDepth() {	return bitDepth; }

	@Override
	public double getMaxPossibleValue() {
		return (int)(Math.pow(2, bitDepth) - 1);
	}

	@Override
	public double getPixelValue(int x, int y) {
		if(destroyed)
			return Double.NaN;

		if(bitDepth > 0 && bitDepth <= 8){
			return (double)imageBuffer.get(y * width + x);

		}else if(bitDepth > 8 && bitDepth <= 16){
			short signedVal = imageBuffer.getShort((y * width + x) * 2);
			//return (double)(signedVal < 0 ? (0x10000 + signedVal) : signedVal);
			return (double)signedVal;

		}else if(bitDepth > 16 && bitDepth <= 32){
			return (double)imageBuffer.getInt((y * width + x) * 4);

		}else if(bitDepth == DEPTH_DOUBLE){
			return (double)imageBuffer.getDouble((y * width + x)*8);

		}else{
			throw new IllegalArgumentException("Invalid bit depth");
		}		
	}
	   
    public void setPixelValue(int x, int y, double value) {	
    	if(destroyed) return;
    	if(!writing)
    		throw new RuntimeException("Img.setPixelValue() called without write lock");

    	if(bitDepth > 0 && bitDepth <= 8){
    		imageBuffer.put(y * width + x, (byte)value);
			
    	}else if(bitDepth > 8 && bitDepth <= 16){
    		imageBuffer.putShort((y * width + x)*2, (short)value);

    	}else if(bitDepth > 16 && bitDepth <= 32){
    		imageBuffer.putInt((y * width + x)*4, (int)value);

    	}else if(bitDepth == DEPTH_DOUBLE){
    		imageBuffer.putDouble((y * width + x)*8, value);

    	}else{
			throw new IllegalArgumentException("Invalid bit depth");
		}
    	rangeValid = RangeValidity.invalid;
	}
	
	@Override
	/** More or less the same as the default implementation Img.calcRange()
	 * but does integer math in the hope that it's faster */
	public void calcRanges(){
		if(bitDepth == DEPTH_DOUBLE){
			super.calcRanges(); //it's double, so do the normal one
		}

		if(rangeValid == RangeValidity.inCalc)
			System.out.println("calcRanges() called during range calc");
		else if(rangeValid == RangeValidity.valid){
			//System.out.println("calcRanges() called when ranges already valid");
			return;
		}

		if(destroyed){
			min = Double.NaN; 
			max = Double.NaN;
			sum = Double.NaN;
			return;
		}
		
		rangeValid = RangeValidity.inCalc;
		int min = Integer.MAX_VALUE;
		int max = Integer.MIN_VALUE;
		long sum = 0;

		//do the math, working through the appropriate buffer
		// (we do it all in ints here)
		for(int i=0; i < height*width; i++) {
			int val;
			if(bitDepth > 0 && bitDepth <= 8){
				val = imageBuffer.get(i);
			}else if(bitDepth > 8 && bitDepth <= 16){
				val = imageBuffer.getShort(i*2);
			}else if(bitDepth > 16 && bitDepth <= 32){
				val = imageBuffer.getInt(i*4);
			}else{
				throw new IllegalArgumentException("Invalid bit depth");
			}

			if(val > max) max = val;
			if(val < min) min = val;
			sum += val;
		}	

		this.min = min;
		this.max = max;
		this.sum = sum;
		if(rangeValid == RangeValidity.inCalc) //if it hasn't been overwritten mid-calc
			rangeValid = RangeValidity.valid;
	
	}
	
	@Override
	public String toString() {
		return super.toString() + "x" + bitDepth + " [" + min + "-" + max + "]";
	}

	/** Expose as public, since the public may fiddle with the data and should tell us */
	public void imageChanged(boolean fast){
		super.imageChanged(fast);
	}
	
	@Override
	public void endWriting() {
		super.endWriting();     
		if(!destroyed){
			//and make sure the buffer is setup correctly
			imageBuffer.position(0);
			imageBuffer.limit(imageBuffer.capacity());
		}
	}

	@Override
	public void destroy() {
		if(isDestroyed())
			return;
		//we have to wait for others to stop writing, because it segfaults the JVM
		try{
			startWriting(); 
		}catch(InterruptedException e){ 
			System.err.println("WARNING: Interrupted during wait for write lock on image destroy, won't destroy it - MEMORY LEAK!");
			destroyed = true; //mark it anyway
			return; //but we can't risk freeing the buffer
		}		
		
		destroyed = true; //mark destroyed regardless of what happens		
		endWriting(); //we can already unlock, so others can continue, but they'll soon discover it's got no buffer

		//and now we can try to cleanup buffers
		if(allocBuffer != null){
			//see if no one is using the alloced buffer anymore, and if not, free it
			synchronized (allImagesInAlloc) {
				allImagesInAlloc.remove(this);                           
				if(allImagesInAlloc.size() == 0){
					if(allocBuffer.isDirect())
						DirectBufferControl.freeBuffer(allocBuffer);
				}
			}
			allocBuffer = null;

		}else if(imageBuffer != null && imageBuffer.isDirect()){
			DirectBufferControl.freeBuffer(imageBuffer);
		}

		imageBuffer = null;
		super.destroy(); //mark it destroyed   
	}

	@Override
	protected void finalize() throws Throwable {		
		destroy();
		super.finalize();
	}

	@Override
	public boolean isMemoryCompatible(Img img) {
		return img != null && (img instanceof ByteBufferImage) && img.getWidth() == width && img.getHeight() == height && bitDepth == ((ByteBufferImage)img).getBitDepth();
	}
	
	
}
