#include "tests/MatrixTests.hpp"

#include "math/Matrix.hpp"
#include "test/Test.hpp"

using V3 = Core::Vector3;

static void testInit() {
    Core::Matrix m;
    const float* data = m.getValues();
    for(int i = 0; i < 16; i++) {
        int x = i % 4;
        int y = i / 4;
        CORE_TEST_FLOAT(x == y, data[i], 0.0f);
    }
}

static void testTranspose() {
    Core::Matrix m;
    m.set(0, Core::Vector4(1.0f, 2.0f, 3.0f, 4.0f));
    m.set(1, Core::Vector4(5.0f, 6.0f, 7.0f, 8.0f));
    m.set(2, Core::Vector4(9.0f, 10.0f, 11.0f, 12.0f));
    m.set(3, Core::Vector4(13.0f, 14.0f, 15.0f, 16.0f));
    Core::Matrix t = m.transpose();
    Core::Matrix m2 = t.transpose();

    const float* mp = m.getValues();
    const float* tp = t.getValues();
    for(int x = 0; x < 4; x++) {
        for(int y = 0; y < 4; y++) {
            CORE_TEST_FLOAT(mp[y * 4 + x], tp[x * 4 + y], 0.0f);
        }
    }
    const float* mp2 = m2.getValues();
    for(int i = 0; i < 16; i++) {
        CORE_TEST_FLOAT(mp[i], mp2[i], 0.0f);
    }
}

static void testScale() {
    Core::Matrix m;
    m.scale(V3(2.0f, 3.0f, 4.0f));
    CORE_TEST_VECTOR(V3(-8.0f, 18.0f, 28.0f), m * V3(-4.0f, 6.0f, 7.0f));
}

static void testUniformScale() {
    Core::Matrix m;
    m.scale(2.0f);
    CORE_TEST_VECTOR(V3(-8.0f, 12.0f, 14.0f), m * V3(-4.0f, 6.0f, 7.0f));
}

static void testTranslateX() {
    Core::Matrix m;
    m.translateX(5.0f);
    CORE_TEST_VECTOR(V3(1.0f, 6.0f, 7.0f), m * V3(-4.0f, 6.0f, 7.0f));
}

static void testTranslateY() {
    Core::Matrix m;
    m.translateY(6.0f);
    CORE_TEST_VECTOR(V3(-4.0f, 12.0f, 7.0f), m * V3(-4.0f, 6.0f, 7.0f));
}

static void testTranslateZ() {
    Core::Matrix m;
    m.translateZ(7.0f);
    CORE_TEST_VECTOR(V3(-4.0f, 6.0f, 14.0f), m * V3(-4.0f, 6.0f, 7.0f));
}

static void testTranslate() {
    Core::Matrix m;
    m.translate(V3(1.0f, 2.0f, 3.0f));
    CORE_TEST_VECTOR(V3(-3.0f, 8.0f, 10.0f), m * V3(-4.0f, 6.0f, 7.0f));
}

static void testCombination() {
    Core::Matrix m;
    m.scale(2.0f);
    m.translateX(1.0f);
    m.translateY(2.0f);
    m.translateZ(3.0f);
    m.translate(V3(-4.0f, 2.0f, 3.0f));
    m.scale(V3(2.0f, 3.0f, 4.0f));
    m.scale(0.5f);
    CORE_TEST_VECTOR(V3(-1.0f, 9.0f, 16.0f), m * V3(1.0f, 1.0f, 1.0f));
}

static void testMatrixCombination() {
    Core::Matrix a;
    a.scale(2.0f);
    a.translate(V3(1.0f, 2.0f, 3.0f));

    Core::Matrix b;
    b.scale(3.0f);
    b.translate(V3(1.0f, 1.0f, 1.0f));

    Core::Matrix c;
    c.translate(V3(-1.0f, -2.0f, -3.0f));
    c *= b * a;

    CORE_TEST_VECTOR(V3(9.0f, 11.0f, 13.0f), c * V3(1.0f, 1.0f, 1.0f));
}

