#include <stdarg.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "Operation.h"
#include "Script.h"

static void sError(Script* sc, const char* format, ...) {
    va_list args;
    va_start(args, format);
    vsnprintf(sc->error, SCRIPT_ERROR_SIZE, format, args);
    va_end(args);
}

static bool sPrinter(Object* o) {
    if(o->type == OT_INT) {
        printf("%d\n", o->data.intValue);
        return false;
    }
    return true;
}

static ObjectPrinter printer = sPrinter;

static bool sRead(Script* sc, void* buffer, int length) {
    if(sc->readIndex + length > sc->byteCodeLength) {
        return true;
    }
    memcpy(buffer, sc->byteCode + sc->readIndex, length);
    sc->readIndex += length;
    return false;
}

static Operation sReadOperation(Script* sc) {
    unsigned char c;
    if(sRead(sc, &c, 1)) {
        return OP_NOTHING;
    } else if(sRead(sc, &sc->line, 2)) {
        sError(sc, "operation without line near line %d", sc->line);
        return OP_NOTHING;
    }
    return c;
}

static void sPush(Script* sc, Object* o) {
    if(sc->stackIndex >= SCRIPT_STACK_SIZE) {
        sError(sc, "stack overflow on line %d", sc->line);
        return;
    }
    sc->stack[sc->stackIndex++] = *o;
}

static bool sPop(Script* sc, Object* o) {
    if(sc->stackIndex <= 0) {
        sError(sc, "stack underflow on line %d", sc->line);
        return true;
    }
    *o = sc->stack[--sc->stackIndex];
    return false;
}

static void sPushInt(Script* sc, int value) {
    Object o = {.type = OT_INT, .data.intValue = value};
    sPush(sc, &o);
}

static void sPushFloat(Script* sc, float value) {
    Object o = {.type = OT_FLOAT, .data.floatValue = value};
    sPush(sc, &o);
}

static void sPushNull(Script* sc) {
    Object o = {.type = OT_NULL};
    sPush(sc, &o);
}

static void sPushBool(Script* sc, bool value) {
    Object o = {.type = OT_BOOL, .data.intValue = value};
    sPush(sc, &o);
}

static void sPushCodeInt(Script* sc) {
    int value = 0;
    if(sRead(sc, &value, sizeof(int))) {
        sError(sc, "cannot read an int from the bytecode on line %d", sc->line);
        return;
    }
    sPushInt(sc, value);
}

static void sPushCodeFloat(Script* sc) {
    float value = 0;
    if(sRead(sc, &value, sizeof(float))) {
        sError(sc, "cannot read a float from the bytecode on line %d", sc->line);
        return;
    }
    sPushFloat(sc, value);
}

static bool sToFloat(Script* sc, Object* o, float* r) {
    if(o->type == OT_FLOAT) {
        *r = o->data.floatValue;
        return true;
    } else if(o->type == OT_INT) {
        *r = o->data.intValue;
        return true;
    }
    sError(sc, "object is not a number on line %d", sc->line);
    return false;
}

static void sIntBinary(Script* sc, int (*fInt)(int, int), float (*fFloat)(float, float)) {
    Object o[2];
    if(sPop(sc, o) || sPop(sc, o + 1)) {
        return;
    }
    if(o[0].type == OT_INT && o[1].type == OT_INT) {
        sPushInt(sc, fInt(o[0].data.intValue, o[1].data.intValue));
        return;
    }
    float f[2];
    if(sToFloat(sc, o, f) && sToFloat(sc, o + 1, f + 1)) {
        sPushFloat(sc, fFloat(f[0], f[1]));
    }
}

static int sIntAdd(int a, int b) {
    return a + b;
}

static int sIntMul(int a, int b) {
    return a * b;
}

static float sFloatAdd(float a, float b) {
    return a + b;
}

static float sFloatMul(float a, float b) {
    return a * b;
}

static void sPrint(Script* sc) {
    Object o;
    if(sPop(sc, &o)) {
        return;
    }
    if(printer(&o)) {
        sError(sc, "cannot print given object on line %d", sc->line);
    }
}

static void sConsumeInstruction(Script* sc) {
    switch(sReadOperation(sc)) {
        case OP_NOTHING: break;
        case OP_PUSH_INT: sPushCodeInt(sc); break;
        case OP_PUSH_FLOAT: sPushCodeFloat(sc); break;
        case OP_PUSH_NULL: sPushNull(sc); break;
        case OP_PUSH_TRUE: sPushBool(sc, true); break;
        case OP_PUSH_FALSE: sPushBool(sc, false); break;
        case OP_ADD: sIntBinary(sc, sIntAdd, sFloatAdd); break;
        case OP_MUL: sIntBinary(sc, sIntMul, sFloatMul); break;
        case OP_PRINT: sPrint(sc); break;
    }
}

static bool sHasData(Script* sc) {
    return sc->readIndex < sc->byteCodeLength;
}

Script* sInit(unsigned char* byteCode, int codeLength) {
    Script* sc = malloc(sizeof(Script));
    sc->error[0] = '\0';
    sc->byteCode = byteCode;
    sc->byteCodeLength = codeLength;
    sc->readIndex = 0;
    sc->stackIndex = 0;
    sc->line = 0;
    return sc;
}

void sDelete(Script* sc) {
    free(sc->byteCode);
}

void sRun(Script* sc) {
    while(sHasData(sc)) {
        sConsumeInstruction(sc);
        if(sc->error[0] != '\0') {
            puts(sc->error);
            return;
        }
    }
}

void sSetPrinter(ObjectPrinter p) {
    printer = p;
}