#include <stdarg.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "utils/Functions.h"
#include "vm/Operation.h"
#include "vm/Script.h"

void sError(Script* sc, const char* format, ...) {
    va_list args;
    va_start(args, format);
    vsnprintf(sc->error, SCRIPT_ERROR_SIZE, format, args);
    va_end(args);
}

static bool sRead(Script* sc, void* buffer, int length) {
    if(sc->readIndex + length > sc->code->length) {
        sError(sc, "cannot read expected %d bytes of data from bytecode");
        return false;
    }
    memcpy(buffer, sc->code->code + sc->readIndex, length);
    sc->readIndex += length;
    return true;
}

static Operation sReadOperation(Script* sc) {
    unsigned char c;
    if(sRead(sc, &c, 1)) {
        return c;
    }
    return OP_NOTHING;
}

static void* sReserve(Script* sc, int length) {
    if(sc->stackIndex + length > SCRIPT_STACK_SIZE) {
        sError(sc, "stack overflow");
        return NULL;
    }
    void* p = sc->stack + sc->stackIndex;
    sc->stackIndex += length;
    return p;
}

static bool sPush(Script* sc, const void* data, int length) {
    void* p = sReserve(sc, length);
    if(p != NULL) {
        memcpy(p, data, length);
        return true;
    }
    return false;
}

static const void* sFree(Script* sc, int length) {
    if(sc->stackIndex < length) {
        sError(sc, "stack underflow");
        return NULL;
    }
    sc->stackIndex -= length;
    return sc->stack + sc->stackIndex;
}

static bool sPop(Script* sc, void* data, int length) {
    const void* p = sFree(sc, length);
    if(p != NULL) {
        memcpy(data, p, length);
        return true;
    }
    return false;
}

static bool sPeek(Script* sc, void* data, int length) {
    if(sc->stackIndex < length) {
        sError(sc, "stack underflow");
        return false;
    }
    memcpy(data, sc->stack + (sc->stackIndex - length), length);
    return true;
}

#define POP_PUSH(type, Type)                                                   \
    bool sPop##Type(Script* sc, type* value) {                                 \
        return sPop(sc, value, sizeof(type));                                  \
    }                                                                          \
    bool sPush##Type(Script* sc, type value) {                                 \
        return sPush(sc, &value, sizeof(type));                                \
    }

#define READ_POP_PUSH(type, Type)                                              \
    static bool sRead##Type(Script* sc, type* i) {                             \
        return sRead(sc, i, sizeof(type));                                     \
    }                                                                          \
    POP_PUSH(type, Type)

#define PUSH_CONSTANT(type, Type)                                              \
    {                                                                          \
        type value;                                                            \
        if(sRead##Type(sc, &value)) {                                          \
            sPush##Type(sc, value);                                            \
        }                                                                      \
    }

#define ZERO_CHECK(name)                                                       \
    if(values[0] == 0) {                                                       \
        sError(sc, name " by 0");                                              \
        return;                                                                \
    }

#define OP_BASE(type, Type, RType, op, check)                                  \
    {                                                                          \
        type values[2];                                                        \
        if(sPop##Type(sc, values) && sPop##Type(sc, values + 1)) {             \
            check;                                                             \
            sPush##RType(sc, values[1] op values[0]);                          \
        }                                                                      \
    }

#define CHECKED_NUMBER_OP(type, Type, op, check)                               \
    OP_BASE(type, Type, Type, op, check)
#define NUMBER_OP(type, Type, op) CHECKED_NUMBER_OP(type, Type, op, )
#define BOOL_OP(type, Type, op) OP_BASE(type, Type, Bool, op, )
#define DIVISION(type, Type)                                                   \
    CHECKED_NUMBER_OP(type, Type, /, ZERO_CHECK("division"));
#define MODULE(type, Type)                                                     \
    CHECKED_NUMBER_OP(type, Type, %, ZERO_CHECK("module"));

READ_POP_PUSH(int, Int)
READ_POP_PUSH(long, Long)
READ_POP_PUSH(float, Float)
POP_PUSH(bool, Bool)

static bool sPopPointer(Script* sc, Pointer* value) {
    return sPop(sc, value, sizeof(Pointer));
}

