/***********************************************************************

	File:	volumeRigidRegisterEMSTwStochSampling.c

	A mex file for Matlab (c)
	The calling syntax is:
			  TParams = volumeRigidRegisterEMSTwStochSampling(vol1, vol2, TParams_init, numOfSamples, TStepSize, quantLevels, gamma, numOfIters)

   vol1 = <single> 3D volume -- fixed volume
	 vol2 = <single> 3D volume -- floating volume *** volumes should be same size ***
	 TParams_init = <double> 1 x 6 -- [tx, ty, tz, theta, omega, phi]
	 numOfSamples = <int> scalar
	 TStepSize  = <double> scalar
	 quantLevels = <int> scalar
	 gamma		 = <double> scalar \in (0,2)
	 numOfIters  = <int> scalar  

   TParams = 1 x 6 -- [tx, ty, tz, theta, phi, omega]

   Date:	June 2005

  Copyright (c) 2005-2008 by Mert Rory Sabuncu

************************************************************************/

#include <math.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include "mex.h"

#define NOT	!
#define AND	&&
#define OR	||

#define EQ	==
#define NE	!=

#define MIN_DISTANCE	1e-5

#define VERBOSE
static void stochSampleVolumes(const float * volume1_in,
                              const float * volume2_in, const int * size_in, const double * TParams_in, 
                   					  int numOfSamples_in, double * samples_out, double * derivatives_out);
static float linearInterpolateVolume(const float * volume, const int * size_in, float x, float y, float z);
static float interpAux(const float x, const float y, const float alpha); 
/*
 * The main routine.  Read in graph, compute derivatives, output it.
 */
