#pragma kernel UpdateGrid
#pragma kernel ApplyParticleMass
#pragma kernel CompactFineParticles
#pragma kernel SwapParticleBuffers
#pragma kernel SpawnParticles
#pragma kernel ClearTable
#include "./HashTableRW.cginc"
#include "./ConstantsAndStructs.cginc"

//https://developer.nvidia.com/content/understanding-structured-buffer-performance
StructuredBuffer<int4> activeVoxelIndices;
StructuredBuffer<CoarseParticle> coarseParticles;
RWStructuredBuffer<FineParticle> fineParticles;
RWStructuredBuffer<FineParticle> fineParticlesNew;
#define VOLUME_MOD 1.6119919540164696407169668466392849389446140723238615

int numActiveVoxels;
int numCoarseParticles;
int time;
float voxelSize;
float fineParticleMass;
float animationSpeed;
float nominalRadius;

[numthreads(GRP_SIZE_VOXELS, 1, 1)]
void UpdateGrid (uint3 id : SV_DispatchThreadID)
{
    uint myId = id.x;

    if ((int)myId >= numActiveVoxels)
        return;
    
    float3 voxelPos = activeVoxelIndices[myId].xyz * voxelSize;
    
    float3 minBound;
    float3 maxBound;
    float mass = 0;
    float3 velocity = float3(0, 0, 0);
    float totalWeight = 0;
    
    bool first = true;
    for (int i = 0; i < numCoarseParticles; i++)
    {
        float particleRadius = RADIUS(coarseParticles[i]);
        float AABBRadius = VOLUME_MOD * particleRadius / 2;
        float3 toEdge = abs(CP_POSITION(coarseParticles[i]) - voxelPos) - AABBRadius;
        //If particle intersects my voxel
        if (all(voxelSize/2 > toEdge))
        {   
            // Calculate particles AABB volume that is inside current voxel
            float3 AABBMax = CP_POSITION(coarseParticles[i]) + AABBRadius;
            float3 AABBMin = CP_POSITION(coarseParticles[i]) - AABBRadius;
            AABBMax = min(AABBMax, voxelPos + voxelSize / 2);
            AABBMin = max(AABBMin, voxelPos - voxelSize / 2);
            float3 intersectingBox = AABBMax - AABBMin;
            float weight = intersectingBox.x * intersectingBox.y * intersectingBox.z / (AABBRadius * AABBRadius * AABBRadius * 8);
            
            totalWeight += weight;
            mass += MASS(coarseParticles[i]) * weight;
            velocity += VELOCITY(coarseParticles[i]) * weight;            
            if (first)
            {
                minBound = AABBMin;
                maxBound = AABBMax;
                first = false;
            }
            else
            {
                minBound = min(minBound, AABBMin);
                maxBound = max(maxBound, AABBMax);
            }
        }            
    }
    
    if (!first)
    {
        minBound = max(minBound, voxelPos - voxelSize / 2);
        maxBound = min(maxBound, voxelPos + voxelSize / 2);
    }
    else
    {
        minBound = float3(0, 0, 0);
        maxBound = float3(0, 0, 0);
    }
    
    if (totalWeight > 0.0f)
    {
        velocity /= totalWeight;
    }
    
    float3 r = float3(nominalRadius, nominalRadius, nominalRadius);
    
    float3 innerMaxBounds = maxBound - (maxBound < (voxelPos + voxelSize / 2)) * r;
    float3 innerMinBounds = minBound + (minBound > (voxelPos - voxelSize / 2)) * r;
               
    
    int numParticlesWhichFitInVoxel = ceil(mass / fineParticleMass);
    InsertIndex(activeVoxelIndices[myId].xyz, numParticlesWhichFitInVoxel, voxelPos, mass, velocity, minBound, maxBound, innerMinBounds, innerMaxBounds);
}

[numthreads(GRP_SIZE_PARTICLES, 1, 1)]
void ApplyParticleMass(uint3 id : SV_DispatchThreadID)
{
    // Runs Per fine particle
    uint myId = id.x;

    if (myId >= NUM_FINE_PARTICLES)
        return;

    
    if (!isnan(EASE(fineParticles[myId])))
    {
        int res;
        int3 voxelIndex = getVoxelIndexFromWorldPosition(FP_POSITION(fineParticles[myId]), voxelSize);
        AddToRoomAtIndex(voxelIndex, -1, res);
        if (res < -EXTRA_PARTICLES_IN_VOXEL)
        {
            if (abs(EASE(fineParticles[myId])) < 2 * animationSpeed)
            {
                //Remove particle
                EASE(fineParticles[myId]) = NAN;
            }
            else if (EASE(fineParticles[myId]) > 0.0f)
            {
                //Ease out particle
                EASE(fineParticles[myId]) = -EASE(fineParticles[myId]);
            }
        }
        else
        {
            float3 innerMinBounds, innerMaxBounds, _;
            LookupBounds(voxelIndex, _, _, innerMinBounds, innerMaxBounds);
            float3 innerBoundsSize = innerMaxBounds - innerMinBounds;
            float3 innerBoundsSizeHalf = innerBoundsSize / 2;
            float3 innerBoundsCenter = innerMinBounds + innerBoundsSizeHalf;
            if (length(max(abs(FP_POSITION(fineParticles[myId]) - innerBoundsCenter) - innerBoundsSizeHalf, float3(0, 0, 0))) >= nominalRadius)
            {
                //Ease out particle
                EASE(fineParticles[myId]) = -EASE(fineParticles[myId]);
            }

        }
    }
        
    NEW_NUM_FINE_PARTICLES = 0;
}

