/*******************************************************************
 *                                                                 *
 *  Range Image Registration Testbed Toolkit                       *
 *  by Gerald Dalley                                               *
 *  Copyright (C) 2001 The Ohio State University                   *
 *                                                                 *
 *******************************************************************/
/*=========================================================================

  Program:   Visualization Toolkit
  Language:  C++
  
The authors hereby grant permission to use, copy, and distribute this
software and its documentation for any purpose, provided that existing
copyright notices are retained in all copies and that this notice is included
verbatim in any distributions. Additionally, the authors grant permission to
modify this software and its documentation for any purpose, provided that
such modifications are not distributed without the explicit consent of the
authors and that existing copyright notices are retained in all copies. Some
of the algorithms implemented by this software are patented, observe all
applicable patent law.

IN NO EVENT SHALL THE AUTHORS OR DISTRIBUTORS BE LIABLE TO ANY PARTY FOR
DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT
OF THE USE OF THIS SOFTWARE, ITS DOCUMENTATION, OR ANY DERIVATIVES THEREOF,
EVEN IF THE AUTHORS HAVE BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

THE AUTHORS AND DISTRIBUTORS SPECIFICALLY DISCLAIM ANY WARRANTIES, INCLUDING,
BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE, AND NON-INFRINGEMENT.  THIS SOFTWARE IS PROVIDED ON AN
"AS IS" BASIS, AND THE AUTHORS AND DISTRIBUTORS HAVE NO OBLIGATION TO PROVIDE
MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.


=========================================================================*/
#include "vtkKDTreePointLocator.h"
#include "vtkMath.h"
#include "vtkIntArray.h"
#include "vtkPolyData.h"

// Right now, we only support 3D kD-trees.
#define NUM_DIMENSIONS 3

struct PointData {
    int PointId;
    float X, Y, Z;
};

struct TerminalNodeData {
    vtkIdList *Points;  // partition
};

struct InternalNodeData {
    int Point;

    enum SplitAxisType { SPLIT_BY_X, SPLIT_BY_Y, SPLIT_BY_Z };

    float SplitValue;           // Xq
    SplitAxisType SplitAxis;    // discriminator

    vtkKDTreeNode *LeftChild;
    vtkKDTreeNode *RightChild;
};

class vtkKDTreeNode {
public:
	~vtkKDTreeNode();
    
    bool Terminal;

    union {
        TerminalNodeData t;
        InternalNodeData i;
    } NodeData;
};

vtkKDTreeNode::~vtkKDTreeNode()
{
    if (this->Terminal) {
        this->NodeData.t.Points->Delete();
    } else {
        delete this->NodeData.i.LeftChild;
        delete this->NodeData.i.RightChild;
    }
}

// =======================================================
//    M A I N   M E T H O D S
// =======================================================

vtkKDTreePointLocator::vtkKDTreePointLocator()
{
    this->NumberOfPointsPerBucket = 10;
    this->Root = NULL;
}

vtkKDTreePointLocator::~vtkKDTreePointLocator()
{
    this->FreeSearchStructure();
}

void vtkKDTreePointLocator::FreeSearchStructure()
{
    delete Root;
}

void vtkKDTreePointLocator::BuildLocator()
{
    if ( (this->BuildTime > this->MTime) && 
        (this->BuildTime > this->DataSet->GetMTime()) )
    {
        return;
    }

    float *point;
    int numPoints = this->DataSet->GetNumberOfPoints();
    PointData *pts = new PointData[numPoints];
    for (int i=0; i<numPoints; i++) {
        point = this->DataSet->GetPoint(i);
        pts[i].PointId = i;
        pts[i].X = point[0];
        pts[i].Y = point[1];
        pts[i].Z = point[2];
    }
    
    this->Root = this->BuildTree(numPoints, pts);
    delete [] pts;

    this->BuildTime.Modified();
}

int CompareX(const void *e1, const void *e2) 
{
    if (((PointData*)e1)->X > ((PointData*)e2)->X) {
        return 1;
    } else {
        return -1;
    }
}

int CompareY(const void *e1, const void *e2) 
{
    if (((PointData*)e1)->Y > ((PointData*)e2)->Y) {
        return 1;
    } else {
        return -1;
    }
}

int CompareZ(const void *e1, const void *e2) 
{
    if (((PointData*)e1)->Z > ((PointData*)e2)->Z) {
        return 1;
    } else {
        return -1;
    }
}

