#include <cmath>

#include "math/Matrix.h"

Matrix::Matrix() {
    data[0] = Vector4(1.0f, 0.0f, 0.0f, 0.0f);
    data[1] = Vector4(0.0f, 1.0f, 0.0f, 0.0f);
    data[2] = Vector4(0.0f, 0.0f, 1.0f, 0.0f);
    data[3] = Vector4(0.0f, 0.0f, 0.0f, 1.0f);
}

Matrix& Matrix::set(int index, const Vector4& v) {
    data[index] = v;
    return *this;
}

Matrix Matrix::transpose() {
    Matrix m;
    for(int x = 0; x < 4; x++) {
        for(int y = 0; y < 4; y++) {
            m.data[x][y] = data[y][x];
        }
    }
    return m;
}

const float* Matrix::getValues() const {
    return &(data[0][0]);
}

Matrix& Matrix::operator*=(const Matrix& m) {
    data[0] = data[0][0] * m.data[0] + data[0][1] * m.data[1] + data[0][2] * m.data[2] + data[0][3] * m.data[3];
    data[1] = data[1][0] * m.data[0] + data[1][1] * m.data[1] + data[1][2] * m.data[2] + data[1][3] * m.data[3];
    data[2] = data[2][0] * m.data[0] + data[2][1] * m.data[1] + data[2][2] * m.data[2] + data[2][3] * m.data[3];
    data[3] = data[3][0] * m.data[0] + data[3][1] * m.data[1] + data[3][2] * m.data[2] + data[3][3] * m.data[3];
    return *this;
}

Matrix Matrix::operator*(const Matrix& other) const {
    Matrix m = *this;
    m *= other;
    return m;
}

Vector3 Matrix::operator*(const Vector3& v) const {
    Vector4 v4(v[0], v[1], v[2], 1.0f);
    return Vector3(data[0].dot(v4), data[1].dot(v4), data[2].dot(v4)) * (1.0f / data[3].dot(v4));
}

Matrix& Matrix::scale(const Vector3& v) {
    data[0] *= v[0];
    data[1] *= v[1];
    data[2] *= v[2];
    return *this;
}

Matrix& Matrix::scale(float s) {
    return scale(Vector3(s, s, s));
}

Matrix& Matrix::translate(const Vector3& v) {
    return translateX(v[0]).translateY(v[1]).translateZ(v[2]);
}

Matrix& Matrix::translateX(float tx) {
    data[0] += data[3] * tx;
    return *this;
}

Matrix& Matrix::translateY(float ty) {
    data[1] += data[3] * ty;
    return *this;
}

Matrix& Matrix::translateZ(float tz) {
    data[2] += data[3] * tz;
    return *this;
}

Matrix& Matrix::translateTo(const Vector3& v) {
    data[0] = Vector4(1.0f, 0.0f, 0.0f, v[0]);
    data[1] = Vector4(0.0f, 1.0f, 0.0f, v[1]);
    data[2] = Vector4(0.0f, 0.0f, 1.0f, v[2]);
    data[3] = Vector4(0.0f, 0.0f, 0.0f, 1.0f);
    return *this;
}

Matrix& Matrix::rotate(float degrees, int a, int b) {
    float sin;
    float cos;
    sincosf(degrees * (M_PI / 180.0f), &sin, &cos);
    Vector4 v = data[a];
    data[a] = cos * data[a] - sin * data[b];
    data[b] = sin * v + cos * data[b];
    return *this;
}

Matrix& Matrix::rotateX(float degrees) {
    return rotate(degrees, 1, 2);
}

Matrix& Matrix::rotateY(float degrees) {
    return rotate(-degrees, 0, 2);
}

Matrix& Matrix::rotateZ(float degrees) {
    return rotate(degrees, 0, 1);
}

Matrix& Matrix::rotate(const Quaternion& q) {
    Vector3 a = q * Vector3(data[0][0], data[1][0], data[2][0]);
    Vector3 b = q * Vector3(data[0][1], data[1][1], data[2][1]);
    Vector3 c = q * Vector3(data[0][2], data[1][2], data[2][2]);
    Vector3 d = q * Vector3(data[0][3], data[1][3], data[2][3]);
    set(0, Vector4(a[0], b[0], c[0], d[0]));
    set(1, Vector4(a[1], b[1], c[1], d[1]));
    set(2, Vector4(a[2], b[2], c[2], d[2]));
    return *this;
}