using AGXUnity;
using AGXUnity.Utils;
using System.Collections.Generic;
using System.Linq;
using Unity.Burst;
using Unity.Collections;
using Unity.Jobs;
using UnityEngine;

namespace PolyBag
{
  [BurstCompile]
  struct PolyBagUpdateJob : IJob
  {
    public NativeArray<agx.Vec3> vertices;
    public NativeArray<Vector3> unityVertices;

    public void Execute()
    {
      for ( int i = 0; i < vertices.Length; i++ ) {
        var v = vertices[i];
        unityVertices[ i ] = v.ToHandedVector3();
      }
    }
  }

  public class PolyBag : ScriptComponent
  {
    public Material RenderMaterial;

    [AllowRecursiveEditing]
    public ShapeMaterial Material
    {
      get { return m_material; }
      set
      {
        m_material = value;
      }
    }

    public enum ElementResolution
    {
      LOW = 1,
      MEDIUM,
      HIGH
    };
    [SerializeField]
    public bool useContactReduction = false;
    private bool m_useContactReduction = false;


    private agxSDK.ContactFilterReducerListener m_contactFilterReducerListener;


    private void SetUseContactReduction( bool flag )
    {
      m_useContactReduction = flag;

      if ( m_contactFilterReducerListener == null ) {
        InitContactReduction();
      }
      m_contactFilterReducerListener.setEnable( flag );
    }


    public uint UniqueGroupID
    {
      get { return m_uniqueGroupID; }
    }

    private MeshFilter m_polybagMeshFilter;
    private agxCollide.Trimesh m_bagMesh;
    private agxCollide.Geometry m_trackGeometry;
    private uint m_uniqueGroupID;

    private agxCollide.Trimesh m_investigateMesh = null;
    private agxCollide.Trimesh m_scaledMesh = null;

    private agxSDK.Assembly m_assembly;

    public agxCollide.GeometryRefSetVector GetGeometries()
    {
      return m_assembly.getGeometries();
    }

    private Vector3[] m_unityRenderVertices = null;

    [SerializeField]
    private ShapeMaterial m_material = null;


    private void InitRenderMeshes()
    {
      // Already done
      if ( m_polybagMeshFilter )
        return;

      m_polybagMeshFilter = this.gameObject.AddComponent<MeshFilter>();
      MeshRenderer polybagMeshRenderer = this.gameObject.AddComponent<MeshRenderer>();
      polybagMeshRenderer.material = RenderMaterial;
    }

    protected override bool Initialize()
    {
      InitRenderMeshes();
      return true;
    }

    void Update()
    {

      if ( m_useContactReduction != useContactReduction )
        SetUseContactReduction( useContactReduction );
    }

    private void InitContactReduction()
    {
      m_contactFilterReducerListener = new agxSDK.ContactFilterReducerListener();
      GetSimulation().add( m_contactFilterReducerListener );

      var filter = new agxSDK.CollisionGroupFilter();
      filter.addGroup( m_uniqueGroupID );

      m_contactFilterReducerListener.add( 6, filter, (int)agxSDK.StepEventListener.ActivationMask.PRE_STEP );
      m_assembly.add( m_contactFilterReducerListener );

    }

    public void InitParameters( ElementResolution resolution, float length, float width, float height,
                               float compressibility, float bendability,
                               float mass, Material renderMaterial, ShapeMaterial material, Mesh polybagMesh, bool useContactReduction )
    {
      RenderMaterial = renderMaterial;
      m_material = material;
      m_assembly = CreatePolyBag( resolution, length, width, height, compressibility, bendability, mass, polybagMesh );
      m_assembly.setPosition( this.transform.position.ToHandedVec3() );
      m_assembly.setRotation( this.transform.rotation.ToHandedQuat() );

      GetSimulation().add( m_assembly );

      // Add contact reduction
      if ( useContactReduction )
        InitContactReduction();

      UpdateMeshes();
    }

    PolyBagUpdateJob m_updateJob;
    JobHandle m_updateJobHandle;
    agx.Vec3[] m_agxVerts = null;
    bool m_updateQueued = false;

