#pragma once
#include "VecMath.h"
#include "Core.h"

class SymetricMatrix {

public:

  // Constructor

  SymetricMatrix(double c = 0) { loopi(0, 10) m[i] = c; }

  SymetricMatrix(double m11, double m12, double m13, double m14,
    double m22, double m23, double m24,
    double m33, double m34,
    double m44)
  {
    m[0] = m11;  m[1] = m12;  m[2] = m13;  m[3] = m14;
    m[4] = m22;  m[5] = m23;  m[6] = m24;
    m[7] = m33;  m[8] = m34;
    m[9] = m44;
  }

  // Make plane

  SymetricMatrix(double a, double b, double c, double d)
  {
    m[0] = a * a;  m[1] = a * b;  m[2] = a * c;  m[3] = a * d;
    m[4] = b * b;  m[5] = b * c;  m[6] = b * d;
    m[7] = c * c; m[8] = c * d;
    m[9] = d * d;
  }

  double operator[](int c) const { return m[c]; }

  // Determinant

  double det(int a11, int a12, int a13,
    int a21, int a22, int a23,
    int a31, int a32, int a33)
  {
    double det = m[a11] * m[a22] * m[a33] + m[a13] * m[a21] * m[a32] + m[a12] * m[a23] * m[a31]
      - m[a13] * m[a22] * m[a31] - m[a11] * m[a23] * m[a32] - m[a12] * m[a21] * m[a33];
    return det;
  }

  const SymetricMatrix operator+(const SymetricMatrix& n) const
  {
    return SymetricMatrix(m[0] + n[0], m[1] + n[1], m[2] + n[2], m[3] + n[3],
      m[4] + n[4], m[5] + n[5], m[6] + n[6],
      m[7] + n[7], m[8] + n[8],
      m[9] + n[9]);
  }

  SymetricMatrix& operator+=(const SymetricMatrix& n)
  {
    m[0] += n[0];   m[1] += n[1];   m[2] += n[2];   m[3] += n[3];
    m[4] += n[4];   m[5] += n[5];   m[6] += n[6];   m[7] += n[7];
    m[8] += n[8];   m[9] += n[9];
    return *this;
  }

  double m[10];
};
///////////////////////////////////////////

namespace Simplify
{
  // Variables & Structures

  struct Triangle {
    Triangle() :
      v{ 0,0,0 },
      err{ 0,0,0,0 },
      deleted(0),
      dirty(0)
    {
    }

    int v[3];
    double err[4];
    int deleted;
    int dirty;
    vec3d n;
  };

  struct Vertex {
    Vertex() :
      p(vec3d()),
      tstart(0),
      tcount(0),
      border(0)
    {
    }

    Vertex(const agx::Vec3& v) :
      p(v),
      tstart(0),
      tcount(0),
      border(0)
    {
    }

    vec3d p; int tstart, tcount; SymetricMatrix q; int border;
  };

  struct Ref {
    Ref() : tid(0), tvertex(0)
    {

    }
    int tid, tvertex;
  };

  struct MeshData {
    std::vector<Triangle> triangles;
    std::vector<Vertex> vertices;
    std::vector<Ref> refs;
  };


  // Helper functions

