/***********************************************************************



	File:	derivativeEMST.c

	A mex file for Matlab (c)

	The calling syntax is:

			

			  [dLength] = derivativeEMST(rr, drr, E_in, lamda)

	

	 E_in  = 2 -by- Number of input edges							<Edges (referenced by start & end indices -Matlab index, i.e. starts at 1- of the points) to be included in the EMST>	

	 rr	= Point dimension -by- Number of Samples					<Sample values>

	 drr   = Point Dimension/2 * der_dim -by- Number of Samples		<Derivative of sample values w.r.t transformation parameters>

	 lamda = Double type scalar										<Exponential weight on Eucliden edges>



	 dLength = der_dim -by- Num Of Samples							<Derivative of length w.r.t x cooridinate of each sample>

	 	

	June 2005



	Copyright (c) 2005 by Mert Rory Sabuncu





************************************************************************/



#include <math.h>

#include <stddef.h>

#include <stdio.h>

#include <stdlib.h>

#include "mex.h"



#define NOT	!

#define AND	&&

#define OR	||

#define EQ	==

#define NE	!=

#define MIN_DISTANCE	1e-5



/*

 * The main routine.  Read in graph, compute derivatives, output it.

 */



void

mexFunction(

			int nlhs, 

			mxArray *plhs[], 

			int nrhs, 

			const mxArray *prhs[])

{



	//Declarations



	const double *  sampleValues = mxGetPr(prhs[0]);

	const int		dim = mxGetM(prhs[0]);

	const int		npoints = mxGetN(prhs[0]);



	const double *  derivativeValues = mxGetPr(prhs[1]);

	const int       npoints2 = mxGetN(prhs[1]);

	const int		dim2 = mxGetM(prhs[1]);



	const double *  edges = mxGetPr(prhs[2]);

	const int       nedges = mxGetN(prhs[2]);

	const int       edgeCheck = mxGetM(prhs[2]);



	const double	lamda = mxGetScalar(prhs[3]);



	int der_dim  = 2 * dim2/dim;



	int				StartPointIndex, EndPointIndex;	

	double			coef;

	int				i,j,k;

	double			* pi, * pj, * dpi, * dpj;

	double			* dLength;

	double			distance, delta;

	

	if (nrhs != 4)  mexErrMsgTxt("Must have four input arguments");



	if (npoints != npoints2) mexErrMsgTxt("Number of sample values and derivative values don't match!");



	if (dim == 1) mexErrMsgTxt("Dimension of input points should be larger than 1!");

	

	if (lamda == 0) mexErrMsgTxt("The exponential weight should be nonzero");



	plhs[0] = mxCreateDoubleMatrix(1, der_dim, mxREAL);

	

	dLength = mxGetPr(plhs[0]);

	



	for (i = 0; i < nedges; i++){



		StartPointIndex = (int) edges[2*i] - 1;

		EndPointIndex = (int) edges[2*i + 1] - 1;



	

		pi = &sampleValues[StartPointIndex * dim];

		pj = &sampleValues[EndPointIndex * dim];

		

		dpi = &derivativeValues[StartPointIndex * dim2];

		dpj = &derivativeValues[EndPointIndex * dim2];

		distance = 0.0;

		

		for (k = 0; k < dim; k++){

			delta = pi[k] - pj[k];

			distance += (delta * delta);

		}

		distance = distance > MIN_DISTANCE ? distance:MIN_DISTANCE;

		

		for (j = 0; j < der_dim; j++){

			for(k = 0; k < dim/2; k++){



				delta = pi[k + (int)dim/2] - pj[k + (int)dim/2];

				coef  = lamda * pow(distance, lamda/2.0 - 1.0) * delta;

				dLength[j] += coef * (dpi[j * dim /2 + k] - dpj[j * dim/2 + k]);				

			}

		}

	}



	return;

}