    private void UpdateMeshes()
    {
      // Not initialized?
      if ( m_polybagMeshFilter == null )
        return;

      var m = m_bagMesh;

      var meshData = m.getMeshData();
      var vertices = meshData.getVertices();

      var n = vertices.Count;
      if ( m_agxVerts == null || m_agxVerts.Length != n )
        m_agxVerts = new agx.Vec3[ n ];

      vertices.Get( m_agxVerts );

      var natVerts = new NativeArray<agx.Vec3>(n, Allocator.TempJob);
      natVerts.CopyFrom( m_agxVerts );

      m_updateJob = new PolyBagUpdateJob()
      {
        vertices = natVerts,
        unityVertices = new NativeArray<Vector3>( n, Allocator.TempJob )
      };

      m_updateJobHandle = m_updateJob.Schedule();
      m_updateQueued = true;

      this.transform.position = m_trackGeometry.getPosition().ToHandedVector3();
      this.transform.rotation = m_trackGeometry.getRotation().ToHandedQuaternion();

    }

    private void LateUpdate()
    {
      if ( m_updateQueued ) {
        m_updateJobHandle.Complete();
        m_updateJob.unityVertices.CopyTo( m_unityRenderVertices );
        m_polybagMeshFilter.mesh.SetVertices( m_unityRenderVertices );

        m_updateJob.vertices.Dispose();
        m_updateJob.vertices = default;
        m_updateJob.unityVertices.Dispose();
        m_updateJob.unityVertices = default;

        m_updateQueued = false;
      }
    }

    protected void EnableAssembly( bool flag )
    {
      foreach ( var b in m_assembly.getRigidBodies() )
        b.setEnable( flag );

      foreach ( var c in m_assembly.getConstraints() )
        c.setEnable( flag );

      foreach ( var c in m_assembly.getConstraints() )
        c.setEnable( flag );

      foreach ( var l in m_assembly.getEventListeners() )
        l.get().setEnable( flag );

      foreach ( var g in m_assembly.getGeometries() )
        g.setEnable( flag );
    }

    protected override void OnEnable()
    {
      if ( m_assembly != null ) {
        EnableAssembly( true );
      }

      Simulation.Instance.StepCallbacks.PostStepForward += UpdateMeshes;

      base.OnEnable();
    }

    protected override void OnDisable()
    {
      EnableAssembly( false );

      if(Simulation.HasInstance)
        Simulation.Instance.StepCallbacks.PostStepForward -= UpdateMeshes;

      base.OnDisable();
    }