static void testRotateX() {
    Core::Matrix m;
    m.rotateX(90);
    CORE_TEST_VECTOR(V3(1.0f, 0.0f, 0.0f), m * V3(1.0f, 0.0f, 0.0f));
    CORE_TEST_VECTOR(V3(0.0f, 0.0f, 1.0f), m * V3(0.0f, 1.0f, 0.0f));
    CORE_TEST_VECTOR(V3(0.0f, -1.0f, 0.0f), m * V3(0.0f, 0.0f, 1.0f));
}

static void testRotateY() {
    Core::Matrix m;
    m.rotateY(90);
    CORE_TEST_VECTOR(V3(0.0f, 0.0f, -1.0f), m * V3(1.0f, 0.0f, 0.0f));
    CORE_TEST_VECTOR(V3(0.0f, 1.0f, 0.0f), m * V3(0.0f, 1.0f, 0.0f));
    CORE_TEST_VECTOR(V3(1.0f, 0.0f, 0.0f), m * V3(0.0f, 0.0f, 1.0f));
}

static void testRotateZ() {
    Core::Matrix m;
    m.rotateZ(90);
    CORE_TEST_VECTOR(V3(0.0f, 1.0f, 0.0f), m * V3(1.0f, 0.0f, 0.0f));
    CORE_TEST_VECTOR(V3(-1.0f, 0.0f, 0.0f), m * V3(0.0f, 1.0f, 0.0f));
    CORE_TEST_VECTOR(V3(0.0f, 0.0f, 1.0f), m * V3(0.0f, 0.0f, 1.0f));
}

static void testToString() {
    Core::String32<1024> s;
    Core::Matrix m;
    m.set(0, Core::Vector4(1.0f, 2.0f, 3.0f, 4.0f));
    m.set(1, Core::Vector4(5.0f, 6.0f, 7.0f, 8.0f));
    m.set(2, Core::Vector4(9.0f, 10.0f, 11.0f, 12.0f));
    m.set(3, Core::Vector4(13.0f, 14.0f, 15.0f, 16.0f));
    CORE_TEST_ERROR(s.append(m));
    CORE_TEST_STRING(
        "[[1.00, 2.00, 3.00, 4.00], [5.00, 6.00, 7.00, 8.00], "
        "[9.00, 10.00, 11.00, 12.00], [13.00, 14.00, 15.00, 16.00]]",
        s);
}

static void testQuaternionMatrix() {
    Core::Quaternion q1(V3(1.0f, 0.0f, 0.0f), 48.0f);
    Core::Quaternion q2(V3(0.0f, 1.0f, 0.0f), 52.0f);
    Core::Quaternion q3(V3(0.0f, 0.0f, 1.0f), 60.0f);

    Core::Matrix m;
    m.translate(V3(1.0f, 2.0f, 3.0f));
    m.rotate(q1).rotate(q2).rotate(q3);
    m.translate(V3(1.0f, 2.0f, 3.0f));

    Core::Matrix check;
    check.translate(V3(1.0f, 2.0f, 3.0f));
    check.rotateX(48.0f).rotateY(52.0f).rotateZ(60.0f);
    check.translate(V3(1.0f, 2.0f, 3.0f));

    for(int i = 0; i < 16; i++) {
        CORE_TEST_FLOAT(check.getValues()[i], m.getValues()[i], 0.0001f);
    }
}

void Core::MatrixTests::test() {
    testInit();
    testScale();
    testUniformScale();
    testTranspose();
    testTranslateX();
    testTranslateY();
    testTranslateZ();
    testTranslate();
    testCombination();
    testMatrixCombination();
    testRotateX();
    testRotateY();
    testRotateZ();
    testToString();
    testQuaternionMatrix();
}