#include "heightfield.h"
#include <assert.h>

Heightfield::Heightfield(const Transform &o2w, int x, int y,
		float *zs)
	: Shape(o2w) {
  nx = x;
  ny = y;
  z = new float[nx*ny];
  memcpy(z, zs, nx*ny*sizeof(float));
  Point *P;
  int x, y;
  
  // adapted from Refine method
  P = new Point[nx*ny];
  int pos = 0;
  for (y = 0; y < ny; ++y) {
    for (x = 0; x < nx; ++x) {
      P[pos].x = (float)x / (float)(nx-1);
      P[pos].y = (float)y / (float)(ny-1);
      P[pos].z = z[pos];
      ++pos;
    }
  }

  // compute normals at each grid point from cross of unit x and y tangents
  Normal *N;
  N = new Normal[nx*ny];
  Vector Tx, Ty;
  pos = 0;
  for (y = 0; y < ny; ++y) {
    for (x = 0; x < nx; ++x) {
      if (x == 0) {  // get x tangent only from right
	Tx = (P[x + 1 + y*nx] - P[x + y*nx]).Hat();
      } else if (x == nx - 1) {  // get x tangent only from left
	Tx = (P[x + y*nx] - P[x - 1 + y*nx]).Hat();
      } else {  // get x tangent centered on pos
	Tx = (P[x + 1 + y*nx] - P[x - 1 + y*nx]).Hat();
      }
      if (y == 0) {
	Ty = (P[x + (y + 1)*nx] - P[x + y*nx]).Hat();
      } else if (y == ny - 1) {
	Ty = (P[x + y*nx] - P[x + (y - 1)*nx]).Hat();
      } else {
	Ty = (P[x + (y + 1)*nx] - P[x + (y - 1)*nx]).Hat();
      }
      N[pos] = Normal(Cross(Tx, Ty)).Hat();
      ++pos;
    }
  }
  Tree = new HfTree(P, N, 0, nx - 1, 0, ny - 1, nx, ny);
}

Heightfield::~Heightfield() {
	delete[] z;
	delete Tree;
}

bool Heightfield::Intersect(const Ray &ray, DifferentialGeometry *dg) const {
  Ray r = WorldToObject(ray);
  bool found = Tree->Intersect(r, dg);
  if (found) {  // could probably do these in all cases
    ray.maxt = r.maxt;
    *dg = ObjectToWorld(*dg);
  }
  return found;
}

bool Heightfield::IntersectP(const Ray &ray) const {
  Ray r = WorldToObject(ray);
  return Tree->IntersectP(r);
}

BBox Heightfield::Bound() const {
	float minz = z[0], maxz = z[0];
	for (int i = 1; i < nx*ny; ++i) {
		if (z[i] < minz) minz = z[i];
		if (z[i] > maxz) maxz = z[i];
	}
	return BBox(Point(0,0,minz), Point(1,1,maxz));
}

/*

void Heightfield::Refine(vector<Shape *> &refined) const {
	int ntris = 2*(nx-1)*(ny-1);
	int *verts = new int[3*ntris];
	Point *P = new Point[nx*ny];
	int x, y;
	P = new Point[nx*ny];
	int pos = 0;
	for (y = 0; y < ny; ++y) {
		for (x = 0; x < nx; ++x) {
			P[pos].x = (float)x / (float)(nx-1);
			P[pos].y = (float)y / (float)(ny-1);
			P[pos].z = z[pos];
			++pos;
		}
	}
	int *vp = verts;
	for (y = 0; y < ny-1; ++y) {
		for (x = 0; x < nx-1; ++x) {
	#define VERT(x,y) ((x)+(y)*nx)
			*vp++ = VERT(x, y);
			*vp++ = VERT(x+1, y);
			*vp++ = VERT(x+1, y+1);
	
			*vp++ = VERT(x, y);
			*vp++ = VERT(x+1, y+1);
			*vp++ = VERT(x, y+1);
		}
	#undef VERT
	}
	refined.push_back(new TriangleMesh(ObjectToWorld, ntris,
		nx*ny, verts, P));
	delete[] P;
	delete[] verts;
}

*/

const Normal HfTree::vertSplitN(1, 0, 0);

const Normal HfTree::horSplitN(0, 1, 0);

const Normal HfTree::diagN = Normal(-1, 1, 0).Hat();