    protected override void OnDestroy()
    {
      if ( !Simulation.HasInstance )
        return;

      m_updateJobHandle.Complete();

      if ( m_updateJob.vertices.IsCreated )
        m_updateJob.vertices.Dispose();
      if ( m_updateJob.unityVertices.IsCreated )
        m_updateJob.unityVertices.Dispose();

      GetSimulation().remove( m_assembly );

      base.OnDestroy();
    }
    public agxSDK.Assembly CreatePolyBag( ElementResolution resolution, float sizeX, float sizeY,
                             float sizeZ, float compressibility,
                             float bendability, float mass, Mesh polybagMesh )
    {
      InitRenderMeshes();

      DeformerSingleton.Instance.GetInitialized<DeformerSingleton>();

      m_polybagMeshFilter.mesh = polybagMesh;

      var agx_vertices = new agx.Vec3Vector();
      var agx_indices = new agx.UInt32Vector();

      if ( m_investigateMesh == null ) {
        var vertices = polybagMesh.vertices;

        List<int> indices = new List<int>();
        polybagMesh.GetIndices( indices, 0 );

        var n = vertices.Count();
        agx_vertices.Capacity = n;

        for ( int i = 0; i < n; i++ )
          agx_vertices.Add( vertices[ i ].ToHandedVec3() );

        n = indices.Count;
        agx_indices.Capacity = n;

        for ( int i = 0; i < n; i++ )
          agx_indices.Add( (uint)indices[ i ] );

        m_investigateMesh = new agxCollide.Trimesh( agx_vertices, agx_indices, "UnityMesh", 0 );
      }

      m_unityRenderVertices = new Vector3[ polybagMesh.vertices.Count() ];

      var meshData = m_investigateMesh.getMeshData();

      // Find min/max in each direction
      agx.Vec2 rangeX = new agx.Vec2(agx.agxSWIG.Infinity, -agx.agxSWIG.Infinity);
      agx.Vec2 rangeY = new agx.Vec2(agx.agxSWIG.Infinity, -agx.agxSWIG.Infinity);
      agx.Vec2 rangeZ = new agx.Vec2(agx.agxSWIG.Infinity, -agx.agxSWIG.Infinity);

      var investigate_vertices = meshData.getVertices();
      foreach ( agx.Vec3 vertex in investigate_vertices ) {
        if ( vertex.x < rangeX[ 0 ] )
          rangeX[ 0 ] = vertex.x;
        if ( vertex.x > rangeX[ 1 ] )
          rangeX[ 1 ] = vertex.x;
        if ( vertex.y < rangeY[ 0 ] )
          rangeY[ 0 ] = vertex.y;
        if ( vertex.y > rangeY[ 1 ] )
          rangeY[ 1 ] = vertex.y;
        if ( vertex.z < rangeZ[ 0 ] )
          rangeZ[ 0 ] = vertex.z;
        if ( vertex.z > rangeZ[ 1 ] )
          rangeZ[ 1 ] = vertex.z;
      }

      float spanX = PolybagUtil.Span(rangeX);
      float spanY = PolybagUtil.Span(rangeY);
      float spanZ = PolybagUtil.Span(rangeZ);

      if ( spanX <= 0.0 || spanY <= 0.0 || spanZ <= 0.0 )
        return null;

      float scaleX = (sizeX / spanX);
      float scaleY = (sizeY / spanY);
      float scaleZ = (sizeZ / spanZ);

      //
      // Try to find out how much we must scale the mesh to fit into the specified "box shape"
      //
      agx.Vec3f spanVector = new agx.Vec3f(spanX, spanY, spanZ);
      agx.Vec3f scaleVector = new agx.Vec3f(scaleX, scaleY, scaleZ);

      agxCollide.Trimesh scaledMesh;

      // Lets create a mesh that is scaled to fit the box.
      // This is the mesh that we will be using to deform things.
      if ( m_scaledMesh == null ) {
        var scaleRotate = new agx.Matrix3x3(new agx.Vec3(scaleVector[0], scaleVector[1], scaleVector[2]));
        var n = agx_vertices.Count;
        for ( int i = 0; i < n; i++ ) {
          var v = scaleRotate.postMult(agx_vertices[i]);
          agx_vertices[ i ] = v;
        }
        scaledMesh = new agxCollide.Trimesh( agx_vertices, agx_indices, "ScaledMesh", 0 );
        m_scaledMesh = scaledMesh;
      }
      else {
        // Make a copy
        scaledMesh = m_scaledMesh.deepCopy();
      }

      if ( scaledMesh == null ) {
        UnityEngine.Debug.Log( "Failed to initialize polybag. Error creating trimesh." );
        return null;
      }

      agx.Vec3 sizeVector = new agx.Vec3(sizeX, sizeY, sizeZ);
      int minElement = sizeVector.minElement();
      int maxElement = sizeVector.maxElement();
      Debug.Assert( minElement != maxElement );

      int midElement = minElement == 0 ? (maxElement == 1 ? 2 : 1) : (maxElement == 0 ? (minElement == 1 ? 2 : 1) : 0);

      int resShort = 1;
      int resMid = (int)resolution * 3;

      float longShortRatio = Mathf.Max(1.0f, spanVector[maxElement] / spanVector[midElement]);


      int resLong = (int)(resMid * longShortRatio);
      int[] resolutionArray = new int[3];
      resolutionArray[ minElement ] = resShort;
      resolutionArray[ midElement ] = resMid;
      resolutionArray[ maxElement ] = resLong;
      LumpedLayer bottomLayer = new LumpedLayer(rangeX * scaleX, rangeY * scaleY, rangeZ * scaleZ, resolutionArray[0], resolutionArray[1], resolutionArray[2], compressibility * 10.0, bendability * 10.0);

      var centralIndex = bottomLayer.CalculateParticleIndex(resolutionArray[0] / 2, resolutionArray[1] / 2, resolutionArray[2] / 2);
      agx.RigidBody centralBody = bottomLayer._bodies[centralIndex];

      agxCollide.Geometry geom = new agxCollide.Geometry(scaledMesh);
      geom.setSensor( true );
      geom.setEnableCollisions( false );
      centralBody.add( geom, agx.AffineMatrix4x4.translate( -centralBody.getFrame().getLocalTranslate() ) );

      float fillArea = spanVector[midElement] * spanVector[maxElement] * scaleVector[midElement] * scaleVector[maxElement];// * fillRate;
      float lengthScale = Mathf.Sqrt(fillArea * spanVector[maxElement] / spanVector[midElement]);
      agx.Vec2 fillRangeXScaled = rangeX * scaleX;
      agx.Vec2 fillRangeYScaled = rangeY * scaleY;
      agx.Vec2 fillRangeZScaled = rangeZ * scaleZ;
      if ( minElement != 0 )
        fillRangeXScaled[ 1 ] -= spanX * scaleX * lengthScale;
      if ( minElement != 1 )
        fillRangeYScaled[ 1 ] -= spanY * scaleY * lengthScale;
      if ( minElement != 2 )
        fillRangeZScaled[ 1 ] -= spanZ * scaleZ * lengthScale;

      int[] fillResolutionArray = resolutionArray;
      if ( minElement != 0 )
        fillResolutionArray[ 0 ] = resMid - (int)resolution;
      if ( minElement != 1 )
        fillResolutionArray[ 1 ] = resMid - (int)resolution;
      if ( minElement != 2 )
        fillResolutionArray[ 2 ] = resMid - (int)resolution;

      LumpedLayer fillLayer = new LumpedLayer(fillRangeXScaled, fillRangeYScaled, fillRangeZScaled, fillResolutionArray[0], fillResolutionArray[1], fillResolutionArray[2], compressibility, bendability);

      agx.Vec3 offset = new agx.Vec3();
      offset[ minElement ] = spanVector[ minElement ] * scaleVector[ minElement ];
      foreach ( agx.RigidBody body in fillLayer._bodies ) {
        var frame = body.getFrame();
        frame.setLocalTranslate( frame.getLocalTranslate() + offset );
      }

      List<agx.Constraint> interLayerConstraints = new List<agx.Constraint>();
      ConstrainLayers( bottomLayer, fillLayer, interLayerConstraints, compressibility, bendability );

      List<agx.RigidBody> allBodies = new List<agx.RigidBody>();

      agxSDK.Assembly assembly = new agxSDK.Assembly();
      foreach ( agx.RigidBody body in bottomLayer._bodies ) {
        allBodies.Add( body );
        assembly.add( body );
      }

      foreach ( agx.RigidBody body in fillLayer._bodies ) {
        allBodies.Add( body );
        assembly.add( body );
      }

      foreach ( agx.Constraint c in bottomLayer._constraints )
        assembly.add( c );

      foreach ( agx.Constraint c in fillLayer._constraints )
        assembly.add( c );

      foreach ( agx.Constraint c in interLayerConstraints )
        assembly.add( c );

      agxUtil.TrimeshDeformer deformer = new agxUtil.TrimeshDeformer(scaledMesh, geom);
      DeformerSingleton.Instance.Native.add( deformer );

      // Just add it again.
      GetSimulation().add( DeformerSingleton.Instance.Native );

      DeformerSingleton.Instance.Native.setParallel( true );

      List<agx.RigidBody> vertexToBody = new List<agx.RigidBody>();

      var scaledMeshData = scaledMesh.getMeshData();
      var scaledMeshVertices = scaledMeshData.getVertices();

      var transform = geom.getTransform();

      foreach ( agx.Vec3 vertex in scaledMeshVertices ) {
        float minDistance2 = Mathf.Infinity;
        vertexToBody.Add( null );
        foreach ( var body in allBodies ) {
          var p = transform.postMult(vertex);

          float distance2 = (float)(body.getPosition().distance2(p));
          if ( distance2 < minDistance2 ) {
            minDistance2 = distance2;
            vertexToBody[ vertexToBody.Count - 1 ] = body;// body.get();
          }
        }
      }

      uint bagConvexID = GetSimulation().getSpace().getUniqueGroupID();
      float fullVolume = 0;

      var numVertices = scaledMeshVertices.Count;
      foreach ( var body in allBodies ) {
        agx.Vec3Vector convexVertices = new agx.Vec3Vector();
        float mostFarAway2 = 0.0f;
        for ( int i = 0; i < numVertices; ++i ) {
          if ( vertexToBody[ i ] == body ) // body.get())
          {
            float d2 = (float)scaledMeshVertices[i].distance2(body.getLocalPosition());
            if ( d2 > mostFarAway2 ) {
              mostFarAway2 = d2;
            }
            convexVertices.Add( scaledMeshVertices[ i ] );
          }
        }

        float extraDistance = scaleVector[maxElement] * spanVector[maxElement] / 20.0f;
        float mostFarAway = Mathf.Sqrt(mostFarAway2);
        for ( int i = 0; i < numVertices; ++i ) {
          float d2 = (float)(scaledMeshVertices[i].distance2(body.getLocalPosition()));
          if ( d2 > mostFarAway2 ) {
            float d = Mathf.Sqrt(d2);
            if ( d - mostFarAway < extraDistance ) {
              convexVertices.Add( scaledMeshVertices[ i ] );
            }
          }
        }

        agxCollide.Convex convexPart = agxUtil.agxUtilSWIG.createConvex(convexVertices);
        agxCollide.Geometry convexPartGeom = new agxCollide.Geometry(convexPart);
        convexPartGeom.addGroup( bagConvexID );
        convexPartGeom.setMaterial( m_material.Native );
        assembly.add( convexPartGeom );
        fullVolume += (float)convexPart.getVolume();

        body.add( convexPartGeom, agx.AffineMatrix4x4.translate( -body.getLocalPosition() ) );

        convexPart.ReturnToPool();
      }

      uint groupId = GetSimulation().getSpace().getUniqueGroupID();
      m_uniqueGroupID = groupId;
      float effectDistance = 1.5f * spanVector[maxElement] * scaleVector[maxElement] / (resolutionArray[maxElement] + 1.0f);

      float invFullVolume = 1 / fullVolume;
      foreach ( var body in allBodies ) {
        float massPercent = (float)body.getGeometries()[body.getGeometries().Count - 1].getShapes()[0].getVolume() * invFullVolume;
        body.getMassProperties().setMass( massPercent * mass );

        deformer.addLocalDeformation( centralBody, body, effectDistance, 1.0f );
        foreach ( agxCollide.GeometryRef geoms in body.getGeometries() ) {
          geoms.addGroup( groupId );
          geom.setMaterial( m_material.Native );
        }
      }

      GetSimulation().getSpace().setEnablePair( groupId, groupId, false );
      deformer.activateLocalOffsetInterpolation();
      assembly.add( deformer );

      m_bagMesh = scaledMesh;
      m_trackGeometry = geom;

      return assembly;
    }

