#include <stdarg.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "vm/Operation.h"
#include "vm/Script.h"

static void sError(Script* sc, const char* format, ...) {
    va_list args;
    va_start(args, format);
    vsnprintf(sc->error, SCRIPT_ERROR_SIZE, format, args);
    va_end(args);
}

static void sIntPrinter(int i) {
    printf("%d\n", i);
}

static void sFloatPrinter(float f) {
    printf("%.2f\n", f);
}

static void sBoolPrinter(bool b) {
    puts(b ? "true" : "false");
}

static IntPrinter intPrinter = sIntPrinter;
static FloatPrinter floatPrinter = sFloatPrinter;
static BoolPrinter boolPrinter = sBoolPrinter;

static bool sRead(Script* sc, void* buffer, int length) {
    if(sc->readIndex + length > sc->code->length) {
        sError(sc, "cannot read expected %d bytes of data from bytecode");
        return false;
    }
    memcpy(buffer, sc->code->code + sc->readIndex, length);
    sc->readIndex += length;
    return true;
}

static Operation sReadOperation(Script* sc) {
    unsigned char c;
    if(sRead(sc, &c, 1)) {
        return c;
    }
    return OP_NOTHING;
}

static void* sReserve(Script* sc, int length) {
    if(sc->stackIndex + length > SCRIPT_STACK_SIZE) {
        sError(sc, "stack overflow");
        return NULL;
    }
    void* p = sc->stack + sc->stackIndex;
    sc->stackIndex += length;
    return p;
}

static bool sPush(Script* sc, const void* data, int length) {
    void* p = sReserve(sc, length);
    if(p != NULL) {
        memcpy(p, data, length);
        return true;
    }
    return false;
}

static const void* sFree(Script* sc, int length) {
    if(sc->stackIndex < length) {
        sError(sc, "stack underflow");
        return NULL;
    }
    sc->stackIndex -= length;
    return sc->stack + sc->stackIndex;
}

static bool sPop(Script* sc, void* data, int length) {
    const void* p = sFree(sc, length);
    if(p != NULL) {
        memcpy(data, p, length);
        return true;
    }
    return false;
}

static bool sPeek(Script* sc, void* data, int length) {
    if(sc->stackIndex < length) {
        sError(sc, "stack underflow");
        return false;
    }
    memcpy(data, sc->stack + (sc->stackIndex - length), length);
    return true;
}

#define POP_PUSH(type, Type)                                                   \
    static bool sPop##Type(Script* sc, type* value) {                          \
        return sPop(sc, value, sizeof(type));                                  \
    }                                                                          \
    static bool sPush##Type(Script* sc, type value) {                          \
        return sPush(sc, &value, sizeof(type));                                \
    }

#define READ_POP_PUSH(type, Type)                                              \
    static bool sRead##Type(Script* sc, type* i) {                             \
        return sRead(sc, i, sizeof(type));                                     \
    }                                                                          \
    POP_PUSH(type, Type)

#define PUSH_CONSTANT(type, Type)                                              \
    {                                                                          \
        type value;                                                            \
        if(sRead##Type(sc, &value)) {                                          \
            sPush##Type(sc, value);                                            \
        }                                                                      \
    }

#define ZERO_CHECK(name)                                                       \
    if(values[0] == 0) {                                                       \
        sError(sc, name " by 0");                                              \
    }

#define OP_BASE(type, Type, RType, op, check)                                  \
    {                                                                          \
        type values[2];                                                        \
        if(sPop##Type(sc, values) && sPop##Type(sc, values + 1)) {             \
            check;                                                             \
            sPush##RType(sc, values[1] op values[0]);                          \
        }                                                                      \
    }

#define CHECKED_NUMBER_OP(type, Type, op, check)                               \
    OP_BASE(type, Type, Type, op, check)
#define NUMBER_OP(type, Type, op) CHECKED_NUMBER_OP(type, Type, op, )
#define BOOL_OP(type, Type, op) OP_BASE(type, Type, Bool, op, )
#define DIVISION(type, Type)                                                   \
    CHECKED_NUMBER_OP(type, Type, /, ZERO_CHECK("division"));
#define MODULE(type, Type)                                                     \
    CHECKED_NUMBER_OP(type, Type, %, ZERO_CHECK("module"));

READ_POP_PUSH(int, Int)
READ_POP_PUSH(float, Float)
POP_PUSH(bool, Bool)

#define PRINT(type, Type, printer)                                             \
    {                                                                          \
        type value;                                                            \
        if(sPop##Type(sc, &value)) {                                           \
            printer(value);                                                    \
        }                                                                      \
    }

#define INVERT_SIGN(type, Type)                                                \
    {                                                                          \
        type value = 0;                                                        \
        if(sPop##Type(sc, &value)) {                                           \
            sPush##Type(sc, -value);                                           \
        }                                                                      \
    }

static void sReserveBytes(Script* sc) {
    int bytes = 0;
    int offset = 0;
    if(sReadInt(sc, &bytes) && sReadInt(sc, &offset)) {
        int oldIndex = sc->stackVarIndex;
        sc->stackVarIndex = sc->stackIndex - offset;
        sReserve(sc, bytes - offset);
        sPushInt(sc, oldIndex);
    }
}