void
mexFunction(
			int nlhs, 
			mxArray *plhs[], 
			int nrhs, 
			const mxArray *prhs[])
{
	/*Declarations*/

  const int *		VolSize			=	mxGetDimensions(prhs[0]);
	const int *		VolSize2		=	mxGetDimensions(prhs[1]);
	const int		numOfTParams	=   mxGetN(prhs[2]);

	const float *  Volume1    = (float *)   mxGetPr(prhs[0]);
	const float *  Volume2    =  (float *)  mxGetPr(prhs[1]);
	const double *  TParams_init =  mxGetPr(prhs[2]);
	const int		numOfSamples  =  mxGetScalar(prhs[3]);
	const double	TStepSize    =	mxGetScalar(prhs[4]);
	const double    QLevels	     =	mxGetScalar(prhs[5]);
	const double	gamma	     =	mxGetScalar(prhs[6]);
	const int       numOfIters   =  (int) (mxGetScalar(prhs[7]));

	double *		dLen;

  mxArray *		sampleBin ;
	mxArray *       derivativeBin;
  int nlhs1, nrhs1; /*Chances are nlhs and nrhs are already used by mexFunction*/

	mxArray *plhs1[2];
	mxArray *prhs1[4];

	double *  iSamples, * Derivatives;
	int		XSize = VolSize[0];
	int		YSize = VolSize[1];
	int		ZSize = VolSize[2];
	int		iter, i;

	double tx_cur, ty_cur, tz_cur, theta_cur, phi_cur, omega_cur;
	double dir_tx, dir_ty, dir_tz, dir_theta, dir_phi, dir_omega, norm;
	double * TParams_cur = (double *) malloc(6 * sizeof(double));
	double * TParams_final;
	double TParams_last[6];

	double cost_cur, cost_prev, delta;
  double cur_step = TStepSize, step_coef;

  srand(1000);
	nlhs1 = 2;        /* Two output requested from the mst_lenght function*/
	nrhs1 = 4;        /* Four inputs passed to the mst_length function */

	prhs1[2] = prhs[6];
	prhs1[3] = prhs[5];

	tx_cur		= TParams_cur[0] = TParams_init[0];
	ty_cur		= TParams_cur[1] = TParams_init[1];
	tz_cur		= TParams_cur[2] = TParams_init[2];
  theta_cur	= TParams_cur[3] = TParams_init[3];
	phi_cur		= TParams_cur[4] = TParams_init[4];
	omega_cur	= TParams_cur[5] = TParams_init[5];

  sampleBin = mxCreateDoubleMatrix(numOfSamples,2,mxREAL);
  iSamples = mxGetPr(sampleBin);
	derivativeBin = mxCreateDoubleMatrix(numOfSamples,6,mxREAL);
	Derivatives = mxGetPr(derivativeBin);

#ifdef VERBOSE
			mexPrintf("\n Number of iters: %d, numOfSamples: %d, gamma: %2.2f\n", numOfIters, numOfSamples, gamma);
#endif
	if (VolSize[0] != VolSize2[0])  mexErrMsgTxt("Volumes NOT same size!");
	if (VolSize[1] != VolSize2[1])  mexErrMsgTxt("Volumes NOT same size!");
	if (VolSize[2] != VolSize2[2])  mexErrMsgTxt("Volumes NOT same size!");
	if (numOfTParams != 6)	mexErrMsgTxt("Need a 1 x 6 TParams_init as input!");
  if (numOfSamples < 3) mexErrMsgTxt("Need at least 2 samples!");
	if (gamma == 0) mexErrMsgTxt("The exponential weight should be nonzero");
	if (TStepSize == 0) mexErrMsgTxt("TStepSize should be nonzero");
  if (numOfIters <= 0) mexErrMsgTxt("Nonpositive number of iterations!");

  plhs[0] = mxCreateDoubleMatrix(1, 6, mxREAL);
	TParams_final = mxGetPr(plhs[0]);

  for (iter = 0; iter < numOfIters ; iter++){
			stochSampleVolumes(Volume1,Volume2,VolSize,TParams_cur, numOfSamples, iSamples,Derivatives);
			prhs1[0] = sampleBin;
			prhs1[1] = derivativeBin;
			mexCallMATLAB(nlhs1,plhs1,nrhs1,prhs1, "mst_length");
			cost_cur = mxGetScalar(plhs1[0]);
			cost_prev = cost_cur;
			dLen     = mxGetPr(plhs1[1]);

			dir_tx = -dLen[0];
			dir_ty = -dLen[1];
			dir_tz = -dLen[2];
			dir_theta = -dLen[3];
			dir_omega = -dLen[5];
			dir_phi = -dLen[4];
			mxDestroyArray(plhs1[0]);      
			mxDestroyArray(plhs1[1]);      

      norm = sqrt(dir_tx * dir_tx + dir_ty * dir_ty + dir_tz * dir_tz +
						dir_theta * dir_theta + dir_phi * dir_phi + dir_omega * dir_omega);

      dir_tx = dir_tx/norm;
			dir_ty = dir_ty/norm;
			dir_tz = dir_tz/norm;
			dir_theta = dir_theta/norm;
			dir_phi = dir_phi/norm;
			dir_omega = dir_omega/norm;

      step_coef = 0.67 * cur_step * 3.14 / 180.0;

			tx_cur += step_coef * (YSize + ZSize) * dir_tx / 2.0;
      ty_cur += step_coef * (XSize + ZSize) * dir_ty / 2.0;
      tz_cur += step_coef * (XSize + YSize) * dir_tz / 2.0;
      theta_cur += cur_step * dir_theta;
      phi_cur += cur_step * dir_phi;
      omega_cur += cur_step * dir_omega;

			if (iter % 10 == 0)
			{
        if (iter > 0)
        {
          delta = (TParams_cur[0] - TParams_last[0]) * (TParams_cur[0] - TParams_last[0]) +
                  (TParams_cur[1] - TParams_last[1]) * (TParams_cur[1] - TParams_last[1]) +
                  (TParams_cur[2] - TParams_last[2]) * (TParams_cur[2] - TParams_last[2]) +
                  (TParams_cur[3] - TParams_last[3]) * (TParams_cur[3] - TParams_last[3]) + 
                  (TParams_cur[4] - TParams_last[4]) * (TParams_cur[4] - TParams_last[4]) + 
                  (TParams_cur[5] - TParams_last[5]) * (TParams_cur[5] - TParams_last[5]);

          if (delta < 0.1)
            break;
        }

        TParams_last[0] = TParams_cur[0];
        TParams_last[1] = TParams_cur[1];
        TParams_last[2] = TParams_cur[2];
        TParams_last[3] = TParams_cur[3];
        TParams_last[4] = TParams_cur[4];
        TParams_last[5] = TParams_cur[5];
      }

#ifdef VERBOSE
			if(!(iter % 50)){
				mexPrintf("TParams: tx = % 2.2f, ty = %2.2f, tz = %2.2f \n", tx_cur, ty_cur, tz_cur);
				mexPrintf("         theta = % 2.2f, phi = %2.2f, omega = %2.2f \n", theta_cur, phi_cur, omega_cur);
			};
#endif 
			TParams_cur[0] = tx_cur;
			TParams_cur[1] = ty_cur;
			TParams_cur[2] = tz_cur;
			TParams_cur[3] = theta_cur;
			TParams_cur[4] = phi_cur;
			TParams_cur[5] = omega_cur;
	};
	TParams_final[0] = tx_cur;
	TParams_final[1] = ty_cur;
	TParams_final[2] = tz_cur;
	TParams_final[3] = theta_cur;
	TParams_final[4] = phi_cur;
	TParams_final[5] = omega_cur;

	mxDestroyArray(derivativeBin);
	mxDestroyArray(sampleBin);
	free(TParams_cur);
	return;
};