[numthreads(GRP_SIZE_PARTICLES, 1, 1)]
void CompactFineParticles(uint3 id : SV_DispatchThreadID)
{
    uint myId = id.x;

    if (myId >= NUM_FINE_PARTICLES || isnan(EASE(fineParticles[myId])))
        return;
    
    uint prev;
    InterlockedAdd(NEW_NUM_FINE_PARTICLES, 1, prev);
    fineParticlesNew[prev] = fineParticles[myId];
}

[numthreads(GRP_SIZE_PARTICLES, 1, 1)]
void SwapParticleBuffers(uint3 id : SV_DispatchThreadID)
{
    uint myId = id.x;
    if (myId == 0)
    {
        NUM_FINE_PARTICLES = NEW_NUM_FINE_PARTICLES;
        NUM_PARTICLE_THREAD_GROUPS = NUM_FINE_PARTICLES == 0 ? 0 : NUM_FINE_PARTICLES / GRP_SIZE_PARTICLES + 1;
    }
    if (myId >= NEW_NUM_FINE_PARTICLES)
        return;
    
    fineParticles[myId] = fineParticlesNew[myId];
}

uint hash(uint x)
{
    x += (x << 10u);
    x ^= (x >> 6u);
    x += (x << 3u);
    x ^= (x >> 11u);
    x += (x << 15u);
    return x;
}

// Construct a float with half-open range [0:1] using low 23 bits.
// All zeroes yields 0.0, all ones yields the next smallest representable value below 1.0.
float floatConstruct(uint m)
{
    const uint ieeeMantissa = 0x007FFFFFu; // binary32 mantissa bitmask
    const uint ieeeOne = 0x3F800000u; // 1.0 in IEEE binary32

    m &= ieeeMantissa; // Keep only mantissa bits (fractional part)
    m |= ieeeOne; // Add fractional part to 1.0

    float f = asfloat(m); // Range [1:2]
    return f - 1.0; // Range [0:1]
}

float random(uint x)
{
    return floatConstruct(hash(x));
}

uint convertBitsToUint(bool b0, bool b1, bool b2)
{
    uint result = 0;
    
    result |= (b0 ? 1u : 0u) << 0;
    result |= (b1 ? 1u : 0u) << 1;
    result |= (b2 ? 1u : 0u) << 2;
    
    return result;
}

[numthreads(GRP_SIZE_VOXELS, 1, 1)]
void SpawnParticles(uint3 id : SV_DispatchThreadID)
{
    uint myId = id.x;

    if ((int)myId >= numActiveVoxels)
        return;
    int3 myVoxelIndex = activeVoxelIndices[myId].xyz;

    float3 offset;

    
    int particleRoom;
    float3 minBounds;
    float3 maxBounds;
    float3 innerMinBounds;
    float3 innerMaxBounds;
    float3 voxelPos = myVoxelIndex * voxelSize;
    
    LookupRoomAndBounds(myVoxelIndex.xyz, particleRoom, minBounds, maxBounds, innerMinBounds, innerMaxBounds);

    if (all(maxBounds == 0) && all(minBounds == 0))
        return;

    float3 boundsSize = maxBounds - minBounds;
    float3 innerBoundsSize = innerMaxBounds - innerMinBounds;
    float3 innerBoundsSizeHalf = innerBoundsSize / 2;
    float3 innerBoundsCenter = innerMinBounds + innerBoundsSizeHalf;
    
    uint prevNumParticles;
	while (particleRoom > 0)
	{
        int seed = time + myId * 1000 + particleRoom * 1000;       
        
        // Randomise point within bounds
        float3 offset;
        offset.x = random(seed++) * boundsSize.x;
        offset.y = random(seed++) * boundsSize.y;
        offset.z = random(seed++) * boundsSize.z;
        float3 randomPos = minBounds + offset;                
        
        //Check if point is within the rounded cuboid
        if (length(max(abs(randomPos - innerBoundsCenter) - innerBoundsSizeHalf, float3(0, 0, 0))) < nominalRadius)
        {
            InterlockedAdd(NUM_FINE_PARTICLES, 1, prevNumParticles);
            if (prevNumParticles % GRP_SIZE_PARTICLES == 0)
                InterlockedAdd(NUM_PARTICLE_THREAD_GROUPS, 1);
            FP_POSITION(fineParticles[prevNumParticles]) = randomPos;
            EASE(fineParticles[prevNumParticles]) = 0.0f;
            RAND(fineParticles[prevNumParticles]) = random(seed++);
        }
        
        // Subtract particleRoom even if no particle is spawned
        particleRoom -= 1;
    }
}