static bool sCheckAddress(Script* sc, int address, int length) {
    if(address < 0 || address + length > sc->stackIndex) {
        sError(sc, "address is out of stack bounds");
        return false;
    }
    return true;
}

static void sStore(Script* sc, int length) {
    int address = -1;
    int index = sc->stackIndex - sizeof(int) - length;
    if(index < 0) {
        sError(sc, "stack underflow");
        return;
    }
    memcpy(&address, sc->stack + index, sizeof(int));
    if(sCheckAddress(sc, address, length)) {
        sPop(sc, sc->stack + address, length);
        sc->stackIndex -= sizeof(int);
    }
}

static void sNot(Script* sc) {
    bool value = false;
    if(sPopBool(sc, &value)) {
        sPushBool(sc, !value);
    }
}

static void sBitNot(Script* sc) {
    int value = 0;
    if(sPopInt(sc, &value)) {
        sPushInt(sc, ~value);
    }
}

static void sLine(Script* sc) {
    sRead(sc, &sc->line, 2);
}

static void sGoTo(Script* sc) {
    int gotoIndex;
    if(sReadInt(sc, &gotoIndex)) {
        sc->readIndex = gotoIndex;
    }
}

static void sGoSub(Script* sc) {
    int gotoIndex;
    int offset;
    if(sReadInt(sc, &gotoIndex) && sReadInt(sc, &offset)) {
        int address = sc->stackIndex - offset - sizeof(int);
        if(sCheckAddress(sc, address, sizeof(int))) {
            memcpy(sc->stack + address, &sc->readIndex, sizeof(int));
            sc->readIndex = gotoIndex;
        }
    }
}

static void sReturn(Script* sc) {
    int bytes = 0;
    int varIndex = 0;
    if(sReadInt(sc, &bytes) && sPopInt(sc, &varIndex)) {
        sc->stackVarIndex = varIndex;
        sFree(sc, bytes);
        if(!sPopInt(sc, &sc->readIndex) || sc->readIndex < 0) {
            sError(sc, "read index is corrupt");
        }
    }
}

#define RETURN(type, Type)                                                     \
    {                                                                          \
        type value;                                                            \
        if(sPop##Type(sc, &value)) {                                           \
            sReturn(sc);                                                       \
            sPush##Type(sc, value);                                            \
        }                                                                      \
    }

static void sIfGoTo(Script* sc) {
    int gotoIndex = 0;
    bool value = false;
    if(sReadInt(sc, &gotoIndex) && sPopBool(sc, &value) && !value) {
        sc->readIndex = gotoIndex;
    }
}

static void sPeekFalseGoTo(Script* sc) {
    int gotoIndex = 0;
    bool value = false;
    if(sReadInt(sc, &gotoIndex) && sPeek(sc, &value, sizeof(bool)) && !value) {
        sc->readIndex = gotoIndex;
    }
}

static void sPeekTrueGoTo(Script* sc) {
    int gotoIndex = 0;
    bool value = false;
    if(sReadInt(sc, &gotoIndex) && sPeek(sc, &value, sizeof(bool)) && value) {
        sc->readIndex = gotoIndex;
    }
}

static void sIntArray(Script* sc) {
    (void)sc;
}

static void sVarRef(Script* sc) {
    int address = 0;
    if(sReadInt(sc, &address)) {
        sPushInt(sc, address + sc->stackVarIndex);
    }
}

static void sReference(Script* sc) {
    int reference = 0;
    if(sPopInt(sc, &reference) && sCheckAddress(sc, reference, sizeof(int))) {
        sPush(sc, sc->stack + reference, sizeof(int));
    }
}

static void sDuplicateReference(Script* sc) {
    int reference = 0;
    if(sPeek(sc, &reference, sizeof(int))) {
        sPushInt(sc, reference);
    }
}

static void sRefLoad(Script* sc) {
    int size = 0;
    int address = 0;
    if(sReadInt(sc, &size) && sPopInt(sc, &address) &&
       sCheckAddress(sc, address, size)) {
        sPush(sc, sc->stack + address, size);
    }
}

static void sLoad(Script* sc, int length) {
    int address = 0;
    if(sPopInt(sc, &address) && sCheckAddress(sc, address, length)) {
        sPush(sc, sc->stack + address, length);
    }
}

#define CASE_NUMBER_OP(name, op)                                               \
    case OP_##name##_INT: NUMBER_OP(int, Int, op); break;                      \
    case OP_##name##_FLOAT:                                                    \
        NUMBER_OP(float, Float, op);                                           \
        break;
#define CASE_BOOL_OP(name, op)                                                 \
    case OP_##name##_INT: BOOL_OP(int, Int, op); break;                        \
    case OP_##name##_FLOAT:                                                    \
        BOOL_OP(float, Float, op);                                             \
        break;
