﻿using System.Collections.Generic;
using UnityEngine;

public class OceanGeometry : MonoBehaviour
{
    [SerializeField] 
    WavesGenerator wavesGenerator;
    [SerializeField]
    Material oceanMaterial;
    [SerializeField]
    bool updateMaterialProperties;
    [SerializeField]
    bool showMaterialLods;

    [SerializeField]
    float lengthScale = 10;
    float previousLengthScale = 10;
    [SerializeField, Range(1, 400)]
    int vertexDensity = 30;
    Element center;
    Quaternion[] trimRotations;

    int previousVertexDensity;

    public float getLengthScale()
    {
        return lengthScale;
    }
  
    private void Start()
    {
        oceanMaterial.SetTexture("_Displacement_c0", wavesGenerator.cascade0.Displacement);
        oceanMaterial.SetTexture("_Derivatives_c0", wavesGenerator.cascade0.Derivatives);
        oceanMaterial.SetTexture("_Turbulence_c0", wavesGenerator.cascade0.Turbulence);

        oceanMaterial.SetTexture("_Displacement_c1", wavesGenerator.cascade1.Displacement);
        oceanMaterial.SetTexture("_Derivatives_c1", wavesGenerator.cascade1.Derivatives);
        oceanMaterial.SetTexture("_Turbulence_c1", wavesGenerator.cascade1.Turbulence);

        oceanMaterial.SetTexture("_Displacement_c2", wavesGenerator.cascade2.Displacement);
        oceanMaterial.SetTexture("_Derivatives_c2", wavesGenerator.cascade2.Derivatives);
        oceanMaterial.SetTexture("_Turbulence_c2", wavesGenerator.cascade2.Turbulence);

//        oceanMaterial.EnableKeyword("CLOSE");

        trimRotations = new Quaternion[]
        {
            Quaternion.AngleAxis(180, Vector3.up),
            Quaternion.AngleAxis(90, Vector3.up),
            Quaternion.AngleAxis(270, Vector3.up),
            Quaternion.identity,
        };

        InstantiateMeshes();
    }

    private void Update()
    {
        if (previousVertexDensity != vertexDensity || previousLengthScale != lengthScale)
        {
            InstantiateMeshes();
            previousVertexDensity = vertexDensity;
            previousLengthScale = lengthScale;
        }

       UpdatePositions();
    }


    void UpdatePositions()
    {
        int k = GridSize();

        center.Transform.position = new Vector3(-lengthScale, 0, -lengthScale);
        center.Transform.localScale = new Vector3(lengthScale / GridSize(), 1, lengthScale / GridSize());
    }
    int GridSize()
    {
        return 4 * vertexDensity + 1;
    }

    void InstantiateMeshes()
    {
        foreach (var child in gameObject.GetComponentsInChildren<Transform>())
        {
            if (child != transform)
                Destroy(child.gameObject);
        }

        int k = GridSize();
        center = InstantiateElement("Center", CreatePlaneMesh(2 * k, 2 * k, 1), oceanMaterial);
    }

    Element InstantiateElement(string name, Mesh mesh, Material mat)
    {
        GameObject go = new GameObject();
        go.name = name;
        go.transform.SetParent(transform);
        go.transform.localPosition = Vector3.zero;
        MeshFilter meshFilter = go.AddComponent<MeshFilter>();
        meshFilter.mesh = mesh;
        MeshRenderer meshRenderer = go.AddComponent<MeshRenderer>();
        meshRenderer.shadowCastingMode = UnityEngine.Rendering.ShadowCastingMode.Off;
        meshRenderer.receiveShadows = true;
        meshRenderer.motionVectorGenerationMode = MotionVectorGenerationMode.Camera;
        meshRenderer.material = mat;
        meshRenderer.allowOcclusionWhenDynamic = false;
        return new Element(go.transform, meshRenderer);
    }

    Mesh CreatePlaneMesh(int width, int height, float lengthScale, int trianglesShift = 0)
    {
        Mesh mesh = new Mesh();

        mesh.name = "Clipmap plane";
        if ((width + 1) * (height + 1) >= 256 * 256)
            mesh.indexFormat = UnityEngine.Rendering.IndexFormat.UInt32;
        Vector3[] vertices = new Vector3[(width + 1) * (height + 1)];
        int[] triangles = new int[width * height * 2 * 3];
        Vector3[] normals = new Vector3[(width + 1) * (height + 1)];

        for (int i = 0; i < height + 1; i++)
        {
            for (int j = 0; j < width + 1; j++)
            {
                int x = j;
                int z = i;
                vertices[j + i * (width + 1)] = new Vector3(x, 0, z) * lengthScale;
                normals[j + i * (width + 1)] = Vector3.up;
            }
        }

        int tris = 0;
        for (int i = 0; i < height; i++)
        {
            for (int j = 0; j < width; j++)
            {
                int k = j + i * (width + 1);
                if ((i + j + trianglesShift) % 2 == 0)
                {
                    triangles[tris++] = k;
                    triangles[tris++] = k + width + 1;
                    triangles[tris++] = k + width + 2;

                    triangles[tris++] = k;
                    triangles[tris++] = k + width + 2;
                    triangles[tris++] = k + 1;
                }
                else
                {
                    triangles[tris++] = k;
                    triangles[tris++] = k + width + 1;
                    triangles[tris++] = k + 1;

                    triangles[tris++] = k + 1;
                    triangles[tris++] = k + width + 1;
                    triangles[tris++] = k + width + 2;
                }
            }
        }

        mesh.vertices = vertices;
        mesh.triangles = triangles;
        mesh.normals = normals;
        return mesh;
    }

    class Element
    {
        public Transform Transform;
        public MeshRenderer MeshRenderer;

        public Element(Transform transform, MeshRenderer meshRenderer)
        {
            Transform = transform;
            MeshRenderer = meshRenderer;
        }
    }
}