HfTree::HfTree(Point *P, Normal *N,
	       int xi0, int xi1, int yi0, int yi1, int nx, int ny) {
  assert(xi0 >= 0 && yi0 >= 0 && xi1 < nx && yi1 < ny
	 && xi0 + 1 <= xi1 && yi0 + 1 <= yi1);

  //compute bounding box
  float minz = P[xi0 + yi0*nx].z;
  float maxz = minz;
  float z;
  for (int xi = xi0; xi <= xi1; xi++) {
    for (int yi = yi0; yi <= yi1; yi++) {
      z = P[xi + yi*nx].z;
      if (z > maxz) maxz = z;
      if (z < minz) minz = z;
    }
  }
  Point p0 = P[xi0 + yi0*nx];
  Point p1 = P[xi1 + yi1*nx];
  box = new BBox(Point(p0.x, p0.y, minz), Point(p1.x, p1.y, maxz));

  if (xi0 + 1 == xi1 && yi0 + 1 == yi1) {  // single grid square
    splitNormal = diagN;
    aboveSplitPlane = new
      HfTriangle(P[xi0 + yi1*nx], P[xi1 + yi1*nx], P[xi0 + yi0*nx],
		 N[xi0 + yi1*nx], N[xi1 + yi1*nx], N[xi0 + yi0*nx]);
    belowSplitPlane = new
      HfTriangle(P[xi0 + yi0*nx], P[xi1 + yi1*nx], P[xi1 + yi0*nx],
		 N[xi0 + yi0*nx], N[xi1 + yi1*nx], N[xi1 + yi0*nx]);
    return;
  }

  if (xi1 - xi0 > yi1 - yi0) {  // split larger stretch
    splitNormal = vertSplitN;
    int xiSplit = (xi1 - xi0)/2 + xi0;

    belowSplitPlane = new HfTree(P, N,
				 xi0, xiSplit, yi0, yi1, nx, ny);

    aboveSplitPlane = new HfTree(P, N,
				 xiSplit, xi1, yi0, yi1, nx, ny);
  } else {
    splitNormal = horSplitN;
    int yiSplit = (yi1 - yi0)/2 + yi0;

    belowSplitPlane = new HfTree(P, N,
				 xi0, xi1, yi0, yiSplit, nx, ny);

    aboveSplitPlane = new HfTree(P, N,
				 xi0, xi1, yiSplit, yi1, nx, ny);
  }
}

HfTree::~HfTree() {
  delete box;
  delete belowSplitPlane;
  delete aboveSplitPlane;
}

bool HfTree::Intersect(const Ray &ray, DifferentialGeometry *dg) const {
  Ray dupRay(ray.O, ray.D, ray.mint, ray.maxt);
  if (!box->IntersectP(dupRay)) return false;
  if (Dot(splitNormal, ray.D) > 0) {
    if (belowSplitPlane->Intersect(ray, dg)) return true;
    return aboveSplitPlane->Intersect(ray, dg);
  } else {
    if (aboveSplitPlane->Intersect(ray, dg)) return true;
    return belowSplitPlane->Intersect(ray, dg);
  }
}

bool HfTree::IntersectP(const Ray &ray) const {
  Ray dupRay(ray.O, ray.D, ray.mint, ray.maxt);
  if (!box->IntersectP(dupRay)) return false;
  if (belowSplitPlane->IntersectP(ray)) return true;
  return aboveSplitPlane->IntersectP(ray);
}

// adapted from Triangle::Intersect
bool HfTriangle::Intersect(const Ray &ray, DifferentialGeometry *dg) const {
  Vector S_1 = Cross( ray.D, E2 );
  Float divisor = Dot( S_1, E1 );
  if (divisor == 0.) {
    return false;
  }
  Float invDivisor = 1. / divisor;
  Vector T = ray.O - P0;
  Float u = Dot( T, S_1 ) * invDivisor;
  if (u < 0. || u > 1.0) {
    return false;
  }
  Vector S_2 = Cross( T, E1 );
  Float v = Dot( ray.D, S_2 ) * invDivisor;
  if (v < 0 || u + v > 1.0) {
    return false;
  }
  Float t = Dot( E2, S_2 ) * invDivisor;
  if (t < ray.mint || t > ray.maxt) {
    return false;
  }
  Normal N = Normal(u*N1 + v*N2 + (1-u-v)*N0).Hat();
  Vector B1 = Cross(Vector(N), Vector(1,0,0)).Hat();
  Point hit = ray(t);
  *dg = DifferentialGeometry(hit, N.Hat(), B1,
			     Cross(Vector(N), B1).Hat(), hit.x, hit.y);
  /* Vector B2 = Cross(Vector(N), B1).Hat();
  if (RandomFloat() > 0.999) printf("\nN = (%.4f, %.4f, %.4f)\n"
				    "S = (%.4f, %.4f, %.4f)\n"
				    "T = (%.4f, %.4f, %.4f)\n",
				    N.Hat().x, N.Hat().y, N.Hat().z,
				    B1.x, B1.y, B1.z,
				    B2.x, B2.y, B2.z);
  */
  ray.maxt = t;
  return true;
}

// adapted from the above in turn
bool HfTriangle::IntersectP(const Ray &ray) const {
  Vector S_1 = Cross( ray.D, E2 );
  Float divisor = Dot( S_1, E1 );
  if (divisor == 0.)
    return false;
  Float invDivisor = 1. / divisor;
  Vector T = ray.O - P0;
  Float u = Dot( T, S_1 ) * invDivisor;
  if (u < 0. || u > 1.0)
    return false;
  Vector S_2 = Cross( T, E1 );
  Float v = Dot( ray.D, S_2 ) * invDivisor;
  if (v < 0 || u + v > 1.0)
    return false;
  Float t = Dot( E2, S_2 ) * invDivisor;
  if (t < ray.mint || t > ray.maxt)
    return false;
  return true;
}