static void stochSampleVolumes(const float * volume1_in,
						  const float * volume2_in, const int * size_in, const double * TParams_in, 
						  int numOfSamples_in, double * samples_out, double * derivatives_out
						  ){
	int	size_x = size_in[0];
	int size_y = size_in[1];
	int size_z = size_in[2];

  float tx = (float) TParams_in[0];
	float ty = (float) TParams_in[1];
	float tz = (float) TParams_in[2];
	float theta = (float) TParams_in[3];
	float phi = (float) TParams_in[4];
	float omega = (float) TParams_in[5];

	float xxx1, zzz1, yyy1,xxx2, zzz2, yyy2,xxx3, zzz3, yyy3, x_new, y_new, z_new;
	float ct, co, cp, st, sp, so;
	float centerX, centerY, centerZ;

	int locX, locY, locZ;
	double dx, dy, dz;
	double * rand_number;

	int sample_cnt = 0;
	double r, s1, s2;

  centerX = size_x/2.0f;
	centerY = size_y/2.0f;
	centerZ = size_z/2.0f;

  ct = (float) cos(theta/180.0*3.14);
  st = (float) sin(theta/180.0*3.14);
  cp = (float) cos(phi/180.0*3.14);
  sp = (float) sin(phi/180.0*3.14);         
  co = (float) cos(omega/180.0*3.14);
  so = (float) sin(omega/180.0*3.14);
  
  while(sample_cnt < numOfSamples_in)
  {
    s1 = 0.0;
    s2 = 0.0;

    while(s1 <= 0.1 || s2 <= 0.1)
    {
  	  r = (   (double)rand() / ((double)(RAND_MAX)+(double)(1)) );
      locX = (int) (size_x * r);
      locX = locX >= size_x ? size_x - 1:locX;

      r = (   (double)rand() / ((double)(RAND_MAX)+(double)(1)) );
      locY = (int) (size_y * r);
      locY = locY >= size_y ? size_y - 1:locY;

      r = (   (double)rand() / ((double)(RAND_MAX)+(double)(1)) );
      locZ = (int) (size_z * r);
      locZ = locZ >= size_z ? size_z - 1:locZ;

      s1 = (double) volume1_in[locZ*size_x*size_y+locY*size_x+locX];

      /* map point to new point*/
      xxx1 = (locX - centerX);
		  yyy1 = -so * (locZ - centerZ) + co * (locY - centerY);
		  zzz1 = co * (locZ - centerZ) + so * (locY - centerY);

      zzz2 = ct * zzz1 + st * xxx1;
		  yyy2 = yyy1;
		  xxx2 = -st * zzz1 + ct * xxx1;

		  zzz3 = zzz2;
		  yyy3 = cp * yyy2 + sp * xxx2;
		  xxx3 = -sp * yyy2 + cp * xxx2;

      x_new = xxx3 + centerX + tx;
		  y_new = yyy3 + centerY + ty;
		  z_new = zzz3 + centerZ + tz;

      s2 = (double) linearInterpolateVolume(volume2_in, size_in, x_new,y_new,z_new);
    }
   
    samples_out[sample_cnt] = s1;
		
    samples_out[numOfSamples_in + sample_cnt] = s2;
		dx = (double) (0.5 * (linearInterpolateVolume(volume2_in, size_in, x_new+1,y_new,z_new) 
			- linearInterpolateVolume(volume2_in, size_in, x_new-1,y_new,z_new)));

		dy = (double) (0.5 * (linearInterpolateVolume(volume2_in, size_in, x_new,y_new+1,z_new) 
			- linearInterpolateVolume(volume2_in, size_in, x_new,y_new-1,z_new)));

		dz = (double) (0.5 * (linearInterpolateVolume(volume2_in, size_in, x_new,y_new,z_new+1) 
			- linearInterpolateVolume(volume2_in, size_in, x_new,y_new,z_new-1)));

		derivatives_out[sample_cnt] = dx;
		derivatives_out[ numOfSamples_in + sample_cnt] = dy;
		derivatives_out[2 * numOfSamples_in + sample_cnt] = dz;

		/*d theta*/
		derivatives_out[3 * numOfSamples_in + sample_cnt] = (double) ((3.14/180.0*(
			(-st*co*(locZ - centerZ)-st*so*(locY - centerY)+ct*(locX - centerX)) * dz +
			(-sp*co*ct*(locZ - centerZ)-sp*ct*so*(locY - centerY)-sp*st*(locX - centerX)) * dy +
			(-cp*ct*co*(locZ - centerZ)-cp*ct*so*(locY - centerY)-st*cp*(locX - centerX)) * dx)));

		/* d phi*/
		derivatives_out[4 * numOfSamples_in + sample_cnt] = (double) ((3.14/180.0*(
			((sp*so-cp*st*co)*(locZ - centerZ)+(-sp*co-cp*st*so)*(locY - centerY)+(cp*ct)*(locX - centerX)) * dy + 
			((cp*so+sp*st*co)*(locZ - centerZ)+(-cp*co+sp*st*so)*(locY - centerY)+(-ct*sp)*(locX - centerX)) * dx)));

		/* d omega*/
		derivatives_out[5 * numOfSamples_in + sample_cnt] = (double) ((3.14/180.0*(
			((-ct*so)*(locZ - centerZ)+(ct*co)*(locY - centerY)) * dz +
			((-cp*co+sp*st*so)*(locZ - centerZ)+(-cp*so - sp*st*co)*(locY - centerY)) * dy +
			((sp*co+cp*st*so)*(locZ - centerZ)+(sp*so-cp*st*co)*(locY - centerY)) * dx)));

    sample_cnt++;
		}
	return;
	};

