#include <fstream>
#include <iostream>

#include "wrapper/Shader.h"
#include "wrapper/GL.h"

Shader::Shader(const char* vertexPath, const char* fragmentPath) : vertexShader(0), fragmentShader(0), program(0) {
    if(readFileAndCompile(vertexPath, vertexShader, GL_VERTEX_SHADER) || 
            readFileAndCompile(fragmentPath, fragmentShader, GL_FRAGMENT_SHADER)) {
        return;
    }
    program = glCreateProgram();
    glAttachShader(program, vertexShader);
    glAttachShader(program, fragmentShader);
    glLinkProgram(program);
    if(GL::checkAndPrintError("cannot link")) {
        return;
    }
    GLint linked;
    glGetProgramiv(program, GL_LINK_STATUS, &linked);
    if(!linked) {
        ErrorLog log;
        glGetProgramInfoLog(program, log.getLength(), nullptr, log.begin());
        std::cout << "linker log: " << log.begin() << "\n";
        return;
    }
}

Shader::~Shader() {
    glDeleteShader(vertexShader);
    glDeleteShader(fragmentShader);
    glDeleteProgram(program);
}

bool Shader::readFileAndCompile(const char* path, GLuint& shader, GLenum shaderType) {
    std::cout << "shader: " << path << '\n';
    Code code;
    if(readFile(code, path)) {
        return true;
    }
    return compile(shader, code, shaderType);
}

bool Shader::readFile(Code& code, const char* path) const {
    std::ifstream in;
    in.open(path);
    if(!in.good()) {
        std::cout << "cannot read file\n";
        return true;
    }
    in.get(code.begin(), code.getLength(), EOF);
    return false;
}

bool Shader::hasError() const {
    return vertexShader == 0 || fragmentShader == 0 || program == 0;
}

bool Shader::compile(GLuint& shader, const Code& code, GLenum shaderType) {
    shader = glCreateShader(shaderType);
    const GLchar* buffer = code.begin();
    glShaderSource(shader, 1, &buffer, nullptr);
    glCompileShader(shader);
    if(GL::checkAndPrintError("compile error")) {
        return true;
    }
    GLint compiled;
    glGetShaderiv(shader, GL_COMPILE_STATUS, &compiled);
    if(!compiled) {
        ErrorLog log;
        glGetShaderInfoLog(shader, log.getLength(), nullptr, log.begin());
        std::cout << "compiler log: " << log.begin() << "\n";
        return true;
    }
    return false;
}

void Shader::use() const {
    glUseProgram(program);
}

void Shader::setMatrix(const GLchar* name, const GLfloat* data) {
    glUniformMatrix4fv(glGetUniformLocation(program, name), 1, GL_TRUE, data);
}

void Shader::setInt(const GLchar* name, GLint data) {
    glUniform1i(glGetUniformLocation(program, name), data);
}

void Shader::setFloat(const GLchar* name, GLfloat data) {
    glUniform1f(glGetUniformLocation(program, name), data);
}