#pragma kernel MoveParticles
#include "./HashTableRead.cginc"
#include "./ConstantsAndStructs.cginc"

RWStructuredBuffer<FineParticle> fineParticles;

float timeStep;
float voxelSize;
float animationSpeed;

uint hash(uint x)
{
    x += (x << 10u);
    x ^= (x >> 6u);
    x += (x << 3u);
    x ^= (x >> 11u);
    x += (x << 15u);
    return x;
}

// Construct a float with half-open range [0:1] using low 23 bits.
// All zeroes yields 0.0, all ones yields the next smallest representable value below 1.0.
float floatConstruct(uint m)
{
    const uint ieeeMantissa = 0x007FFFFFu; // binary32 mantissa bitmask
    const uint ieeeOne = 0x3F800000u; // 1.0 in IEEE binary32

    m &= ieeeMantissa; // Keep only mantissa bits (fractional part)
    m |= ieeeOne; // Add fractional part to 1.0

    float f = asfloat(m); // Range [1:2]
    return f - 1.0; // Range [0:1]
}

float random(uint x)
{
    return floatConstruct(hash(x));
}

[numthreads(GRP_SIZE_PARTICLES, 1, 1)]
void MoveParticles(uint3 id : SV_DispatchThreadID)
{
    // Runs per fine particle
    uint fpIndex = id.x;
    if (fpIndex >= NUM_FINE_PARTICLES)
    {
        return;
    }
    
    int3 particleVoxelIndex = getVoxelIndexFromWorldPosition(FP_POSITION(fineParticles[fpIndex]), voxelSize);
    
    float weightSum = 0.0f;
    float3 velocitySum = float3(0, 0, 0);
    
    int3 voxelIndex;
    float3 voxelPos, voxelVel;
    
    for (int i = -1; i < 2; i++)
    {
        for (int j = -1; j < 2; j++)
        {
            for (int k = -1; k < 2; k++)
            {
                // Reduce checked voxels to only one voxel in each direction of each axis
                if (i + j + k != 0 && abs(i) + abs(j) + abs(k) != 1)
                    continue;
                voxelIndex = particleVoxelIndex + int3(i, j, k);
                if (LookupPosAndVel(voxelIndex, voxelPos, voxelVel))
                {
                    float voxelWeight = pow(max(0.0f, 1.0f - (distance2(voxelPos, FP_POSITION(fineParticles[fpIndex])) / (voxelSize * voxelSize))), 2);
                    weightSum += voxelWeight;
                    velocitySum += voxelWeight * voxelVel;
                }
                // Counting inactive voxels towards the total weight makes a more realistic velocity field but seems to lead to particles falling behind and freezing in inactive voxels
                //else
                //{
                //    float voxelWeight = pow(max(0.0f, 1.0f - (distance2(voxelIndex * voxelSize, FP_POSITION(fineParticles[fpIndex])) / (voxelSize * voxelSize))), 2);
                //    weightSum += voxelWeight;
                //}
            }
        }
    }

    float3 velocity = 0.0f;
    float3 randomOffset = float3(0.0f,0.0f,0.0f);//float3(random(hash(fpIndex + 1)) * 0.02f - 0.01f, random(hash(fpIndex + 2)) * 0.02f - 0.01f, random(hash(fpIndex + 3)) * 0.02f - 0.01f);
    
    if (weightSum != 0.0f)
    {
        velocity = (1.0f / weightSum) * velocitySum;
    }
    FP_POSITION(fineParticles[fpIndex]) += (velocity + randomOffset) * timeStep;
    EASE(fineParticles[fpIndex]) += animationSpeed;
    // Clamp ease to be at most 1.0 and to avoid dead particles hanging in mid air, mark particles with negative ease and negligible velocity (not counting random offset) for removal next synch
    EASE(fineParticles[fpIndex]) = (EASE(fineParticles[fpIndex]) < 0 && all(velocity < 0.001)) ? NAN : min(1.0, EASE(fineParticles[fpIndex]));

}