package otherSupport;

import java.util.Random;

/** Generate samples from a truncated univariate Gaussian
 * 
 * Algorithm from "Efficient Simulation from the Multivariate Normal and Student-t Distributions Subject to Linear Constraints and the Evaluation of Constraint Probabilities"
 * J.Geweke 1991 - <www.biz.uiowa.edu/faculty/ jgeweke/papers/paper47/paper47.pdf> - has typo's!
 * 
 * Also, there's (at time of writing) a better description here: <http://athens.src.uchicago.edu/jenni/econ319_2003/042202_TA_sess.pdf>
 * "Pseudo Random Variable Generation (TA Session, Econ 319,4/22/02)" - appears to be some lecture notes.
 * 
 * @author oliford <codes@oliford.co.uk>
 * 
 */
public class TruncatedUnivarGauss extends Random {
	private final double r2p = Math.sqrt(2*Math.PI);
	private final double t1 = 0.150;
	private final double t2 = 2.18;
	private final double t3 = 0.725;
	private final double t4 = 0.45;
	
	private class TooManyAttemptsException extends Exception{ public TooManyAttemptsException() { super("Too many rejections in rejection sampling."); } };
	
	/** Maximum number of rejections before changing the way we're doing it */ 
	public static final int maxAttempts=10000; 
	public int method;
	
	public final double nextTruncatedGaussian(double mean, double sigma, double min, double max){
		double l = (min - mean) / sigma;
		double u = (max - mean) / sigma;
		return mean + sigma * nextTruncatedGaussian(l, u);
	}
	
	public final double nextTruncatedGaussian(double a, double b){		
		if(b <= a)
			throw new IllegalArgumentException("nextTruncatedGaussian must have b > a (a="+a+" b="+b+")");
		
		method = 0;
		
		try{
			if(Double.isInfinite(a)){
				if(Double.isInfinite(b)) 			//both infinite, use the non-truncated normal
					return nextGaussian();
				else{ 								//a infinite, b finite
					if(b >= -t4)return normalRejectionSample(a, b);
					else return -exponentialRejectionSample(-b, -a); //swap the limits and swap the answer
					
				}		
			}else{
				if(Double.isInfinite(a)){			//a finite, b infinite
					if(a <= t4)return normalRejectionSample(a, b);
					else return exponentialRejectionSample(a, b);
				}else{								//both finite				
					if(a > 0){ //both positive
						double r = gauss(a) / gauss(b);
						if(r <= t2)return uniformRejectionSample(a, b);
						else if(a < t3) return halfNormalRejectionSample(a, b);
						else return exponentialRejectionSample(a, b);
					}
					else if(b < 0){ //both negative, opposite of above
						double r = gauss(b) / gauss(a);
						if(r <= t2)return uniformRejectionSample(a, b);
						else if(b > -t3)return -halfNormalRejectionSample(-b, -a);
						else return -exponentialRejectionSample(-b, -a);
					}
					else{ 					// 0 in [a,b]
						//TODO: possibly replace with 
						// if(a <= -0.3987 || b > 0.3987)
						if(gauss(a) <= t1 || gauss(b) <= t1)return normalRejectionSample(a, b);
						else return uniformRejectionSample(a, b); //uniform if they're both inside
					}
				}
			}
			
		} catch(TooManyAttemptsException err) {
			System.err.println("WARNING: Too many attempts in Trucnated 1D Gaussian sampler (a="+a+", b="+b+").");
			//have a go at uniform rejection, if that fails try normal, if that fails something is wrong and we just give up altogether			
			try {
				return uniformRejectionSample(a, b);				
			} catch(TooManyAttemptsException err2) {
				try {
					return normalRejectionSample(a, b);
				} catch(TooManyAttemptsException err3) {
					throw new RuntimeException("Too many attempts in Trucnated 1D Gaussian sampler (a="+a+", b="+b+".)");
				}
			}
		}
		
	}
	
	/** Sample gaussian, reject if outside limits 
	 * @throws TooManyAttemptsException */
	public final  double normalRejectionSample(double a, double b) throws TooManyAttemptsException{
		method=1;
		double x;
		int attempts = 0;
		do{
			x = nextGaussian();
			attempts++;
			if(attempts > maxAttempts)
				throw new TooManyAttemptsException();
		}while(x < a || x > b);
		return x;
	}
	
	/** Sample half gaussian, reject if outside limits 
	 * @throws TooManyAttemptsException */
	private final double halfNormalRejectionSample(double a, double b) throws TooManyAttemptsException{
		method=2;
		double x;
		int attempts = 0;
		do{
			x = StrictMath.abs(nextGaussian());
			attempts++;
			if(attempts > maxAttempts)
				throw new TooManyAttemptsException();
		}while(x < a || x > b);		
		return x;		
	}
	
	/** Sample uniform between limits, reject based on gaussian amplitude at sampled value 
	 * @throws TooManyAttemptsException */
	private final  double uniformRejectionSample(double a, double b) throws TooManyAttemptsException{
		method=3;
		double u,x;
		double gaussMax;
		int attempts = 0;
		
		if(a > 0)
			gaussMax = gauss(a);
		else{
			if(b < 0)gaussMax = gauss(b); 
			else gaussMax = 1.0 / r2p;
		}
		
		do{
			x = a + nextDouble() * (b-a); //draw x from U[a,b] (ish)
			u = nextDouble();			// draw u from U[0,1] (ish)	.
			attempts++;
			if(attempts > maxAttempts)
				throw new TooManyAttemptsException();
		}while(u > (gauss(x) / gaussMax));
		
		return x;
	}

	/** Samples from the positive half of a truncated gaussian using exponential rejection sampling 
	 * @throws TooManyAttemptsException */
	public final double exponentialRejectionSample(double a, double b) throws TooManyAttemptsException{
		method=4;
		double u,x;
		int attempts = 0;
		
		do{
			do{
				x = a + nextExponential(a);
				attempts++;
				if(attempts > maxAttempts)
					throw new TooManyAttemptsException();
			}while( x > b); 
			u = nextDouble();			
		}while(u > StrictMath.exp(- 0.5 * (x*x + a*a) + a*x));
		
		return x;
	}
	
	/** samples from exp( - l a ) for l > 0 using inverse C.D.F method */  
	public final double nextExponential(double l){
		return -StrictMath.log( 1 - nextDouble()) / l;		
		
	}
	
	private final double gauss(double x){
		return Math.exp(-(x*x) / 2.0) / r2p;
		
	}

}
