#include "photonmap.h"
#include "color.h"
#include "primitives.h"
#include "reflection.h"
#include "kdtreecc.h"
#include "dal.h"
#include <math.h>

CausticPhotonMap::CausticPhotonMap(int numbOfPhotons)
{
  assert(numbOfPhotons > 0);

  //photonTree = new Photon*[numbOfPhotons];
  photonTree = vector<Photon*>(numbOfPhotons);
  photons = new Photon[numbOfPhotons];
  numberOfPhotons = numbOfPhotons;

  for (int i=0; i< numbOfPhotons; i++)
  {
    photonTree[i] = photons+i;
  }

  assert(photonTree[numbOfPhotons-1] == 
    &photons[numbOfPhotons-1]);
}

CausticPhotonMap::~CausticPhotonMap()
{
  //delete photonTree;
  delete photons;
}

  
void CausticPhotonMap::generatePhotonMap(Scene &scene)
{
  // there should be no light sources, we create our own!

  // create directional area light source (the sun)
  DirectionalAreaLight dal = DirectionalAreaLight(scene);
 
  // get photon color/power
  //Spectrum photonPower = Spectrum(1.)/numberOfPhotons; 
  Spectrum photonPower = Spectrum(100.)/numberOfPhotons; 

  // get light direction
  Vector dir = dal.GetDirection(); 

	// compute stepping in u and v
	int UMax = (int) sqrt(numberOfPhotons);
	int VMax = UMax;
	Float du = 1.0/UMax;
	Float dv = 1.0/VMax;

	int UCount = 0;
	int VCount = 0;

	Float currentU = 0.0;
	Float currentV = 0.0;
  
  Ray ray;
  int numberOfMisses = 0;
  for (int i=0; i< numberOfPhotons; i++)
  {
    // is this repetition bad? i can't think why it would be.
    int misses = -1;
    do 
    { 
      misses++;
      
      // get point on directional area light
      //Float u[2] = { RandomFloat(), RandomFloat() };
      Float u[2] = 
			  { currentU /*+ du*RandomFloat()*/, 
				  currentV /*+ dv*RandomFloat()*/ };
      ray.O = dal.Sample(u);
      ray.D = dir;
      
			// initialize before calling getPhotonPosition
      photonDepth = 0; 

    } while (!(getPhotonPosition(&ray, scene)));

		VCount ++;
		currentV +=dv;
		if (VCount >= VMax)
		{  
		  VCount = 0;
			UCount++;
			currentU += du;
			currentV = 0.0;
    }
    
		//cerr << UCount << endl;
		//cerr << UMax << endl;
		assert(UCount <= UMax);
    
		numberOfMisses +=misses;
    
    photons[i].ray = ray;
    photons[i].power = photonPower;

    if (!((i+1)%1000)) { cerr << "." ; } // progess indicator
  }
  cerr << " gotPhotons." << endl;
  cerr << "Number of Missed Photons: " << 
    numberOfMisses << endl;

  balancePhotonMap();

	printPhotonDim();

	cerr << "du:" << du << " dv:" << dv << endl;
	//printPhotons();
}

void CausticPhotonMap::printPhotons()
{
   for (int i=0; i< numberOfPhotons; i++)
	 {
	   cerr << photons[i].ray.O.x <<
		   "," << photons[i].ray.O.y <<
		   "," << photons[i].ray.O.z <<
			 endl;
     getchar();
   }
}

void CausticPhotonMap::printPhotonDim()
{
  Float minx = photons[0].ray.O.x;
  Float miny = photons[0].ray.O.y;
  Float minz = photons[0].ray.O.z;
	Float maxx = minx;
	Float maxy = miny;
	Float maxz = minz;

	for (int i=1; i<numberOfPhotons; i++)
	{
	  if (photons[i].ray.O.x < minx)
		{ minx = photons[i].ray.O.x; }
		else if (photons[i].ray.O.x > maxx)
		{ maxx = photons[i].ray.O.x; }
	  
		if (photons[i].ray.O.y < miny)
		{ miny = photons[i].ray.O.y; }
		else if (photons[i].ray.O.y > maxy)
		{ maxy = photons[i].ray.O.y; }

	  if (photons[i].ray.O.z < minz)
		{ minz = photons[i].ray.O.z; }
		else if (photons[i].ray.O.z > maxz)
		{ maxz = photons[i].ray.O.z; }
	}

	cerr << "Photons Dim: {" <<
	  minx << "," << maxx << "} {" <<
	  miny << "," << maxy << "} {" <<
	  minz << "," << maxz << "}" << endl; 

}