vtkKDTreeNode *vtkKDTreePointLocator::BuildTree(int numPoints, PointData *pts) 
{
    PointData *pt;
    int i;

    // Create KD-Tree node
    vtkKDTreeNode *node = new vtkKDTreeNode();

    // Jump out if we just need to stuff the results 
    // in a terminal bucket
    if (numPoints <= this->NumberOfPointsPerBucket) {
        node->Terminal = true;
        
        vtkIdList *bucket = vtkIdList::New();
        bucket->SetNumberOfIds(numPoints);

        for (pt=pts,i=0; i<numPoints; i++,pt++) {
            bucket->SetId(i, pt->PointId);
        }

        node->NodeData.t.Points = bucket;

        return node;
    } 

    node->Terminal = false;

    // Find which axis has the greatest spread.  Then sort the points
    // based on the values for that axis.
    float xmin = VTK_FLOAT_MAX, ymin = VTK_FLOAT_MAX, zmin = VTK_FLOAT_MAX;
    float xmax = -VTK_FLOAT_MAX, ymax = -VTK_FLOAT_MAX, zmax = -VTK_FLOAT_MAX;

        // Find the bounding box
    pt = pts;
    for (i=numPoints-1; i>=0; i--) {
        if (pt->X < xmin) xmin = pt->X;
        if (pt->Y < ymin) ymin = pt->Y;
        if (pt->Z < zmin) zmin = pt->Z;
        
        if (pt->X > xmax) xmax = pt->X;
        if (pt->Y > ymax) ymax = pt->Y;
        if (pt->Z > zmax) zmax = pt->Z;

        pt++;
    }

        // Decide which spread is the greatest.  
        // Then sort and set the split vars
    float xSpread = xmax - xmin;
    float ySpread = ymax - ymin;
    float zSpread = zmax - zmin;
    int medianPosition = numPoints/2;
    if (xSpread > ySpread) {
        if (xSpread > zSpread) {
            qsort((void*)pts, numPoints, sizeof(PointData), CompareX);
            node->NodeData.i.SplitAxis = node->NodeData.i.SPLIT_BY_X;
            node->NodeData.i.SplitValue = pts[medianPosition].X;
        } else {
            qsort((void*)pts, numPoints, sizeof(PointData), CompareZ);
            node->NodeData.i.SplitAxis = node->NodeData.i.SPLIT_BY_Z;
            node->NodeData.i.SplitValue = pts[medianPosition].Z;
        }
    } else {
        if (ySpread > zSpread) {
            qsort((void*)pts, numPoints, sizeof(PointData), CompareY);
            node->NodeData.i.SplitAxis = node->NodeData.i.SPLIT_BY_Y;
            node->NodeData.i.SplitValue = pts[medianPosition].Y;
        } else {
            qsort((void*)pts, numPoints, sizeof(PointData), CompareZ);
            node->NodeData.i.SplitAxis = node->NodeData.i.SPLIT_BY_Z;
            node->NodeData.i.SplitValue = pts[medianPosition].Z;
        }
    }

    node->NodeData.i.LeftChild = this->BuildTree(medianPosition, pts);
    node->NodeData.i.RightChild = this->BuildTree(
        numPoints - medianPosition - 1, pts + medianPosition + 1);

    return node;
}

void vtkKDTreePointLocator::GenerateRepresentation(
    int vtkNotUsed(level), vtkPolyData *pd)
{
    vtkErrorMacro(<<"I don't understand what this method is supposed to do, so it's not implemented.");
}

// Given a position x-y-z, return the id of the point closest to it.
int vtkKDTreePointLocator::FindClosestPoint(float x, float y, float z)
{
  float xyz[3];

  xyz[0] = x; xyz[1] = y; xyz[2] = z;
  return this->FindClosestPoint(xyz);
}

int vtkKDTreePointLocator::FindClosestPoint(float *x)
{
    int closestPointId = -1;
    float closestPointDistanceSquared = VTK_FLOAT_MAX;
    float extent[] = {
        VTK_FLOAT_MIN, VTK_FLOAT_MAX, 
        VTK_FLOAT_MIN, VTK_FLOAT_MAX, 
        VTK_FLOAT_MIN, VTK_FLOAT_MAX
    };
    float closestPoint[3] = {VTK_FLOAT_MAX, VTK_FLOAT_MAX, VTK_FLOAT_MAX};

    this->SearchForClosestPoint(x, this->Root, 
        closestPointId, closestPointDistanceSquared, closestPoint, extent);

    return closestPointId;
}

long numExecutes = 0;
long numTerminals = 0;
long numDrips = 0;
long numDoubleChecks = 0;