static bool sPushPointer(Script* sc, Pointer* value) {
    return sPush(sc, value, sizeof(Pointer));
}

static void sPushNullPointer(Script* sc) {
    Pointer p = {-1, -1};
    sPushPointer(sc, &p);
}

#define PRINT(type, Type, printer)                                             \
    {                                                                          \
        type value;                                                            \
        if(sPop##Type(sc, &value)) {                                           \
            printer(value);                                                    \
        }                                                                      \
    }

#define INVERT_SIGN(type, Type)                                                \
    {                                                                          \
        type value = 0;                                                        \
        if(sPop##Type(sc, &value)) {                                           \
            sPush##Type(sc, -value);                                           \
        }                                                                      \
    }

static void sReserveBytes(Script* sc) {
    int bytes = 0;
    int offset = 0;
    if(sReadInt(sc, &bytes) && sReadInt(sc, &offset)) {
        int oldIndex = sc->stackVarIndex;
        sc->stackVarIndex = sc->stackIndex - offset;
        sReserve(sc, bytes - offset);
        sPushInt(sc, oldIndex);
    }
}

static void* sCheckAddress(Script* sc, Pointer* p, int length) {
    if(p->array >= 0) {
        Array* a = asGet(&sc->arrays, p->array);
        if(a == NULL) {
            sError(sc, "invalid heap pointer");
            return NULL;
        } else if(p->offset < 0 || p->offset >= a->size) {
            sError(sc, "address %d is out of array bounds", p->offset);
            return NULL;
        }
        return ((char*)a->data) + p->offset;
    }
    if(p->offset < 0 || p->offset + length > sc->stackIndex) {
        sError(sc, "address %d is out of stack bounds", p->offset);
        return NULL;
    }
    return sc->stack + p->offset;
}

static void sNot(Script* sc) {
    bool value = false;
    if(sPopBool(sc, &value)) {
        sPushBool(sc, !value);
    }
}

static void sBitNotInt(Script* sc) {
    int value = 0;
    if(sPopInt(sc, &value)) {
        sPushInt(sc, ~value);
    }
}

static void sBitNotLong(Script* sc) {
    long value = 0;
    if(sPopLong(sc, &value)) {
        sPushLong(sc, ~value);
    }
}

static void sLine(Script* sc) {
    sRead(sc, &sc->line, 2);
}

static void sGoTo(Script* sc) {
    int gotoIndex;
    if(sReadInt(sc, &gotoIndex)) {
        sc->readIndex = gotoIndex;
    }
}

static void sGoSub(Script* sc) {
    int gotoIndex;
    int offset;
    if(sReadInt(sc, &gotoIndex) && sReadInt(sc, &offset)) {
        Pointer p = {.array = -1,
                     .offset = sc->stackIndex - offset - sizeof(int)};
        void* dest = sCheckAddress(sc, &p, sizeof(int));
        if(dest != NULL) {
            memcpy(dest, &sc->readIndex, sizeof(int));
            sc->readIndex = gotoIndex;
        }
    }
}

static void sReturn(Script* sc) {
    int bytes = 0;
    int varIndex = 0;
    if(sReadInt(sc, &bytes) && sPopInt(sc, &varIndex)) {
        sc->stackVarIndex = varIndex;
        sFree(sc, bytes);
        if(!sPopInt(sc, &sc->readIndex) || sc->readIndex < 0) {
            sError(sc, "read index is corrupt");
        }
    }
}

static void sReturnPointer(Script* sc) {
    Pointer p;
    if(sPopPointer(sc, &p)) {
        sReturn(sc);
        sPushPointer(sc, &p);
    }
}

#define RETURN(type, Type)                                                     \
    {                                                                          \
        type value;                                                            \
        if(sPop##Type(sc, &value)) {                                           \
            sReturn(sc);                                                       \
            sPush##Type(sc, value);                                            \
        }                                                                      \
    }

static void sIfGoTo(Script* sc) {
    int gotoIndex = 0;
    bool value = false;
    if(sReadInt(sc, &gotoIndex) && sPopBool(sc, &value) && !value) {
        sc->readIndex = gotoIndex;
    }
}