#define CASE_TYPE(TYPE, Type, type)                                            \
    case OP_STORE_##TYPE: sStore(sc, sizeof(type)); break;                     \
    case OP_RETURN_##TYPE: RETURN(type, Type); break;                          \
    case OP_PRINT_##TYPE: PRINT(type, Type, type##Printer); break;             \
    case OP_EQUAL_##TYPE: BOOL_OP(type, Type, ==); break;                      \
    case OP_LOAD_##TYPE: sLoad(sc, sizeof(type)); break;

static void sConsumeInstruction(Script* sc) {
    switch(sReadOperation(sc)) {
        CASE_NUMBER_OP(ADD, +);
        CASE_NUMBER_OP(SUB, -);
        CASE_NUMBER_OP(MUL, *);
        CASE_BOOL_OP(LESS, <);
        CASE_BOOL_OP(GREATER, >);
        CASE_TYPE(INT, Int, int);
        CASE_TYPE(BOOL, Bool, bool);
        CASE_TYPE(FLOAT, Float, float);
        case OP_NOTHING: break;
        case OP_PUSH_INT: PUSH_CONSTANT(int, Int); break;
        case OP_PUSH_FLOAT: PUSH_CONSTANT(float, Float); break;
        case OP_PUSH_TRUE: sPushBool(sc, true); break;
        case OP_PUSH_FALSE: sPushBool(sc, false); break;
        case OP_DIV_INT: DIVISION(int, Int); break;
        case OP_DIV_FLOAT: DIVISION(float, Float); break;
        case OP_MOD_INT: MODULE(int, Int); break;
        case OP_INVERT_SIGN_INT: INVERT_SIGN(int, Int); break;
        case OP_INVERT_SIGN_FLOAT: INVERT_SIGN(float, Float); break;
        case OP_NOT: sNot(sc); break;
        case OP_AND: BOOL_OP(bool, Bool, &&); break;
        case OP_OR: BOOL_OP(bool, Bool, ||); break;
        case OP_BIT_NOT: sBitNot(sc); break;
        case OP_BIT_AND: NUMBER_OP(int, Int, &); break;
        case OP_BIT_OR: NUMBER_OP(int, Int, |); break;
        case OP_BIT_XOR: NUMBER_OP(int, Int, ^); break;
        case OP_LEFT_SHIFT: NUMBER_OP(int, Int, <<); break;
        case OP_RIGHT_SHIFT: NUMBER_OP(int, Int, >>); break;
        case OP_LINE: sLine(sc); break;
        case OP_GOTO: sGoTo(sc); break;
        case OP_IF_GOTO: sIfGoTo(sc); break;
        case OP_PEEK_FALSE_GOTO: sPeekFalseGoTo(sc); break;
        case OP_PEEK_TRUE_GOTO: sPeekTrueGoTo(sc); break;
        case OP_GOSUB: sGoSub(sc); break;
        case OP_RETURN: sReturn(sc); break;
        case OP_RESERVE: sReserveBytes(sc); break;
        case OP_DEREFERENCE_VAR: sVarRef(sc); break;
        case OP_REFERENCE: sReference(sc); break;
        case OP_DUPLICATE_REFERENCE: sDuplicateReference(sc); break;
        case OP_LOAD: sRefLoad(sc); break;
        case OP_INT_ARRAY: sIntArray(sc); break;
        case OP_STORE_ARRAY: sStore(sc, sizeof(int)); break;
    }
    // sCollectGarbage(sc);
}

static bool sHasData(Script* sc) {
    return sc->readIndex < sc->code->length;
}

Script* sInit(ByteCode* code) {
    Script* sc = malloc(sizeof(Script));
    sc->error[0] = '\0';
    sc->code = code;
    sc->readIndex = 0;
    sc->stackIndex = 0;
    sc->stackVarIndex = 0;
    sc->line = 0;
    // aInit(&sc->allocator);
    return sc;
}

void sDelete(Script* sc) {
    bcDelete(sc->code);
    // aDelete(&sc->allocator);
    free(sc);
}

void sRun(Script* sc) {
    while(sHasData(sc)) {
        sConsumeInstruction(sc);
        if(sc->error[0] != '\0') {
            puts("error:");
            printf(" - info: %s\n", sc->error);
            printf(" - line: %d\n", sc->line);
            return;
        }
    }
}

void sSetIntPrinter(IntPrinter p) {
    intPrinter = p;
}

void sSetFloatPrinter(FloatPrinter p) {
    floatPrinter = p;
}

void sSetBoolPrinter(BoolPrinter p) {
    boolPrinter = p;
}

/*static void sMark(Script* sc, Object* o) {
    if(o->type == OT_ARRAY) {
        Array* a = sc->allocator.data + o->as.intValue;
        a->marked = true;
        for(int i = 0; i < a->length; i++) {
            sMark(sc, a->data + i);
        }
    }
}

void sCollectGarbage(Script* sc) {
    aClearMarker(&sc->allocator);
    for(int i = 0; i < sc->stackIndex; i++) {
        sMark(sc, sc->stack + i);
    }
    aRemoveUnmarked(&sc->allocator);
}*/