// return photon position and direction in ray
// REMEMBER: initialize photonDepth before calling 
//   getPhotonPosition
bool CausticPhotonMap::getPhotonPosition(Ray* ray, 
  const Scene& scene)
{
  Float hitDist = INFINITY;
  HitInfo hitInfo;
      
  photonDepth++;

  if (scene.Intersect(*ray,1e-4,&hitDist,&hitInfo))
  {
    ShadeContext shadeContext(&hitInfo,-ray->D);
    if (!hitInfo.hitPrim->attributes->Surface)
    { 
      cerr << "Warning: hit a non surface" << endl;
      return false;
    }
    BRDF *brdf =
      hitInfo.hitPrim->attributes->Surface->Shade(shadeContext);
    
    // save position
    ray->O =
      hitInfo.hitPrim->attributes->ObjectToWorld(hitInfo.Pobj); 
    
    if (brdf->TransmissionComponents() >= 1)
    {
      if (photonDepth > 5)
      {
        cerr << 
          "Max photon depth reached, returning current " 
					"position." << endl;
        ray->D = -ray->D; // make photon direction outward
        return true;
      }
      // only sample the first transmission component, for now
      brdf->SampleTransmission(0, &(ray->D));
      return getPhotonPosition(ray,scene);            
    }
    else // must be a diffuse or reflective-only surface
    { 
      // check if this is the first thing the caustic
      //   photon hits; if so, ignore it
      //if (photonDepth <= 1) { return false; }
      ray->D = -ray->D; // make photon direction outward
      return true;
    }  
  } 
  else
  { 
	  cerr << "^" ;
    return false; 
  }
}


// returns: the pair vector consisting of: 
//   distance SQUARED and the photon
vector< pair<Float,Photon *> >* 
CausticPhotonMap::getNearestPhotons(Point &point, 
  int numberOfPhotonsWanted, Float maxDistance)
{
  static unsigned int timesCalled = 0;

  assert(numberOfPhotonsWanted <= numberOfPhotons);
  assert(maxDistance > 0);

  this->maxDistance2 = maxDistance*maxDistance;
  this->numberOfPhotonsWanted = numberOfPhotonsWanted;

  // this is actually a heap
  vector< pair<Float,Photon *> >* pq =
    new vector< pair<Float,Photon *> >; 
  make_heap(pq->begin(),pq->end(),comparePhotonDistance);

  // get a queue of the nearest photons
  locatePhoton(0,numberOfPhotons-1,point,*pq);

	// must sort queue !!! as non-intuitive as it may 
	// be ... STL nonsense
	sort_heap(pq->begin(),pq->end(),comparePhotonDistance);
  
#if 0
  // this could happen if you look for photons near points on water
  if (pq->size() < (unsigned int )numberOfPhotonsWanted)
  {
    cerr << "Not enough photons found: " << pq->size() <<
      " photons found." << endl;
  }
#endif

  if (((++timesCalled) % 10000) == 0) { cerr << "cp"; }

  return pq;
}

void CausticPhotonMap::balancePhotonMap()
{
  cerr << "Balancing Photon Map " ; 
  
  KDTreeCC::balance(photonTree.begin(),photonTree.end(),photonTree);
  //KDTreeCC::balance(0,numberOfPhotons-1,photonTree);

  cerr << "done" << endl;
}

/* 
 * Description:
 * recurses down in kd-tree to find the closest photon
 * to point, and then backtracks upward looking at the 
 * photons around it
 *
 */
//void CausticPhotonMap::locatePhoton(int first, int last, 
bool CausticPhotonMap::locatePhoton(int first, int last, 
  const Point &point, vector< pair<Float, Photon *> > &pq)
{
  //assert(first <= last);
  assert(first <= (last+1));
	if (first > last) { return true; }
  
  // check if this is not leaf node of kd-tree
  if (first < last )
  {
    int median = (first+last)/2;
    Photon* medianPhoton = photonTree[median];
    Float signdist = 0.;

    switch (medianPhoton->axis)
    {
      case Photon::XAXIS:
        signdist = point.x - medianPhoton->ray.O.x;
        break;
      case Photon::YAXIS:
        signdist = point.y - medianPhoton->ray.O.y;
        break;
      case Photon::ZAXIS:
        signdist = point.z - medianPhoton->ray.O.z;
        break;
      default:
        cerr << "Invalid Axis!!" << endl;
        break;
    }

    if (signdist < 0.) // on left of kd-tree
    {
      if (locatePhoton(first,median-1,point,pq) == false)
      { return false; }
      if ((signdist*signdist) < maxDistance2)
      { locatePhoton(median+1,last,point,pq); }
    }
    else  // on right of kd-tree
    {
      if (locatePhoton(median+1,last,point,pq) == false)
      { return false; }
      if ((signdist*signdist) < maxDistance2)
      { locatePhoton(first,median-1,point,pq); }
    }
  }

  Photon* median = photonTree[(first+last)/2];
  Float dist2 = DistanceSquared(median->ray.O,point);

  if (dist2 < maxDistance2)
  {
		assert (is_heap(pq.begin(),pq.end(),comparePhotonDistance)
		  == true);
    pq.push_back(make_pair(dist2,median));
    push_heap(pq.begin(),pq.end(),comparePhotonDistance);

#if 0
    // only update the maxDistance if you have at least the
    // required number of photons
    if (pq.size() >= (unsigned int) numberOfPhotonsWanted)
    { 
		  vector< pair<Float, Photon *> > sortedPQ(pq); 
      sort_heap(sortedPQ.begin(),sortedPQ.end(),
			  comparePhotonDistance);
		  maxDistance2 = pq[numberOfPhotonsWanted-1].first; 
		}
#endif
    return true;
  }
  
  return false;
} // locatePhoton