static void sPeekFalseGoTo(Script* sc) {
    int gotoIndex = 0;
    bool value = false;
    if(sReadInt(sc, &gotoIndex) && sPeek(sc, &value, sizeof(bool)) && !value) {
        sc->readIndex = gotoIndex;
    }
}

static void sPeekTrueGoTo(Script* sc) {
    int gotoIndex = 0;
    bool value = false;
    if(sReadInt(sc, &gotoIndex) && sPeek(sc, &value, sizeof(bool)) && value) {
        sc->readIndex = gotoIndex;
    }
}

static void sNewArray(Script* sc) {
    int length = 0;
    int size = 0;
    if(sReadInt(sc, &size) && sPopInt(sc, &length)) {
        Pointer p = {.array = asAllocate(&sc->arrays, size, length),
                     .offset = 0};
        if(p.array == -1) {
            sError(sc, "out of memory");
        } else if(p.array == -2) {
            sError(sc, "bad allocation");
        } else {
            sPushPointer(sc, &p);
        }
    }
}

static void sDeleteArray(Script* sc) {
    Pointer p;
    if(sPopPointer(sc, &p)) {
        if(p.offset != 0) {
            sError(sc, "delete of array with offset: %d", p.offset);
            return;
        }
        Array* a = asGet(&sc->arrays, p.array);
        if(a == NULL) {
            sError(sc, "delete of invalid array");
            return;
        }
        asDeleteArray(&sc->arrays, a, p.array);
    }
}

static void sLength(Script* sc) {
    Pointer p;
    if(sPopPointer(sc, &p)) {
        if(p.array == -1) {
            sPushInt(sc, p.offset >= 0);
            return;
        }
        Array* a = asGet(&sc->arrays, p.array);
        if(a == NULL) {
            sError(sc, "invalid heap pointer");
            return;
        }
        sPushInt(sc, a->length);
    }
}

static void sDereference(Script* sc) {
    int address = 0;
    if(sReadInt(sc, &address)) {
        Pointer p = {.array = -1, .offset = address + sc->stackVarIndex};
        sPushPointer(sc, &p);
    }
}

static void sLoad(Script* sc, int length) {
    Pointer p;
    if(sPopPointer(sc, &p)) {
        void* src = sCheckAddress(sc, &p, length);
        if(src != NULL) {
            sPush(sc, src, length);
        }
    }
}

static void sDuplicateReference(Script* sc) {
    Pointer p;
    if(sPeek(sc, &p, sizeof(Pointer))) {
        sPushPointer(sc, &p);
    }
}

static void sAddReference(Script* sc) {
    int size = 0;
    int add = 0;
    Pointer p;
    if(sReadInt(sc, &size) && sPopInt(sc, &add) && sPopPointer(sc, &p)) {
        p.offset += add * size;
        sPushPointer(sc, &p);
    }
}

static void sLoadSize(Script* sc) {
    int size = 0;
    Pointer p;
    if(sReadInt(sc, &size) && sPopPointer(sc, &p)) {
        void* src = sCheckAddress(sc, &p, size);
        if(src != NULL) {
            sPush(sc, src, size);
        }
    }
}

static void sStore(Script* sc, int length) {
    int index = sc->stackIndex - sizeof(Pointer) - length;
    if(index < 0) {
        sError(sc, "stack underflow");
        return;
    }
    Pointer p;
    memcpy(&p, sc->stack + index, sizeof(Pointer));
    void* dest = sCheckAddress(sc, &p, length);
    if(dest != NULL) {
        sPop(sc, dest, length);
        sc->stackIndex -= sizeof(Pointer);
    }
}

static void sEqualPointer(Script* sc) {
    Pointer a;
    Pointer b;
    if(sPopPointer(sc, &a) && sPopPointer(sc, &b)) {
        sPushBool(sc, a.array == b.array && a.offset == b.offset);
    }
}

static void sCall(Script* sc) {
    int function = 0;
    if(sReadInt(sc, &function) && gfsCall(sc, function)) {
        sError(sc, "invalid function call");
    }
}