    private void ConstrainLayers( LumpedLayer bottomLayer, LumpedLayer fillLayer, List<agx.Constraint> constraints,
                                 float compressibility, float bendability )
    {
      for ( int i = 0; i < fillLayer._resX; ++i ) {
        for ( int j = 0; j < fillLayer._resY; ++j ) {
          for ( int k = 0; k < fillLayer._resZ; ++k ) {
            int particle_index = fillLayer.CalculateParticleIndex(i, j, k);
            agx.RigidBody fillBody = fillLayer._bodies[particle_index];
            agx.RigidBody bottomBody = bottomLayer._bodies[particle_index];
            agx.Constraint constraint = PolybagUtil.CreateLockJoint(fillBody, bottomBody, compressibility, bendability, fillBody.getPosition().distance(bottomBody.getPosition()));
            constraint.rebind();
            constraints.Add( constraint );
          }
        }
      }
    }
  }

  public class LumpedLayer
  {
    public int _resX;
    public int _resY;
    public int _resZ;
    public List<agx.RigidBody> _bodies = new List<agx.RigidBody>();
    public List<agx.LockJoint> _constraints = new List<agx.LockJoint>();

    public LumpedLayer( agx.Vec2 rangeX, agx.Vec2 rangeY, agx.Vec2 rangeZ,
                       int resX, int resY, int resZ,
                       double compressibility, double bendability )
    {
      _resX = resX;
      _resY = resY;
      _resZ = resZ;
      CreateLayer( rangeX, rangeY, rangeZ, compressibility, bendability );
    }