  double vertex_error(const SymetricMatrix& q, double x, double y, double z);
  double calculate_error(MeshData& meshData, int id_v1, int id_v2, vec3d& p_result);
  bool flipped(MeshData& meshData, vec3d p, int i0, int i1, Vertex& v0, Vertex& v1, std::vector<int>& deleted);
  void update_triangles(MeshData& meshData, int i0, Vertex& v, std::vector<int>& deleted, int& deleted_triangles);
  void update_mesh(MeshData& meshData, int iteration);
  void compact_mesh(MeshData& meshData);
  //
  // Main simplification function 
  //
  // target_count  : target nr. of triangles
  // agressiveness : sharpness to increase the threshold.
  //                 5..8 are good numbers
  //                 more iterations yield higher quality
  //
  void simplify_mesh(MeshData& meshData, int target_count, double agressiveness = 7)
  {
    // init
    loopi(0, (int)meshData.triangles.size()) meshData.triangles[i].deleted = 0;

    // main iteration loop 

    int deleted_triangles = 0;
    std::vector<int> deleted0, deleted1;
    size_t triangle_count = meshData.triangles.size();

    loop(iteration, 0, 100)
    {
      // target number of triangles reached ? Then break
      //printf("iteration %d - triangles %d\n", iteration, triangle_count - deleted_triangles);
      if ((int)triangle_count - deleted_triangles <= target_count)break;

      // update mesh once in a while
      if (iteration % 5 == 0)
      {
        update_mesh(meshData, iteration);
      }

      // clear dirty flag
      loopi(0, (int)meshData.triangles.size()) meshData.triangles[i].dirty = 0;

      //
      // All triangles with edges below the threshold will be removed
      //
      // The following numbers works well for most models.
      // If it does not, try to adjust the 3 parameters
      //
      double threshold = 0.000000001 * pow(double(iteration + 3), agressiveness);

      // remove vertices & mark deleted triangles      
      loopi(0, (int)meshData.triangles.size())
      {
        Triangle& t = meshData.triangles[i];
        if (t.err[3] > threshold) continue;
        if (t.deleted) continue;
        if (t.dirty) continue;

        loopj(0, 3)if (t.err[j] < threshold)
        {
          int i0 = t.v[j]; Vertex& v0 = meshData.vertices[i0];
          int i1 = t.v[(j + 1) % 3]; Vertex& v1 = meshData.vertices[i1];

          // Border check
          if (v0.border != v1.border)  continue;

          // Compute vertex to collapse to
          vec3d p;
          calculate_error(meshData, i0, i1, p);

          deleted0.resize(v0.tcount); // normals temporarily
          deleted1.resize(v1.tcount); // normals temporarily

          // don't remove if flipped
          if (flipped(meshData, p, i0, i1, v0, v1, deleted0)) continue;
          if (flipped(meshData, p, i1, i0, v1, v0, deleted1)) continue;

          // not flipped, so remove edge                        
          v0.p = p;
          v0.q = v1.q + v0.q;
          size_t tstart = meshData.refs.size();

          update_triangles(meshData, i0, v0, deleted0, deleted_triangles);
          update_triangles(meshData, i0, v1, deleted1, deleted_triangles);

          size_t tcount = meshData.refs.size() - tstart;

          if ((int)tcount <= v0.tcount)
          {
            // save ram
            if (tcount)memcpy(&meshData.refs[v0.tstart], &meshData.refs[tstart], tcount * sizeof(Ref));
          }
          else
            // append
            v0.tstart = (int)tstart;

          v0.tcount = (int)tcount;
          break;
        }
        // done?
        if (triangle_count - deleted_triangles <= (size_t)target_count)break;
      }
    }

    // clean up mesh
    compact_mesh(meshData);

    // ready
  }

  // Check if a triangle flips when this edge is removed

  bool flipped(MeshData& meshData, vec3d p, int /*i0*/, int i1, Vertex& v0, Vertex& /*v1*/, std::vector<int>& deleted)
  {
    int bordercount = 0;
    loopk(0, v0.tcount)
    {
      Triangle& t = meshData.triangles[meshData.refs[v0.tstart + k].tid];
      if (t.deleted)continue;

      int s = meshData.refs[v0.tstart + k].tvertex;
      int id1 = t.v[(s + 1) % 3];
      int id2 = t.v[(s + 2) % 3];

      if (id1 == i1 || id2 == i1) // delete ?
      {
        bordercount++;
        deleted[k] = 1;
        continue;
      }
      vec3d d1 = meshData.vertices[id1].p - p; d1.normalize();
      vec3d d2 = meshData.vertices[id2].p - p; d2.normalize();
      if (fabs(d1.dot(d2)) > 0.999) return true;
      vec3d n;
      n.cross(d1, d2);
      n.normalize();
      deleted[k] = 0;
      if (n.dot(t.n) < 0.2) return true;
    }
    return false;
  }