static float linearInterpolateVolume(const float * volume, const int * size_in, float x, float y, float z){

	int	size_xL = size_in[0];
	int size_yL = size_in[1];
	int size_zL = size_in[2];

	const float * dataL = volume;
	int x0 = (int)(floor(x));
  int y0 = (int)(floor(y));
  int z0 = (int)(floor(z));

	const float dx = x - x0;
	const float dy = y - y0;
	const float dz = z - z0;

	float v0, v1, v2, v3, v4, v5, v6, v7;
	float temp0;
	float temp1;
	float temp2;
	float temp3;
	float temp4;
	float temp5;
	int x1, y1, z1;

	if ((x0 < 0) || (y0 < 0) || (z0 < 0) || (z0 >= (size_zL - 1)) || (y0 >= (size_yL - 1)) || (x0 >= (size_xL - 1))) 
  {
		x1 = x0 < 0 ? 0:x0;
		x1 = x1 >= size_xL - 1 ? size_xL - 1:x1;
		y1 = y0 < 0 ? 0:y0;
		y1 = y1 >= size_yL - 1 ? size_yL - 1:y1;
		z1 = z0 < 0 ? 0:z0;
		z1 = z1 >= size_zL - 1 ? size_zL - 1:z1;
		return (dataL[z1*size_xL*size_yL+y1*size_xL+x1]);
	}; 
	if ((dx == 0) && (dy == 0) && (dz == 0))
		return (dataL[z0*size_xL*size_yL+y0*size_xL+x0]); 
	else {
  	v0 = dataL[z0*size_xL*size_yL+y0*size_xL+x0];
		z0++;
		v1 = dataL[z0*size_xL*size_yL+y0*size_xL+x0];
		y0++;
		v2 = dataL[z0*size_xL*size_yL+y0*size_xL+x0];
		z0--;
		v3 = dataL[z0*size_xL*size_yL+y0*size_xL+x0];
		x0++;
		y0--;
		v4 = dataL[z0*size_xL*size_yL+y0*size_xL+x0];
		z0++;
		v5 = dataL[z0*size_xL*size_yL+y0*size_xL+x0];
		y0++;
		v6 = dataL[z0*size_xL*size_yL+y0*size_xL+x0];
		z0--;
		v7 = dataL[z0*size_xL*size_yL+y0*size_xL+x0];
	}

  temp0 = interpAux(v0,v1,dz);
	temp1 = interpAux(v3,v2,dz);
	temp2 = interpAux(v4,v5,dz);
	temp3 = interpAux(v7,v6,dz);
	temp4 = interpAux(temp0,temp1,dy);
	temp5 = interpAux(temp2,temp3,dy);
	return (interpAux(temp4,temp5,dx));
};

static float interpAux(const float x, const float y, const float alpha)
{      
	return (x + alpha * (y - x)); 
};