#define CHANGE_OP(type, op)                                                    \
    {                                                                          \
        char c = 0;                                                            \
        Pointer p;                                                             \
        if(sRead(sc, &c, sizeof(char)) && sPopPointer(sc, &p)) {               \
            void* data = sCheckAddress(sc, &p, sizeof(type));                  \
            if(data != NULL) {                                                 \
                type current;                                                  \
                memcpy(&current, data, sizeof(type));                          \
                op                                                             \
            }                                                                  \
        }                                                                      \
    }
#define PUSH_PRE_CHANGE(Type, type)                                            \
    CHANGE_OP(type, current += c; sPush##Type(sc, current);                    \
              memcpy(data, &current, sizeof(type));)
#define PUSH_POST_CHANGE(Type, type)                                           \
    CHANGE_OP(type, sPush##Type(sc, current); current += c;                    \
              memcpy(data, &current, sizeof(type));)
#define CHANGE(type)                                                           \
    CHANGE_OP(type, current += c; memcpy(data, &current, sizeof(type));)
#define CAST(From, from, To)                                                   \
    {                                                                          \
        from value;                                                            \
        if(sPop##From(sc, &value)) {                                           \
            sPush##To(sc, value);                                              \
        }                                                                      \
    }
#define CASE_CHANGE(TYPE, Type, type)                                          \
    case OP_PUSH_PRE_CHANGE_##TYPE: PUSH_PRE_CHANGE(Type, type); break;        \
    case OP_PUSH_POST_CHANGE_##TYPE: PUSH_POST_CHANGE(Type, type); break;      \
    case OP_CHANGE_##TYPE:                                                     \
        CHANGE(type);                                                          \
        break;
#define CASE_NUMBER_OP(name, op)                                               \
    case OP_##name##_INT: NUMBER_OP(int, Int, op); break;                      \
    case OP_##name##_LONG: NUMBER_OP(long, Long, op); break;                   \
    case OP_##name##_FLOAT:                                                    \
        NUMBER_OP(float, Float, op);                                           \
        break;
#define CASE_BOOL_OP(name, op)                                                 \
    case OP_##name##_INT: BOOL_OP(int, Int, op); break;                        \
    case OP_##name##_LONG: BOOL_OP(long, Long, op); break;                     \
    case OP_##name##_FLOAT:                                                    \
        BOOL_OP(float, Float, op);                                             \
        break;
#define CASE_TYPE(TYPE, Type, type)                                            \
    case OP_STORE_##TYPE: sStore(sc, sizeof(type)); break;                     \
    case OP_RETURN_##TYPE: RETURN(type, Type); break;                          \
    case OP_EQUAL_##TYPE: BOOL_OP(type, Type, ==); break;                      \
    case OP_LOAD_##TYPE: sLoad(sc, sizeof(type)); break;