  // Update triangle connections and edge error after a edge is collapsed

  void update_triangles(MeshData& meshData, int i0, Vertex& v, std::vector<int>& deleted, int& deleted_triangles)
  {
    vec3d p;
    loopk(0, v.tcount)
    {
      Ref& r = meshData.refs[v.tstart + k];
      Triangle& t = meshData.triangles[r.tid];
      if (t.deleted)continue;
      if (deleted[k])
      {
        t.deleted = 1;
        deleted_triangles++;
        continue;
      }
      t.v[r.tvertex] = i0;
      t.dirty = 1;
      t.err[0] = calculate_error(meshData, t.v[0], t.v[1], p);
      t.err[1] = calculate_error(meshData, t.v[1], t.v[2], p);
      t.err[2] = calculate_error(meshData, t.v[2], t.v[0], p);
      t.err[3] = std::min(t.err[0], std::min(t.err[1], t.err[2]));
      meshData.refs.push_back(r);
    }
  }

  // compact triangles, compute edge error and build reference list

  void update_mesh(MeshData& meshData, int iteration)
  {
    if (iteration > 0) // compact triangles
    {
      int dst = 0;
      loopi(0, (int)meshData.triangles.size())
        if (!meshData.triangles[i].deleted)
        {
          meshData.triangles[dst++] = meshData.triangles[i];
        }
      meshData.triangles.resize(dst);
    }
    //
    // Init Quadrics by Plane & Edge Errors
    //
    // required at the beginning ( iteration == 0 )
    // recomputing during the simplification is not required,
    // but mostly improves the result for closed meshes
    //
    if (iteration == 0)
    {
      loopi(0, (int)meshData.vertices.size())
        meshData.vertices[i].q = SymetricMatrix(0.0);

      loopi(0, (int)meshData.triangles.size())
      {
        Triangle& t = meshData.triangles[i];
        vec3d n, p[3];
        loopj(0, 3) p[j] = meshData.vertices[t.v[j]].p;
        n.cross(p[1] - p[0], p[2] - p[0]);
        n.normalize();
        t.n = n;
        loopj(0, 3) meshData.vertices[t.v[j]].q =
          meshData.vertices[t.v[j]].q + SymetricMatrix(n.x, n.y, n.z, -n.dot(p[0]));
      }
      loopi(0, (int)meshData.triangles.size())
      {
        // Calc Edge Error
        Triangle& t = meshData.triangles[i]; vec3d p;
        loopj(0, 3) t.err[j] = calculate_error(meshData, t.v[j], t.v[(j + 1) % 3], p);
        t.err[3] = std::min(t.err[0], std::min(t.err[1], t.err[2]));
      }
    }

    // Init Reference ID list  
    loopi(0, (int)meshData.vertices.size())
    {
      meshData.vertices[i].tstart = 0;
      meshData.vertices[i].tcount = 0;
    }
    loopi(0, (int)meshData.triangles.size())
    {
      Triangle& t = meshData.triangles[i];
      loopj(0, 3) meshData.vertices[t.v[j]].tcount++;
    }
    int tstart = 0;
    loopi(0, (int)meshData.vertices.size())
    {
      Vertex& v = meshData.vertices[i];
      v.tstart = tstart;
      tstart += v.tcount;
      v.tcount = 0;
    }

    // Write References
    meshData.refs.resize(meshData.triangles.size() * 3);
    loopi(0, (int)meshData.triangles.size())
    {
      Triangle& t = meshData.triangles[i];
      loopj(0, 3)
      {
        Vertex& v = meshData.vertices[t.v[j]];
        meshData.refs[v.tstart + v.tcount].tid = i;
        meshData.refs[v.tstart + v.tcount].tvertex = j;
        v.tcount++;
      }
    }

    // Identify boundary : vertices[].border=0,1 
    if (iteration == 0)
    {
      std::vector<int> vcount, vids;

      loopi(0, (int)meshData.vertices.size())
        meshData.vertices[i].border = 0;

      loopi(0, (int)meshData.vertices.size())
      {
        Vertex& v = meshData.vertices[i];
        vcount.clear();
        vids.clear();
        loopj(0, v.tcount)
        {
          int kx = meshData.refs[v.tstart + j].tid;
          Triangle& t = meshData.triangles[kx];
          loopk(0, 3)
          {
            size_t ofs = 0, id = t.v[k];
            while (ofs < vcount.size())
            {
              if (vids[ofs] == (int)id)break;
              ofs++;
            }
            if (ofs == vcount.size())
            {
              vcount.push_back(1);
              vids.push_back((int)id);
            }
            else
              vcount[ofs]++;
          }
        }
        loopj(0, (int)vcount.size()) if (vcount[j] == 1)
          meshData.vertices[vids[j]].border = 1;
      }
    }
  }

