﻿using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using AGXUnity;
using Unity.MLAgents.Actuators;

namespace AGXUnity_RobotML.Scripts
{
  public class RobotAgent : Agent
  {
    public int nrEpisodes = 5;
    public GameObject robot;
    public GameObject goal;
    public float[] heuristicActions = new float[ 6 ];

    private agxCollide.Geometry m_goalGeometry;
    private RigidBody m_tcp;
    private RigidBody m_tool;
    private RigidBody[] m_robotBodies;
    private RobotMotor[] m_motors;
    private Constraint[] m_constraints;
    private float[] m_constraint_speed_limit = { 30.0f, 125.0f, 60.0f, 100.0f, 30.0f, 20.0f };
    private Vector2[] m_constraint_range_limit = {  new Vector2(-5.24f, 5.24f),
                                                    new Vector2(-1.13f, 1.48f),
                                                    new Vector2(-1.22f, 3.14f),
                                                    new Vector2(-5.24f, 5.24f),
                                                    new Vector2(-2.09f, 2.09f),
                                                    new Vector2(-5.24f, 5.24f)  };
    private float[] m_constraint_force_limit = { 500.0f, 500.0f, 500.0f, 200.0f, 200.0f, 100.0f };
    private agxCollide.Space m_space;
    private agx.Vec3 m_robotPosition;
    private string[] m_hingeNames = { "baseHinge", "hipHinge", "shoulderHinge", "armHinge", "wristHinge2", "wristHinge" };
    private float m_limit = 0.7f;
    private int m_episodeNr = 10;
    private int m_maxEpisodeNr = 4;
    private bool m_isDisabled = true;
    private float m_lastTargetTheta = 0.5f * Mathf.PI;

    private EnvironmentParameters m_envParameters;

    public override void Initialize()
    {
      m_constraints = new Constraint[ m_hingeNames.Length ];
      m_motors = new RobotMotor[ m_hingeNames.Length ];

      var box = goal.GetComponent<AGXUnity.Collide.Box>().GetInitialized<AGXUnity.Collide.Box>().Native.asBox();
      m_goalGeometry = box.getGeometry();

      m_envParameters = Academy.Instance.EnvironmentParameters;

      Academy.Instance.AutomaticSteppingEnabled = false;
      Simulation.Instance.StepCallbacks.PostStepForward += Academy.Instance.EnvironmentStep;
      m_episodeNr = 10;

      m_space = Simulation.Instance.Native.getSpace();
    }



    public override void CollectObservations( VectorSensor sensor )
    {
      if ( m_isDisabled ) {
        for ( int i = 0; i < 30; i++ ) {
          sensor.AddObservation( 0.0f );
        }
        return;
      }

      var relGoalPosition = robot.transform.InverseTransformPoint( goal.transform.position );
      var tcpPosition = robot.transform.InverseTransformPoint( m_tcp.gameObject.transform.position );
      // rel goal position
      sensor.AddObservation( normalize( relGoalPosition - tcpPosition, new Vector3( -7.0f, -7.0f, -3.0f ), new Vector3( 7.0f, 7.0f, 3.0f ) ) );
      sensor.AddObservation( normalize( goal.transform.forward - m_tcp.gameObject.transform.forward, -2 * Vector3.one, 2 * Vector3.one ) );

      // tool position 
      sensor.AddObservation( normalize( tcpPosition, new Vector3( -3.5f, -3.5f, 0.0f ), new Vector3( 3.5f, 3.5f, 4.5f ) ) );
      sensor.AddObservation( m_tcp.gameObject.transform.forward );

      // tool velocity 
      sensor.AddObservation( normalize( m_tcp.LinearVelocity, new Vector3( -100f, -125f, -100f ), new Vector3( 100f, 125f, 100f ) ) );
      // tool velocity 
      sensor.AddObservation( normalize( m_tcp.AngularVelocity, -125.0f * Vector3.one, 125.0f * Vector3.one ) );
      // hinge state
      for ( int i = 0; i < m_constraints.Length; i++ ) {
        var hinge = m_constraints[ i ].Native.asHinge();
        Vector2 range = m_constraint_range_limit[ i ];
        float speed = m_constraint_speed_limit[ i ];
        sensor.AddObservation( normalize( (float)hinge.getAngle(), range.x, range.y ) );
        sensor.AddObservation( normalize( (float)hinge.getCurrentSpeed(), -speed, speed ) );
      }
    }