static void sConsumeInstruction(Script* sc) {
    switch(sReadOperation(sc)) {
        CASE_NUMBER_OP(ADD, +);
        CASE_NUMBER_OP(SUB, -);
        CASE_NUMBER_OP(MUL, *);
        CASE_BOOL_OP(LESS, <);
        CASE_BOOL_OP(GREATER, >);
        CASE_TYPE(INT, Int, int);
        CASE_TYPE(LONG, Long, long);
        CASE_TYPE(BOOL, Bool, bool);
        CASE_TYPE(FLOAT, Float, float);
        CASE_CHANGE(INT, Int, int);
        CASE_CHANGE(LONG, Long, long);
        case OP_NOTHING: break;
        case OP_PUSH_INT: PUSH_CONSTANT(int, Int); break;
        case OP_PUSH_LONG: PUSH_CONSTANT(long, Long); break;
        case OP_PUSH_FLOAT: PUSH_CONSTANT(float, Float); break;
        case OP_PUSH_TRUE: sPushBool(sc, true); break;
        case OP_PUSH_FALSE: sPushBool(sc, false); break;
        case OP_PUSH_NULLPTR: sPushNullPointer(sc); break;
        case OP_DIV_INT: DIVISION(int, Int); break;
        case OP_DIV_LONG: DIVISION(long, Long); break;
        case OP_DIV_FLOAT: DIVISION(float, Float); break;
        case OP_MOD_INT: MODULE(int, Int); break;
        case OP_MOD_LONG: MODULE(long, Long); break;
        case OP_INVERT_SIGN_INT: INVERT_SIGN(int, Int); break;
        case OP_INVERT_SIGN_LONG: INVERT_SIGN(long, Long); break;
        case OP_INVERT_SIGN_FLOAT: INVERT_SIGN(float, Float); break;
        case OP_NOT: sNot(sc); break;
        case OP_AND: BOOL_OP(bool, Bool, &&); break;
        case OP_OR: BOOL_OP(bool, Bool, ||); break;
        case OP_BIT_NOT_INT: sBitNotInt(sc); break;
        case OP_BIT_AND_INT: NUMBER_OP(int, Int, &); break;
        case OP_BIT_OR_INT: NUMBER_OP(int, Int, |); break;
        case OP_BIT_XOR_INT: NUMBER_OP(int, Int, ^); break;
        case OP_LEFT_SHIFT_INT: NUMBER_OP(int, Int, <<); break;
        case OP_RIGHT_SHIFT_INT: NUMBER_OP(int, Int, >>); break;
        case OP_BIT_NOT_LONG: sBitNotLong(sc); break;
        case OP_BIT_AND_LONG: NUMBER_OP(long, Long, &); break;
        case OP_BIT_OR_LONG: NUMBER_OP(long, Long, |); break;
        case OP_BIT_XOR_LONG: NUMBER_OP(long, Long, ^); break;
        case OP_LEFT_SHIFT_LONG: NUMBER_OP(long, Long, <<); break;
        case OP_RIGHT_SHIFT_LONG: NUMBER_OP(long, Long, >>); break;
        case OP_LINE: sLine(sc); break;
        case OP_GOTO: sGoTo(sc); break;
        case OP_IF_GOTO: sIfGoTo(sc); break;
        case OP_PEEK_FALSE_GOTO: sPeekFalseGoTo(sc); break;
        case OP_PEEK_TRUE_GOTO: sPeekTrueGoTo(sc); break;
        case OP_GOSUB: sGoSub(sc); break;
        case OP_RETURN: sReturn(sc); break;
        case OP_RETURN_POINTER: sReturnPointer(sc); break;
        case OP_RESERVE: sReserveBytes(sc); break;
        case OP_DEREFERENCE_VAR: sDereference(sc); break;
        case OP_REFERENCE: sLoad(sc, sizeof(Pointer)); break;
        case OP_DUPLICATE_REFERENCE: sDuplicateReference(sc); break;
        case OP_ADD_REFERENCE: sAddReference(sc); break;
        case OP_LOAD: sLoadSize(sc); break;
        case OP_NEW: sNewArray(sc); break;
        case OP_DELETE: sDeleteArray(sc); break;
        case OP_LENGTH: sLength(sc); break;
        case OP_STORE_POINTER: sStore(sc, sizeof(Pointer)); break;
        case OP_EQUAL_POINTER: sEqualPointer(sc); break;
        case OP_INT_TO_FLOAT: CAST(Int, int, Float); break;
        case OP_FLOAT_TO_INT: CAST(Float, float, Int); break;
        case OP_INT_TO_LONG: CAST(Int, int, Long); break;
        case OP_LONG_TO_INT: CAST(Long, long, Int); break;
        case OP_FLOAT_TO_LONG: CAST(Float, float, Long); break;
        case OP_LONG_TO_FLOAT: CAST(Long, long, Float); break;
        case OP_CALL: sCall(sc); break;
    }
}

static bool sHasData(Script* sc) {
    return sc->readIndex < sc->code->length;
}

Script* sInit(ByteCode* code) {
    Script* sc = malloc(sizeof(Script));
    sc->error[0] = '\0';
    sc->code = code;
    sc->readIndex = 0;
    sc->stackIndex = 0;
    sc->stackVarIndex = 0;
    sc->line = 0;
    asInit(&sc->arrays);
    return sc;
}

void sDelete(Script* sc) {
    bcDelete(sc->code);
    asDelete(&sc->arrays);
    free(sc);
}

void sRun(Script* sc) {
    while(sHasData(sc)) {
        sConsumeInstruction(sc);
        if(sc->error[0] != '\0') {
            puts("error:");
            printf(" - info: %s\n", sc->error);
            printf(" - line: %d\n", sc->line);
            return;
        }
    }
}