#include "Matrix3D.h"
#include <cstring>
#include <cmath>

Matrix3D::Matrix3D()
{
    setToIdentity();
}

Matrix3D::Matrix3D(const Matrix3D& orig)
{
}

Matrix3D::~Matrix3D()
{
}

void Matrix3D::setToIdentity()
{
    data[0] = 1.0f;
    data[1] = 0.0f;
    data[2] = 0.0f;
    data[3] = 0.0f;
    data[4] = 0.0f;
    data[5] = 1.0f;
    data[6] = 0.0f;
    data[7] = 0.0f;
    data[8] = 0.0f;
    data[9] = 0.0f;
    data[10] = 1.0f;
    data[11] = 0.0f;
    data[12] = 0.0f;
    data[13] = 0.0f;
    data[14] = 0.0f;
    data[15] = 1.0f;
}

void Matrix3D::set(const Matrix3D& m)
{
    memcpy(data, m.data, sizeof(float) * 16);
}

void Matrix3D::set(int row, int col, float value)
{
    data[row + (col << 2)] = value;
}

float Matrix3D::get(int row, int col) const
{
    return data[row + (col << 2)];
}

const float* Matrix3D::getValues() const
{
    return data;
}

void Matrix3D::mul(const Matrix3D& m)
{
    float mNew[16];
    mNew[0] = data[0] * m.data[0] + data[4] * m.data[1] + data[8] * m.data[2] + data[12] * m.data[3];
    mNew[1] = data[1] * m.data[0] + data[5] * m.data[1] + data[9] * m.data[2] + data[13] * m.data[3];
    mNew[2] = data[2] * m.data[0] + data[6] * m.data[1] + data[10] * m.data[2] + data[14] * m.data[3];
    mNew[3] = data[3] * m.data[0] + data[7] * m.data[1] + data[11] * m.data[2] + data[15] * m.data[3];
    mNew[4] = data[0] * m.data[4] + data[4] * m.data[5] + data[8] * m.data[6] + data[12] * m.data[7];
    mNew[5] = data[1] * m.data[4] + data[5] * m.data[5] + data[9] * m.data[6] + data[13] * m.data[7];
    mNew[6] = data[2] * m.data[4] + data[6] * m.data[5] + data[10] * m.data[6] + data[14] * m.data[7];
    mNew[7] = data[3] * m.data[4] + data[7] * m.data[5] + data[11] * m.data[6] + data[15] * m.data[7];
    mNew[8] = data[0] * m.data[8] + data[4] * m.data[9] + data[8] * m.data[10] + data[12] * m.data[11];
    mNew[9] = data[1] * m.data[8] + data[5] * m.data[9] + data[9] * m.data[10] + data[13] * m.data[11];
    mNew[10] = data[2] * m.data[8] + data[6] * m.data[9] + data[10] * m.data[10] + data[14] * m.data[11];
    mNew[11] = data[3] * m.data[8] + data[7] * m.data[9] + data[11] * m.data[10] + data[15] * m.data[11];
    mNew[12] = data[0] * m.data[12] + data[4] * m.data[13] + data[8] * m.data[14] + data[12] * m.data[15];
    mNew[13] = data[1] * m.data[12] + data[5] * m.data[13] + data[9] * m.data[14] + data[13] * m.data[15];
    mNew[14] = data[2] * m.data[12] + data[6] * m.data[13] + data[10] * m.data[14] + data[14] * m.data[15];
    mNew[15] = data[3] * m.data[12] + data[7] * m.data[13] + data[11] * m.data[14] + data[15] * m.data[15];
    memcpy(data, mNew, sizeof(float) * 16);
}

void Matrix3D::scale(float sx, float sy, float sz)
{
    data[0] *= sx;
    data[1] *= sx;
    data[2] *= sx;
    data[3] *= sx;
    data[4] *= sy;
    data[5] *= sy;
    data[6] *= sy;
    data[7] *= sy;
    data[8] *= sz;
    data[9] *= sz;
    data[10] *= sz;
    data[11] *= sz;
}

void Matrix3D::translate(float tx, float ty, float tz) 
{
    data[12] += data[0] * tx + data[4] * ty + data[8] * tz;
    data[13] += data[1] * tx + data[5] * ty + data[9] * tz;
    data[14] += data[2] * tx + data[6] * ty + data[10] * tz;
    data[15] += data[3] * tx + data[7] * ty + data[11] * tz;
}