    private float normalize( float v, float min, float max )
    {
      return 2 * ( v - min ) / ( max - min ) - 1.0f;
    }



    private Vector3 normalize( Vector3 vector, Vector3 min, Vector3 max )
    {
      return new Vector3(
        normalize( vector.x, min.x, max.x ),
        normalize( vector.y, min.y, max.y ),
        normalize( vector.z, min.z, max.z )
        );
    }



    public override void OnActionReceived( ActionBuffers vectorAction )
    {
      SetReward( CalculateReward() );
      // Check if tool is colliding with the robot and end episode if that is the case
      if ( ToolCollisionSensor() ) {
        m_episodeNr = (int)m_envParameters.GetWithDefault( "nr_episodes_curriculum", m_maxEpisodeNr ); // The episode number is increased so that the robot will be completely reset
        EndEpisode();
      }
      else {
        // Here the torque is applied to each hinge motor
        for ( int i = 0; i < m_hingeNames.Length; i++ ) {
          // Since the action from the Brain is a continuous values between -1 and 1, it needs to be rescaled.
          m_motors[ i ].torque = m_constraint_force_limit[ i ] * Mathf.Clamp( vectorAction.ContinuousActions[ i ], -1f, 1f );
        }
      }
    }



    public override void OnEpisodeBegin()
    {
      var maxEpisodeNr = (int)m_envParameters.GetWithDefault( "nr_episodes_curriculum", m_maxEpisodeNr );
      if ( m_episodeNr >= maxEpisodeNr ) {
        if ( robot != null )
          DestroyImmediate( robot );

        Simulation.Instance.Native.garbageCollect();
        m_isDisabled = false;
        robot = Instantiate( Resources.Load<GameObject>( "IRB6700_220_265_SW6_LeanID" ) );
        InitializeRobot();

        m_episodeNr = 0;
        m_lastTargetTheta = 0.5f * Mathf.PI;
      }

      ChangeGoalPosition();
      m_episodeNr++;
    }


    public override void Heuristic( in ActionBuffers actionsOut )
    {
      var actions = actionsOut.ContinuousActions;
      for ( int i = 0; i < 6; i++ ) {
        actions[ i ] = heuristicActions[ i ];
      }
    }



    protected override void OnDisable()
    {
      m_isDisabled = true;
      base.OnDisable();
    }



    private void InitializeRobot()
    {
      for ( int i = 0; i < m_hingeNames.Length; i++ ) {
        m_constraints[ i ] = robot.transform.Find( m_hingeNames[ i ] ).GetComponent<Constraint>().GetInitialized<Constraint>();
        m_motors[ i ] = m_constraints[ i ].GetComponent<RobotMotor>().GetInitialized<RobotMotor>();
      }

      m_tcp = robot.transform.Find( "TCP" ).GetComponent<RigidBody>().GetInitialized<RigidBody>();
      m_tool = robot.transform.Find( "WeldingTool" ).GetComponent<RigidBody>().GetInitialized<RigidBody>();
      m_robotPosition = robot.transform.Find( "IRB6700_235-265_IRC5_rev00_BASE_CAD" ).GetComponent<RigidBody>().GetInitialized<RigidBody>().Native.getPosition();

      m_robotBodies = robot.gameObject.GetComponentsInChildren<RigidBody>();
    }



    private bool ToolCollisionSensor()
    {
      agxCollide.GeometryContactPtrVector matches = new agxCollide.GeometryContactPtrVector();
      foreach ( var robotBody in m_robotBodies ) {
        if ( !( robotBody.name.Contains( "TCP" ) || robotBody.name.Contains( "Tool" ) ) ) {
          var nrContacts = m_space.getGeometryContacts( matches, m_tool.Native, robotBody.Native );
          if ( nrContacts > 0 ) {
            return true;
          }
        }
      }
      return false;
    }