    private void CreateLayer( agx.Vec2 rangeX, agx.Vec2 rangeY, agx.Vec2 rangeZ,
                             double compressibility, double bendability )
    {
      float dX = _resX <= 1 ? PolybagUtil.Span(rangeX) : PolybagUtil.Span(rangeX) / (_resX - 1);
      float dY = _resY <= 1 ? PolybagUtil.Span(rangeY) : PolybagUtil.Span(rangeY) / (_resY - 1);
      float dZ = _resZ <= 1 ? PolybagUtil.Span(rangeZ) : PolybagUtil.Span(rangeZ) / (_resZ - 1);

      var x = rangeX[0];

      for ( int i = 0; i < _resX; ++i ) {
        var y = rangeY[0];
        for ( int j = 0; j < _resY; ++j ) {
          var z = rangeZ[0];
          for ( int k = 0; k < _resZ; ++k ) {
            agx.Vec3 pos = new agx.Vec3(x, y, z);
            agx.RigidBody particle = new agx.RigidBody();
            agxCollide.Geometry geom = new agxCollide.Geometry(new agxCollide.Sphere(0.001));
            geom.setEnableCollisions( false );
            particle.add( geom );
            particle.getMassProperties().setMass( 0.01 );
            particle.setLocalPosition( pos );
            _bodies.Add( particle );

            if ( i > 0 ) {
              int particleIndex = CalculateParticleIndex(i - 1, j, k);
              agx.RigidBody prevParticle = _bodies[particleIndex];
              agx.LockJoint lockJoint = PolybagUtil.CreateLockJoint(particle, prevParticle, compressibility, bendability, dX);
              _constraints.Add( lockJoint );
            }

            if ( j > 0 ) {
              int particleIndex = CalculateParticleIndex(i, j - 1, k);
              agx.RigidBody prevParticle = _bodies[particleIndex];
              agx.LockJoint lockJoint = PolybagUtil.CreateLockJoint(particle, prevParticle, compressibility, bendability, dY);
              _constraints.Add( lockJoint );
            }
            if ( k > 0 ) {
              int particleIndex = CalculateParticleIndex(i, j, k - 1);
              agx.RigidBody prevParticle = _bodies[particleIndex];
              agx.LockJoint lockJoint = PolybagUtil.CreateLockJoint(particle, prevParticle, compressibility, bendability, dZ);
              _constraints.Add( lockJoint );
            }
            z += dZ;
          }
          y += dY;
        }
        x += dX;
      }
    }

