#include <cmath>
#include <iomanip>
#include <cstring>
#include <x86intrin.h>

#include "client/math/Matrix.h"

Matrix::Matrix() {
    setToIdentity();
}

Matrix& Matrix::set(const Matrix& other) {
    *this = other;
    return *this;
}

Matrix& Matrix::setToIdentity() {
    data[0] = 1.0f;
    data[1] = 0.0f;
    data[2] = 0.0f;
    data[3] = 0.0f;
    data[4] = 0.0f;
    data[5] = 1.0f;
    data[6] = 0.0f;
    data[7] = 0.0f;
    data[8] = 0.0f;
    data[9] = 0.0f;
    data[10] = 1.0f;
    data[11] = 0.0f;
    data[12] = 0.0f;
    data[13] = 0.0f;
    data[14] = 0.0f;
    data[15] = 1.0f;
    return *this;
}

Matrix& Matrix::set(uint index, float f) {
    data[index] = f;
    return *this;
}

const float* Matrix::getValues() const {
    return data;
}

Matrix& Matrix::mul(const Matrix& m) {
    float mNew[16];
    mNew[0] = data[0] * m.data[0] + data[4] * m.data[1] + data[8] * m.data[2] + data[12] * m.data[3];
    mNew[1] = data[1] * m.data[0] + data[5] * m.data[1] + data[9] * m.data[2] + data[13] * m.data[3];
    mNew[2] = data[2] * m.data[0] + data[6] * m.data[1] + data[10] * m.data[2] + data[14] * m.data[3];
    mNew[3] = data[3] * m.data[0] + data[7] * m.data[1] + data[11] * m.data[2] + data[15] * m.data[3];
    mNew[4] = data[0] * m.data[4] + data[4] * m.data[5] + data[8] * m.data[6] + data[12] * m.data[7];
    mNew[5] = data[1] * m.data[4] + data[5] * m.data[5] + data[9] * m.data[6] + data[13] * m.data[7];
    mNew[6] = data[2] * m.data[4] + data[6] * m.data[5] + data[10] * m.data[6] + data[14] * m.data[7];
    mNew[7] = data[3] * m.data[4] + data[7] * m.data[5] + data[11] * m.data[6] + data[15] * m.data[7];
    mNew[8] = data[0] * m.data[8] + data[4] * m.data[9] + data[8] * m.data[10] + data[12] * m.data[11];
    mNew[9] = data[1] * m.data[8] + data[5] * m.data[9] + data[9] * m.data[10] + data[13] * m.data[11];
    mNew[10] = data[2] * m.data[8] + data[6] * m.data[9] + data[10] * m.data[10] + data[14] * m.data[11];
    mNew[11] = data[3] * m.data[8] + data[7] * m.data[9] + data[11] * m.data[10] + data[15] * m.data[11];
    mNew[12] = data[0] * m.data[12] + data[4] * m.data[13] + data[8] * m.data[14] + data[12] * m.data[15];
    mNew[13] = data[1] * m.data[12] + data[5] * m.data[13] + data[9] * m.data[14] + data[13] * m.data[15];
    mNew[14] = data[2] * m.data[12] + data[6] * m.data[13] + data[10] * m.data[14] + data[14] * m.data[15];
    mNew[15] = data[3] * m.data[12] + data[7] * m.data[13] + data[11] * m.data[14] + data[15] * m.data[15];
    std::memcpy(data, mNew, sizeof (float) * 16);
    return *this;
}

Matrix& Matrix::scale(float sx, float sy, float sz) {
    data[0] *= sx;
    data[1] *= sx;
    data[2] *= sx;
    data[3] *= sx;
    data[4] *= sy;
    data[5] *= sy;
    data[6] *= sy;
    data[7] *= sy;
    data[8] *= sz;
    data[9] *= sz;
    data[10] *= sz;
    data[11] *= sz;
    return *this;
}

Matrix& Matrix::scale(float s) {
    return scale(s, s, s);
}

Matrix& Matrix::translate(float tx, float ty, float tz) {
    return translateX(tx).translateY(ty).translateZ(tz);
}

Matrix& Matrix::translateX(float tx) {
    data[12] += data[0] * tx;
    data[13] += data[1] * tx;
    data[14] += data[2] * tx;
    data[15] += data[3] * tx;
    return *this;
}

Matrix& Matrix::translateY(float ty) {
    data[12] += data[4] * ty;
    data[13] += data[5] * ty;
    data[14] += data[6] * ty;
    data[15] += data[7] * ty;
    return *this;
}

Matrix& Matrix::translateZ(float tz) {
    data[12] += data[8] * tz;
    data[13] += data[9] * tz;
    data[14] += data[10] * tz;
    data[15] += data[11] * tz;
    return *this;
}

Matrix& Matrix::translateTo(float tx, float ty, float tz) {
    data[0] = 1.0f;
    data[1] = 0.0f;
    data[2] = 0.0f;
    data[3] = 0.0f;
    data[4] = 0.0f;
    data[5] = 1.0f;
    data[6] = 0.0f;
    data[7] = 0.0f;
    data[8] = 0.0f;
    data[9] = 0.0f;
    data[10] = 1.0f;
    data[11] = 0.0f;
    data[12] = tx;
    data[13] = ty;
    data[14] = tz;
    data[15] = 1.0f;
    return *this;
}

Matrix& Matrix::rotate(float degrees, uint indexA, uint indexB) {
    degrees *= M_PIf32 / 180.0f;
    float sin;
    float cos;
    sincosf(degrees, &sin, &cos);

    __m128 va = _mm_load_ps(data + indexA);
    __m128 vb = _mm_load_ps(data + indexB);

    __m128 vcos = _mm_set1_ps(cos);

    __m128 vresult1 = _mm_add_ps(_mm_mul_ps(va, vcos), _mm_mul_ps(vb, _mm_set1_ps(sin)));
    __m128 vresult2 = _mm_add_ps(_mm_mul_ps(va, _mm_set1_ps(-sin)), _mm_mul_ps(vb, vcos));

    _mm_store_ps(data + indexA, vresult1);
    _mm_store_ps(data + indexB, vresult2);
    return *this;
}

Matrix& Matrix::rotateX(float degrees) {
    return rotate(degrees, 4, 8);
}

Matrix& Matrix::rotateY(float degrees) {
    return rotate(degrees, 0, 8);
}

Matrix& Matrix::rotateZ(float degrees) {
    return rotate(degrees, 0, 4);
}

std::ostream& operator<<(std::ostream& os, const Matrix& m) {
    const float* data = m.getValues();
    os << "Matrix\n(\n";
    os << std::fixed << std::setprecision(5);
    for(int i = 0; i < 4; i++) {
        os << std::setw(15);
        os << data[i] << ", ";
        os << std::setw(15);
        os << data[i + 4] << ", ";
        os << std::setw(15);
        os << data[i + 8] << ", ";
        os << std::setw(15);
        os << data[i + 12] << "\n";
    }
    os << std::defaultfloat;
    os << ")";
    return os;
}