#include <fstream>

#include "rendering/Shader.h"
#include "utils/Logger.h"
#include "wrapper/GL.h"

Shader::Shader() : vertex(0), geometry(0), fragment(0), program(0) {
}

Shader::~Shader() {
    GL::deleteShader(vertex);
    GL::deleteShader(geometry);
    GL::deleteShader(fragment);
    GL::deleteProgram(program);
}

Error Shader::compile(const char* vertexPath, const char* geometryPath,
                      const char* fragmentPath) {
    if(vertexPath != nullptr) {
        Error error = compile(vertexPath, vertex, GL::VERTEX_SHADER);
        if(error.has()) {
            return error;
        }
    }
    if(geometryPath != nullptr) {
        Error error = compile(geometryPath, geometry, GL::GEOMETRY_SHADER);
        if(error.has()) {
            return error;
        }
    }
    if(fragmentPath != nullptr) {
        Error error = compile(fragmentPath, fragment, GL::FRAGMENT_SHADER);
        if(error.has()) {
            return error;
        }
    }
    program = GL::createProgram();
    if(vertexPath != nullptr) {
        GL::attachShader(program, vertex);
    }
    if(geometryPath != nullptr) {
        GL::attachShader(program, geometry);
    }
    if(fragmentPath != nullptr) {
        GL::attachShader(program, fragment);
    }
    GL::linkProgram(program);
    Error error = GL::getError("cannot link");
    if(error.has()) {
        return error;
    }
    return GL::getLinkerError(program);
}

Error Shader::compile(const char* path, GL::Shader& s, GL::ShaderType st) {
    List<char> code;
    Error error = readFile(code, path);
    if(error.has()) {
        return error;
    }
    return compile(s, code, st);
}

Error Shader::readFile(List<char>& code, const char* path) const {
    std::ifstream in;
    in.open(path);
    if(!in.good()) {
        return {"cannot read file"};
    }
    while(true) {
        int c = in.get();
        if(c == EOF) {
            break;
        }
        code.add(c);
    }
    code.add('\0');
    return {};
}

Error Shader::compile(GL::Shader& s, const List<char>& code,
                      GL::ShaderType st) {
    s = GL::createShader(st);
    GL::compileShader(s, code.begin());
    Error error = GL::getError("compile error");
    if(error.has()) {
        return error;
    }
    return GL::getCompileError(s);
}

void Shader::use() const {
    GL::useProgram(program);
}

void Shader::setMatrix(const char* name, const float* data) {
    GL::setMatrix(program, name, data);
}

void Shader::setInt(const char* name, int data) {
    GL::setInt(program, name, data);
}

void Shader::setFloat(const char* name, float data) {
    GL::setFloat(program, name, data);
}

void Shader::setVector(const char* name, const Vector2& v) {
    GL::set2Float(program, name, &(v[0]));
}

void Shader::setVector(const char* name, const Vector3& v) {
    GL::set3Float(program, name, &(v[0]));
}

void Shader::setVector(const char* name, const Vector4& v) {
    GL::set4Float(program, name, &(v[0]));
}