    public int CalculateParticleIndex( int i, int j, int k )
    {
      return i * _resY * _resZ + j * _resZ + k;
    }
  }

  static class PolybagUtil
  {
    public static float Span( agx.Vec2 range )
    {
      return ( (float)range[ 1 ] - (float)range[ 0 ] );
    }

    public static agx.LockJoint CreateLockJoint( agx.RigidBody b1, agx.RigidBody b2,
                                                double compressibility, double bendability, double springLength )
    {
      agx.Frame frame1 = new agx.Frame();
      agx.Frame frame2 = new agx.Frame();
      agx.Vec3 vecaxi = new agx.Vec3(0, 1, 0);
      agx.Constraint.calculateFramesFromBody( new agx.Vec3( 0, 0, 0 ), vecaxi, b1, frame1, b2, frame2 );
      agx.LockJoint bj = new agx.LockJoint(b1, frame1, b2, frame2);
      bj.setSolveType( agx.Constraint.SolveType.DIRECT_AND_ITERATIVE );
      double linearCompliance = (0.001 + compressibility * 0.05) * springLength;
      double torsionCompliance = (0.1 + bendability * 4.9) * springLength;

      bj.setCompliance( linearCompliance, (long)agx.LockJoint.DOF.TRANSLATIONAL_1 );
      bj.setCompliance( linearCompliance, (long)agx.LockJoint.DOF.TRANSLATIONAL_2 );
      bj.setCompliance( linearCompliance, (long)agx.LockJoint.DOF.TRANSLATIONAL_3 );
      bj.setCompliance( torsionCompliance, (long)agx.LockJoint.DOF.ROTATIONAL_1 );
      bj.setCompliance( torsionCompliance, (long)agx.LockJoint.DOF.ROTATIONAL_2 );
      bj.setCompliance( torsionCompliance, (long)agx.LockJoint.DOF.ROTATIONAL_3 );
      bj.setDamping( 0.3 );

      return bj;
    }
  }
}