  // Finally compact mesh before exiting

  void compact_mesh(MeshData& meshData)
  {
    int dst = 0;
    loopi(0, (int)meshData.vertices.size())
    {
      meshData.vertices[i].tcount = 0;
    }
    loopi(0, (int)meshData.triangles.size())
      if (!meshData.triangles[i].deleted)
      {
        Triangle& t = meshData.triangles[i];
        meshData.triangles[dst++] = t;
        loopj(0, 3)meshData.vertices[t.v[j]].tcount = 1;
      }
    meshData.triangles.resize(dst);
    dst = 0;
    loopi(0, (int)meshData.vertices.size())
      if (meshData.vertices[i].tcount)
      {
        meshData.vertices[i].tstart = dst;
        meshData.vertices[dst].p = meshData.vertices[i].p;
        dst++;
      }
    loopi(0, (int)meshData.triangles.size())
    {
      Triangle& t = meshData.triangles[i];
      loopj(0, 3)t.v[j] = meshData.vertices[t.v[j]].tstart;
    }
    meshData.vertices.resize(dst);
  }

  // Error between vertex and Quadric

  double vertex_error(const SymetricMatrix& q, double x, double y, double z)
  {
    return   q[0] * x * x + 2 * q[1] * x * y + 2 * q[2] * x * z + 2 * q[3] * x + q[4] * y * y
      + 2 * q[5] * y * z + 2 * q[6] * y + q[7] * z * z + 2 * q[8] * z + q[9];
  }

  // Error for one edge

  double calculate_error(MeshData& meshData, int id_v1, int id_v2, vec3d& p_result)
  {
    // compute interpolated vertex 

    SymetricMatrix q = meshData.vertices[id_v1].q + meshData.vertices[id_v2].q;
    bool   border = (meshData.vertices[id_v1].border & meshData.vertices[id_v2].border) != 0;
    double error = 0;
    double det = q.det(0, 1, 2, 1, 4, 5, 2, 5, 7);

    if (det != 0 && !border)
    {
      // q_delta is invertible
      p_result.x = -1 / det * (q.det(1, 2, 3, 4, 5, 6, 5, 7, 8));  // vx = A41/det(q_delta) 
      p_result.y = 1 / det * (q.det(0, 2, 3, 1, 5, 6, 2, 7, 8));  // vy = A42/det(q_delta) 
      p_result.z = -1 / det * (q.det(0, 1, 3, 1, 4, 6, 2, 5, 8));  // vz = A43/det(q_delta) 
      error = vertex_error(q, p_result.x, p_result.y, p_result.z);
    }
    else
    {
      // det = 0 -> try to find best result
      vec3d p1 = meshData.vertices[id_v1].p;
      vec3d p2 = meshData.vertices[id_v2].p;
      vec3d p3 = (p1 + p2) / 2;
      double error1 = vertex_error(q, p1.x, p1.y, p1.z);
      double error2 = vertex_error(q, p2.x, p2.y, p2.z);
      double error3 = vertex_error(q, p3.x, p3.y, p3.z);
      error = std::min(error1, std::min(error2, error3));
      if (error1 == error) p_result = p1;
      if (error2 == error) p_result = p2;
      if (error3 == error) p_result = p3;
    }
    return error;
  }
}
///////////////////////////////////////////
