#include "SSAOShader.h"
#include "../Wrapper.h"
#include <random>

SSAOShader::SSAOShader()
{
}

SSAOShader::~SSAOShader()
{
    glDeleteFramebuffers(1, &framebuffer);
    glDeleteTextures(1, &texture);
    glDeleteTextures(1, &noiseTexture);
}

bool SSAOShader::init()
{
    program.compile("shader/ssaoVertex.vs", "shader/ssaoFragment.fs");
    if(!program.isValid())
    {
        return false;
    }
    
    // generate framebuffer
    glGenFramebuffers(1, &framebuffer);
    glBindFramebuffer(GL_FRAMEBUFFER, framebuffer);
    
    // color texture
    glGenTextures(1, &texture);
    glBindTexture(GL_TEXTURE_2D, texture);
    glTexImage2D(GL_TEXTURE_2D, 0, GL_RED, Engine::getWidth(), Engine::getHeight(), 0, GL_RGB, GL_FLOAT, NULL);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
    // attache color texture to framebuffer
    glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, texture, 0);  
    
    // check if framebuffer is okay
    if(glCheckFramebufferStatus(GL_FRAMEBUFFER) != GL_FRAMEBUFFER_COMPLETE)
    {
	cout << "ssao frame buffer is not complete!" << endl;
        return false;
    }
    // unbind framebuffer
    glBindFramebuffer(GL_FRAMEBUFFER, 0);
    
    // generate noise data
    std::uniform_real_distribution<float> randomF(0.0, 1.0);
    std::default_random_engine gen;
    float noise[48];
    for(int i = 0; i < 16; i++)
    {
        noise[i * 3] = randomF(gen) * 2.0 - 1.0;
        noise[i * 3 + 1] = randomF(gen) * 2.0 - 1.0;
        noise[i * 3 + 2] = 0.0f;
    }  
    
    // noise texture
    glGenTextures(1, &noiseTexture);
    glBindTexture(GL_TEXTURE_2D, noiseTexture);
    glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB16F, 4, 4, 0, GL_RGB, GL_FLOAT, noise);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT); 
    
    // get uniform locations
    unifProjMatrix = glGetUniformLocation(program.getProgram(), "projMatrix");
    unifNumberOfSamples = glGetUniformLocation(program.getProgram(), "numberOfSamples");
    unifRadius = glGetUniformLocation(program.getProgram(), "radius");
    unifWidth = glGetUniformLocation(program.getProgram(), "width");
    unifHeight = glGetUniformLocation(program.getProgram(), "height");
    
    return true;
}

void SSAOShader::resize()
{
    glBindTexture(GL_TEXTURE_2D, texture);
    glTexImage2D(GL_TEXTURE_2D, 0, GL_RED, Engine::getWidth(), Engine::getHeight(), 0, GL_RGB, GL_FLOAT, NULL);
}

void SSAOShader::preRender(const float* projMatrix)
{
    // bind ssao shader program
    glUseProgram(program.getProgram());
    
    // set projection matrix uniform
    glUniformMatrix4fv(unifProjMatrix, 1, 0, projMatrix);
    // set other uniforms
    glUniform1i(unifNumberOfSamples, numberOfSamples);
    glUniform1f(unifRadius, radius);
    glUniform1i(unifWidth, Engine::getWidth());
    glUniform1i(unifHeight, Engine::getHeight());
    
    // bind ssao framebuffer
    glBindFramebuffer(GL_FRAMEBUFFER, framebuffer);
    
    // clear color buffer
    glClear(GL_COLOR_BUFFER_BIT);
    
    // depth testing is not needed
    glDisable(GL_DEPTH_TEST);
}

void SSAOShader::bindTexture(unsigned int textureUnit)
{
    glActiveTexture(GL_TEXTURE0 + textureUnit);
    glBindTexture(GL_TEXTURE_2D, texture);
}

void SSAOShader::bindNoiseTexture(unsigned int textureUnit)
{
    glActiveTexture(GL_TEXTURE0 + textureUnit);
    glBindTexture(GL_TEXTURE_2D, noiseTexture);
}

void SSAOShader::setNumberOfSamples(int amount)
{
    numberOfSamples = min(max(amount, 0), 64);
}

int SSAOShader::getNumberOfSamples() const
{
    return numberOfSamples;
}

void SSAOShader::setSampleRadius(float sampleRadius)
{
    numberOfSamples = min(max(sampleRadius, 0.05f), 20.0f);
}

float SSAOShader::getSampleRadius() const
{
    return radius;
}