bool vtkKDTreePointLocator::SearchForClosestPoint(
    float *x, vtkKDTreeNode *node, 
    int &closestPointId, float &closestPointDistanceSquared, 
    float *closestPoint, float extent[6])
{
    float *tmpClosestPoint;
    
numExecutes++;
    if (node->Terminal)
    {
numTerminals++;
        // If we have hit a terminal node, look for the closest closestPoint
        // in that node to x.

        float distSquared;
        int numIds = node->NodeData.t.Points->GetNumberOfIds();
        int *id = node->NodeData.t.Points->GetPointer(0);
numDrips += numIds;
        
        for (int idNum=0; idNum<numIds; idNum++, id++)
        {
            tmpClosestPoint = this->DataSet->GetPoint(*id);
            distSquared = vtkMath::Distance2BetweenPoints(x, tmpClosestPoint);

            if (distSquared < closestPointDistanceSquared)
            {
                closestPointDistanceSquared = distSquared;
                closestPointId = *id;
                closestPoint = tmpClosestPoint;
            }
        }
    } else {
        float tmpExtent;
        bool done;

        // We're not at a terminal node, so we need to figure out
        // which subtree to search
        
        if (x[node->NodeData.i.SplitAxis] < node->NodeData.i.SplitValue)
        {
            // First case: x is contained in the left subtree
            // Search for the closest point in that subtree.
            tmpExtent = extent[node->NodeData.i.SplitAxis*2+1];
            extent[node->NodeData.i.SplitAxis*2+1] = 
                node->NodeData.i.SplitValue;
            done = this->SearchForClosestPoint(x, node->NodeData.i.LeftChild,
                    closestPointId, closestPointDistanceSquared, 
                    closestPoint, extent);
            extent[node->NodeData.i.SplitAxis*2+1] = tmpExtent;
                
            // Closest point definitively found in the subtree.  We
            // can jump out and return here.
            if (done) return true;


            // Second case: we did not definitively find x in the
            // closer subtree, so we'll check the farther one.

            // Check other subtree if there could be a closer point
            // in it.
            tmpExtent = extent[node->NodeData.i.SplitAxis*2];
            extent[node->NodeData.i.SplitAxis*2] = 
                node->NodeData.i.SplitValue;
            /*
            if (node->NodeData.i.SplitValue - 
                    closestPoint[node->NodeData.i.SplitAxis] >
                fabs(x[node->NodeData.i.SplitAxis] -
                    closestPoint[node->NodeData.i.SplitAxis])
            */
            if (this->BoundsOverlapBall(x, 
                closestPointDistanceSquared, extent))
            {
numDoubleChecks++;
                this->SearchForClosestPoint(x, node->NodeData.i.RightChild,
                    closestPointId, closestPointDistanceSquared, 
                    closestPoint, extent);
            }
            extent[node->NodeData.i.SplitAxis*2] = tmpExtent;
        }
        else 
        {
            // First case: x is contained in the right subtree
            // Search for the closest point in that subtree.
            tmpExtent = extent[node->NodeData.i.SplitAxis*2];
            extent[node->NodeData.i.SplitAxis*2] = 
                node->NodeData.i.SplitValue;
            done = this->SearchForClosestPoint(x, node->NodeData.i.RightChild,
                    closestPointId, closestPointDistanceSquared, 
                    closestPoint, extent);
            extent[node->NodeData.i.SplitAxis*2] = tmpExtent;
                
            // Closest point definitively found in the subtree.  We
            // can jump out and return here.
            if (done) return true;


            // Second case: we did not definitively find x in the
            // closer subtree, so we'll check the farther one.

            // Check other subtree
            tmpExtent = extent[node->NodeData.i.SplitAxis*2+1];
            extent[node->NodeData.i.SplitAxis*2+1] = 
                node->NodeData.i.SplitValue;
            /*
            if (closestPoint[node->NodeData.i.SplitAxis] - 
                    node->NodeData.i.SplitValue >
                fabs(x[node->NodeData.i.SplitAxis] -
                    closestPoint[node->NodeData.i.SplitAxis])
            */
            if (this->BoundsOverlapBall(x, 
                closestPointDistanceSquared, extent))
            {
numDoubleChecks++;
                this->SearchForClosestPoint(x, node->NodeData.i.LeftChild,
                    closestPointId, closestPointDistanceSquared, 
                    closestPoint, extent);
            }
            extent[node->NodeData.i.SplitAxis*2+1] = tmpExtent;
        }
    }

    // Finally, see if we can guarantee that the closest closestPoint is
    // closer to x than any of the edges of the current kD-tree
    // node.  If it is not, then some other node may contain a
    // closestPoint closer than the one just found.  We return false if
    // more work needs to be done to guarantee that we found the closest
    // point.
    float xToPoint;
    float xToBoundary;
    
    float *ePtr = extent;
    float *xPtr = x;
    for (int d=0; d<NUM_DIMENSIONS; d++)
    {
        xToPoint = fabs(closestPoint[d] - (*xPtr));

        // Check min boundary
        xToBoundary = (*xPtr) - (*ePtr); ePtr++;
        if (xToBoundary < xToPoint) return false;

        // Check max boundary
        xToBoundary = (*ePtr) - (*xPtr); ePtr++;
        if (xToBoundary < xToPoint) return false;

        xPtr++;
    }
    return true;
}

bool vtkKDTreePointLocator::BoundsOverlapBall(
    float *x, float closestPointDistanceSquared, float extent[6])
{
    float xToBoundary;
    
    float *ePtr = extent;
    float *xPtr = x;

    float xToExtentCorner = 0.0;

    for (int d=0; d<NUM_DIMENSIONS; d++) 
    {
        // Check min boundary
        xToBoundary = (*xPtr) - (*ePtr); ePtr++;
        if (xToBoundary < 0.0) {
            xToExtentCorner +=  xToBoundary * xToBoundary;
        }

        // Check max boundary
        xToBoundary = (*xPtr) - (*ePtr); ePtr++;
        if (xToBoundary > 0.0) {
            xToExtentCorner +=  xToBoundary * xToBoundary;
        }

        xPtr++;        
    }

    return (xToExtentCorner < closestPointDistanceSquared);
}