void Matrix3D::translateX(float tx) 
{
    data[12] += data[0] * tx;
    data[13] += data[1] * tx;
    data[14] += data[2] * tx;
    data[15] += data[3] * tx;
}

void Matrix3D::translateY(float ty) 
{
    data[12] += data[4] * ty;
    data[13] += data[5] * ty;
    data[14] += data[6] * ty;
    data[15] += data[7] * ty;
}

void Matrix3D::translateZ(float tz) 
{
    data[12] += data[8] * tz;
    data[13] += data[9] * tz;
    data[14] += data[10] * tz;
    data[15] += data[11] * tz;
}

void Matrix3D::translateTo(float tx, float ty, float tz) 
{
    data[0] = 1.0f;
    data[1] = 0.0f;
    data[2] = 0.0f;
    data[3] = 0.0f;
    data[4] = 0.0f;
    data[5] = 1.0f;
    data[6] = 0.0f;
    data[7] = 0.0f;
    data[8] = 0.0f;
    data[9] = 0.0f;
    data[10] = 1.0f;
    data[11] = 0.0f;
    data[12] = tx;
    data[13] = ty;
    data[14] = tz;
    data[15] = 1.0f;
}

void Matrix3D::rotate(float xDegrees, float yDegrees, float zDegrees) 
{
    rotateX(xDegrees);
    rotateY(yDegrees);
    rotateZ(zDegrees);
}

void Matrix3D::rotateX(float degrees) 
{
    degrees *= M_PI / 180.0f;
    float sin = sinf(degrees);
    float cos =  cosf(degrees);
    
    float a = data[4];
    float b = data[8];
    data[4] = a * cos + b * sin;
    data[8] = a * -sin + b * cos;

    a = data[5];
    b = data[9];
    data[5] = a * cos + b * sin;
    data[9] = a * -sin + b * cos;

    a = data[6];
    b = data[10];
    data[6] = a * cos + b * sin;
    data[10] = a * -sin + b * cos;

    a = data[7];
    b = data[11];
    data[7] = a * cos + b * sin;
    data[11] = a * -sin + b * cos;
}

void Matrix3D::rotateY(float degrees) 
{
    degrees *= M_PI / 180.0f;
    float sin = sinf(degrees);
    float cos =  cosf(degrees);
    
    float a = data[0];
    float b = data[8];
    data[0] = a * cos + b * -sin;
    data[8] = a * sin + b * cos;

    a = data[1];
    b = data[9];
    data[1] = a * cos + b * -sin;
    data[9] = a * sin + b * cos;

    a = data[2];
    b = data[10];
    data[2] = a * cos + b * -sin;
    data[10] = a * sin + b * cos;

    a = data[3];
    b = data[11];
    data[3] = a * cos + b * -sin;
    data[11] = a * sin + b * cos;
}

void Matrix3D::rotateZ(float degrees) 
{
    degrees *= M_PI / 180.0f;
    float sin = sinf(degrees);
    float cos =  cosf(degrees);
    
    float a = data[0];
    float b = data[4];
    data[0] = a * cos + b * sin;
    data[4] = a * -sin + b * cos;

    a = data[1];
    b = data[5];
    data[1] = a * cos + b * sin;
    data[5] = a * -sin + b * cos;

    a = data[2];
    b = data[6];
    data[2] = a * cos + b * sin;
    data[6] = a * -sin + b * cos;

    a = data[3];
    b = data[7];
    data[3] = a * cos + b * sin;
    data[7] = a * -sin + b * cos;
}

std::ostream& operator<<(std::ostream& os, const Matrix3D& m)
{
    os << "Matrix3D\n(\n";
    os << m.get(0, 0) << ", " << m.get(0, 1) << ", " << m.get(0, 2) << ", " << m.get(0, 3) << "\n";
    os << m.get(1, 0) << ", " << m.get(1, 1) << ", " << m.get(1, 2) << ", " << m.get(1, 3) << "\n";
    os << m.get(2, 0) << ", " << m.get(2, 1) << ", " << m.get(2, 2) << ", " << m.get(2, 3) << "\n";
    os << m.get(3, 0) << ", " << m.get(3, 1) << ", " << m.get(3, 2) << ", " << m.get(3, 3) << "\n";
    os << ")";
    return os;
}