    // Method to calculate the reward. The reward is based on how close the robot tool tip is to the goal position. This
    // means high reward if it is a the right place and has the correct rotation.
    private float CalculateReward()
    {
      // The distance reward in this case is based on the normalized distance to the goal position, to the power of 0.4.
      // This is so that the reward increases more rapidly as the distance is close to zero.
      var relGoalPosition = robot.transform.InverseTransformPoint( goal.transform.position );
      var relTcpPosition = robot.transform.InverseTransformPoint( m_tcp.gameObject.transform.position );
      var diff = relGoalPosition - relTcpPosition;

      var rewardDiff = Mathf.Min( diff.magnitude, m_limit );
      var distReward = 1f - Mathf.Pow( rewardDiff / m_limit, 0.4f );

      // Calculate the rotation between the quaternians and base the reward on that.
      var q = goal.transform.rotation * Quaternion.Inverse( m_tcp.transform.rotation );
      var angleDiff = Mathf.Min( Mathf.Sqrt( Mathf.Pow( q.x, 2 ) + Mathf.Pow( q.y, 2 ) + Mathf.Pow( q.z, 2 ) ), 1f );
      var angleReward = 1f - Mathf.Pow( angleDiff / 1f, 0.4f );
      // Debug.Log((angleReward, distReward));

      // The final reward is a multiplication of the two rewards and scaling so that the total reward possible does not
      // become unresonably large.
      float rewardScaling = 100f / MaxStep;
      // Debug.Log(rewardScaling * distReward * angleReward);
      return rewardScaling * distReward * angleReward;
    }



    // Function for moving the goal to a new, random position within range
    private void ChangeGoalPosition()
    {
      // Cylindrical coordinates are used to generate the goal position
      var h = m_envParameters.GetWithDefault( "goal_height_curriculum", Random.Range( 0.6f, 2.6f ) );

      var theta = m_lastTargetTheta + m_envParameters.GetWithDefault( "goal_theta_curriculum", Random.Range( -3.14f / 4, 3.14f / 4 ) );
      // Restrict the max and min theta. To the training range.
      if ( theta > 0.8f * Mathf.PI ) {
        theta = 0.8f * Mathf.PI;
      }
      else if ( theta < 0.2f * Mathf.PI ) {
        theta = 0.2f * Mathf.PI;
      }

      var minRadius = 1.0f;
      var maxRadius = 2.6f;
      // Restrict the max and min radius to the current height.
      if ( h < 2f ) {
        minRadius = 1.4f;
      }
      else if ( h >= 2 ) {
        maxRadius = 2f;
      }
      var radius = Random.Range( minRadius, maxRadius );

      var goalX = radius * Mathf.Cos( theta );
      var goalY = radius * Mathf.Sin( theta );
      agx.Vec3 targetPosition = m_robotPosition + new agx.Vec3( goalY, h, goalX );
      m_goalGeometry.setPosition( targetPosition );

      // Rotate the goal
      m_goalGeometry.setRotation( new agx.Quat() );
      // Find vector from goal to robot
      var diff = m_robotPosition - targetPosition;
      // Find two perpendicular vector
      var v2 = new agx.Vec3( -diff.z, 0.0f, diff.x );
      var v3 = agx.Vec3.Y_AXIS();
      // Perturb diff vector using the perpendicular vectors in a random way
      diff = diff + Random.Range( 0.0f, 5.0f ) * v3 + Random.Range( -2.0f, 2.0f ) * v2;
      // Find the current green side vector of goal
      var x = m_goalGeometry.getFrame().transformVectorToWorld( -agx.Vec3.Z_AXIS() );
      // Find rotation for green side to the perturb diff vector
      var newRot = new agx.Quat( x, diff );
      // rotate the goal.
      m_goalGeometry.setRotation( m_goalGeometry.getRotation() * newRot );
      // m_goalGeometry.setRotation(newRot);

      // Debug.Log((h, theta, radius, diff));

      m_lastTargetTheta = theta;
    }
  }
}