#include <threads.h>

#include "../Tests.hpp"
#include "core/thread/Mutex.hpp"
#include "core/thread/SpinLock.hpp"
#include "core/thread/Thread.hpp"
#include "core/utils/Clock.hpp"

static int runDone = 0;

struct IntHolder {
    int value;
};

static int run(void*) {
    runDone = 1;
    return 7;
}

static void testStart() {
    runDone = 0;
    Core::Thread t;
    CORE_TEST_ERROR(t.start(run, nullptr));
    int returnValue = 0;
    CORE_TEST_ERROR(t.join(&returnValue));
    CORE_TEST_EQUAL(1, runDone);
    CORE_TEST_EQUAL(7, returnValue);
}

static void testLambda() {
    IntHolder i(0);
    Core::Thread t;
    CORE_TEST_ERROR(t.start(
        [](void* p) {
            IntHolder* ip = static_cast<IntHolder*>(p);
            ip->value = 2;
            return 0;
        },
        &i));
    CORE_TEST_ERROR(t.join(nullptr));
    CORE_TEST_EQUAL(2, i.value);
}

static void testJoinWithoutStart() {
    Core::Thread t;
    CORE_TEST_EQUAL(Core::ErrorCode::THREAD_ERROR, t.join(nullptr));
}

static void testAutoJoin() {
    Core::Thread t;
    CORE_TEST_ERROR(t.start([](void*) { return 0; }, nullptr));
}

static void testMove() {
    Core::Thread t;
    CORE_TEST_ERROR(t.start([](void*) { return 0; }, nullptr));
    Core::Thread m = Core::move(t);
    CORE_TEST_ERROR(m.join());
}

static void testMoveAssignment() {
    Core::Thread t;
    CORE_TEST_ERROR(t.start([](void*) { return 0; }, nullptr));
    Core::Thread m;
    m = Core::move(t);
    CORE_TEST_ERROR(m.join());
}

static void testMoveIntoActive() {
    Core::Thread t;
    CORE_TEST_ERROR(t.start([](void*) { return 0; }, nullptr));
    Core::Thread m;
    t = Core::move(m);
}

static void testDoubleJoin() {
    Core::Thread t;
    CORE_TEST_ERROR(t.start([](void*) { return 0; }, nullptr));
    CORE_TEST_ERROR(t.join(nullptr));
    CORE_TEST_EQUAL(Core::ErrorCode::THREAD_ERROR, t.join(nullptr));
}

struct MutexCounter {
    Core::Mutex m{};
    int counter = 0;
};

static int incrementMutexCounter(void* p) {
    MutexCounter* mcp = static_cast<MutexCounter*>(p);
    for(int i = 0; i < 10000; i++) {
        (void)mcp->m.lock();
        mcp->counter++;
        (void)mcp->m.unlock();
    }
    return 0;
}

static void testMutex() {
    Core::Clock::Nanos n;
    CORE_TEST_ERROR(Core::Clock::getNanos(n));

    MutexCounter mc;
    CORE_TEST_ERROR(mc.m.init());
    Core::Thread t[2];
    CORE_TEST_ERROR(t[0].start(incrementMutexCounter, &mc));
    CORE_TEST_ERROR(t[1].start(incrementMutexCounter, &mc));
    CORE_TEST_ERROR(t[0].join(nullptr));
    CORE_TEST_ERROR(t[1].join(nullptr));
    CORE_TEST_EQUAL(20000, mc.counter);

    Core::Clock::Nanos n2;
    CORE_TEST_ERROR(Core::Clock::getNanos(n2));
    Core::ArrayString<64> s;
    s.append(n2 - n).append("ns Mutex").printLine();
}

struct SpinLockCounter {
    Core::SpinLock s{};
    int counter = 0;
};

static int incrementSpinLockCounter(void* p) {
    SpinLockCounter* mcp = static_cast<SpinLockCounter*>(p);
    for(int i = 0; i < 10000; i++) {
        mcp->s.lock();
        mcp->counter++;
        mcp->s.unlock();
    }
    return 0;
}

static void testSpinLock() {
    Core::Clock::Nanos n;
    CORE_TEST_ERROR(Core::Clock::getNanos(n));

    SpinLockCounter sc;
    Core::Thread t[2];
    CORE_TEST_ERROR(t[0].start(incrementSpinLockCounter, &sc));
    CORE_TEST_ERROR(t[1].start(incrementSpinLockCounter, &sc));
    CORE_TEST_ERROR(t[0].join(nullptr));
    CORE_TEST_ERROR(t[1].join(nullptr));
    CORE_TEST_EQUAL(20000, sc.counter);

    Core::Clock::Nanos n2;
    CORE_TEST_ERROR(Core::Clock::getNanos(n2));
    Core::ArrayString<64> s;
    s.append(n2 - n).append("ns SpinLock").printLine();
}

void Core::testThread() {
    testStart();
    testLambda();
    testJoinWithoutStart();
    testAutoJoin();
    testMove();
    testMoveAssignment();
    testMoveIntoActive();
    testDoubleJoin();
    testMutex();
    testSpinLock();
}