#include <setjmp.h>
#include <stdarg.h>
#include <stdio.h>
#include <string.h>

#include "Compiler.h"
#include "FunctionMap.h"
#include "Operation.h"
#include "StringIntMap.h"
#include "tokenizer/Tokenizer.h"

#define ERROR_LENGTH 256
#define RETURN_BUFFER 16
#define BREAK_BUFFER 32

static jmp_buf errorJump;
static char error[ERROR_LENGTH] = {'\0'};

static ByteCode* code;

static int16 line = 1;

static int varIndex = 0;
static StringIntMap vars[2];
static FunctionMap functions;

static int returns[RETURN_BUFFER];
static int returnIndex = 0;
static int returnState = 0;

static int breaks[BREAK_BUFFER];
static int breakIndex = 0;
static int forWhileStack = 0;
static int continueAt = 0;

static void cError(const char* format, ...) {
    va_list args;
    va_start(args, format);
    vsnprintf(error, ERROR_LENGTH, format, args);
    va_end(args);
    longjmp(errorJump, 0);
}

static int cAddVar(const char* var) {
    int index = vars[varIndex].entries;
    simAdd(vars + varIndex, var, &index);
    return index;
}

static void cUnexpectedToken(Token t) {
    cError("unexpected token on line %d: %s", line, tGetName(t));
}

static void cAddOperation(Operation token) {
    unsigned char c = token;
    bcAddBytes(code, &c, 1);
}

static int cReserveInt() {
    return bcReserveBytes(code, sizeof(int));
}

static void cSetInt(int p, int i) {
    bcSetBytes(code, p, &i, sizeof(int));
}

static void cAddInt(int i) {
    bcAddBytes(code, &i, sizeof(int));
}

static void cAddInt16(int16 i) {
    bcAddBytes(code, &i, sizeof(int16));
}

static void cAddFloat(float f) {
    bcAddBytes(code, &f, sizeof(float));
}

static int cAddPush(int offset) {
    cAddOperation(OP_PUSH_VARS);
    int p = cReserveInt();
    cAddInt(offset);
    return p;
}

static void cAddPop(int p, int vars) {
    cAddOperation(OP_POP_VARS);
    cAddInt(vars);
    cSetInt(p, vars);
}

static Token cReadTokenAndLine() {
    Token t = tReadToken();
    if(tReadInt16(&line)) {
        return t;
    }
    return T_END;
}

static void cConsumeToken(Token wanted) {
    Token t = cReadTokenAndLine();
    if(wanted != t) {
        cError("unexpected token on line %d: expected '%s' got '%s'", line,
               tGetName(wanted), tGetName(t));
    }
}

static bool cConsumeTokenIf(Token t) {
    if(tPeekToken() == t) {
        cReadTokenAndLine();
        return true;
    }
    return false;
}

static void cConstantInt() {
    int value;
    if(!tReadInt(&value)) {
        cError("int token without an int on line %d", line);
    }
    cAddOperation(OP_PUSH_INT);
    cAddInt(value);
}

static void cConstantFloat() {
    float value;
    if(!tReadFloat(&value)) {
        cError("float token without a float on line %d", line);
    }
    cAddOperation(OP_PUSH_FLOAT);
    cAddFloat(value);
}

static void cConstantString() {
    int length;
    const char* s = tReadString(&length);
    if(s == NULL) {
        cError("text without string on line %d", line);
    }
    cAddOperation(OP_PUSH_CONST_STRING);
    cAddInt(length);
    bcAddBytes(code, s, length);
}

static const char* cReadString() {
    int length;
    const char* literal = tReadString(&length);
    if(literal == NULL) {
        cError("literal without string on line %d", line);
    }
    return literal;
}

static void cExpression();

static int cCallFunctionArguments() {
    int arguments = 0;
    while(!cConsumeTokenIf(T_CLOSE_BRACKET)) {
        arguments++;
        cExpression();
        if(cConsumeTokenIf(T_COMMA) && tPeekToken() == T_CLOSE_BRACKET) {
            cUnexpectedToken(tPeekToken());
        }
    }
    return arguments;
}

static void cCallFunction(const char* literal, bool noReturn) {
    cAddOperation(OP_PUSH_INT);
    cAddInt(0);
    int arguments = cCallFunctionArguments();
    Function* f = fmSearch(&functions, literal, arguments);
    cAddOperation(OP_GOSUB);
    if(f == NULL) {
        fmEnqueue(&functions, literal, arguments, line, cReserveInt(),
                  noReturn);
        cAddInt(arguments);
        cAddOperation(OP_NOTHING);
    } else {
        if(!noReturn && !f->returns) {
            cError("function '%s' needs a return value on line %d", f->name,
                   line);
        }
        cAddInt(f->address);
        cAddInt(arguments);
        if(f->returns && noReturn) {
            cAddOperation(OP_POP);
        }
    }
}

static void cAddReference(const char* var) {
    if(cConsumeTokenIf(T_OPEN_SQUARE_BRACKET)) {
        cExpression();
        cAddOperation(OP_REFERENCE_FROM_ARRAY);
        cAddInt(cAddVar(var));
        cConsumeToken(T_CLOSE_SQUARE_BRACKET);
    } else {
        cAddOperation(OP_REFERENCE_FROM_VAR);
        cAddInt(cAddVar(var));
    }
}

static void cLiteral() {
    const char* literal = cReadString();
    if(cConsumeTokenIf(T_OPEN_BRACKET)) {
        cCallFunction(literal, false);
        return;
    }
    cAddReference(literal);
    if(cConsumeTokenIf(T_INCREMENT)) {
        cAddOperation(OP_POST_INCREMENT);
    } else if(cConsumeTokenIf(T_DECREMENT)) {
        cAddOperation(OP_POST_DECREMENT);
    } else if(cConsumeTokenIf(T_POINT)) {
        cConsumeToken(T_LITERAL);
        const char* access = cReadString();
        if(strcmp(access, "length") == 0) {
            cAddOperation(OP_ARRAY_LENGTH);
        } else {
            cError("'%s' not supported after . on line %d", access, line);
        }
    } else {
        cAddOperation(OP_DEREFERENCE);
    }
}

static void cArray() {
    cConsumeToken(T_OPEN_SQUARE_BRACKET);
    cExpression();
    cConsumeToken(T_CLOSE_SQUARE_BRACKET);
    cAddOperation(OP_ALLOCATE_ARRAY);
}

static void cPrimary() {
    Token t = cReadTokenAndLine();
    switch(t) {
        case T_INT: cConstantInt(); break;
        case T_FLOAT: cConstantFloat(); break;
        case T_TEXT: cConstantString(); break;
        case T_NULL: cAddOperation(OP_PUSH_NULL); break;
        case T_TRUE: cAddOperation(OP_PUSH_TRUE); break;
        case T_FALSE: cAddOperation(OP_PUSH_FALSE); break;
        case T_OPEN_BRACKET:
            cExpression();
            cConsumeToken(T_CLOSE_BRACKET);
            break;
        case T_LITERAL: cLiteral(); break;
        case T_ARRAY: cArray(); break;
        default: cUnexpectedToken(t); break;
    }
}

static void cPreChange(Operation op) {
    cConsumeToken(T_LITERAL);
    cAddReference(cReadString());
    cAddOperation(op);
}

static void cPreUnary() {
    if(cConsumeTokenIf(T_SUB)) {
        cPrimary();
        cAddOperation(OP_INVERT_SIGN);
    } else if(cConsumeTokenIf(T_INCREMENT)) {
        cPreChange(OP_PRE_INCREMENT);
    } else if(cConsumeTokenIf(T_DECREMENT)) {
        cPreChange(OP_PRE_DECREMENT);
    } else if(cConsumeTokenIf(T_NOT)) {
        int counter = 1;
        while(cConsumeTokenIf(T_NOT)) {
            counter++;
        }
        cPrimary();
        cAddOperation(OP_NOT);
        if((counter & 1) == 0) {
            cAddOperation(OP_NOT);
        }
    } else if(cConsumeTokenIf(T_BIT_NOT)) {
        cPrimary();
        cAddOperation(OP_BIT_NOT);
    } else {
        cPrimary();
    }
}

static void cMul() {
    cPreUnary();
    while(true) {
        if(cConsumeTokenIf(T_MUL)) {
            cPreUnary();
            cAddOperation(OP_MUL);
        } else if(cConsumeTokenIf(T_DIV)) {
            cPreUnary();
            cAddOperation(OP_DIV);
        } else if(cConsumeTokenIf(T_MOD)) {
            cPreUnary();
            cAddOperation(OP_MOD);
        } else {
            break;
        }
    }
}

static void cAdd() {
    cMul();
    while(true) {
        if(cConsumeTokenIf(T_ADD)) {
            cMul();
            cAddOperation(OP_ADD);
        } else if(cConsumeTokenIf(T_SUB)) {
            cMul();
            cAddOperation(OP_SUB);
        } else {
            break;
        }
    }
}

static void cShift() {
    cAdd();
    while(true) {
        if(cConsumeTokenIf(T_LEFT_SHIFT)) {
            cAdd();
            cAddOperation(OP_LEFT_SHIFT);
        } else if(cConsumeTokenIf(T_RIGHT_SHIFT)) {
            cAdd();
            cAddOperation(OP_RIGHT_SHIFT);
        } else {
            break;
        }
    }
}

static void cComparison() {
    cShift();
    while(true) {
        if(cConsumeTokenIf(T_LESS)) {
            cShift();
            cAddOperation(OP_LESS);
        } else if(cConsumeTokenIf(T_LESS_EQUAL)) {
            cShift();
            cAddOperation(OP_GREATER);
            cAddOperation(OP_NOT);
        } else if(cConsumeTokenIf(T_GREATER)) {
            cShift();
            cAddOperation(OP_GREATER);
        } else if(cConsumeTokenIf(T_GREATER_EQUAL)) {
            cShift();
            cAddOperation(OP_LESS);
            cAddOperation(OP_NOT);
        } else {
            break;
        }
    }
}

static void cEqual() {
    cComparison();
    while(true) {
        if(cConsumeTokenIf(T_EQUAL)) {
            cComparison();
            cAddOperation(OP_EQUAL);
        } else if(cConsumeTokenIf(T_NOT_EQUAL)) {
            cComparison();
            cAddOperation(OP_EQUAL);
            cAddOperation(OP_NOT);
        } else {
            break;
        }
    }
}

static void cBitAnd() {
    cEqual();
    while(cConsumeTokenIf(T_BIT_AND)) {
        cEqual();
        cAddOperation(OP_BIT_AND);
    }
}

static void cBitXor() {
    cBitAnd();
    while(cConsumeTokenIf(T_BIT_XOR)) {
        cBitAnd();
        cAddOperation(OP_BIT_XOR);
    }
}

static void cBitOr() {
    cBitXor();
    while(cConsumeTokenIf(T_BIT_OR)) {
        cBitXor();
        cAddOperation(OP_BIT_OR);
    }
}

static void cAnd() {
    cBitOr();
    while(cConsumeTokenIf(T_AND)) {
        cAddOperation(OP_DUPLICATE);
        cAddOperation(OP_IF_GOTO);
        int p = cReserveInt();
        cBitOr();
        cAddOperation(OP_AND);
        cSetInt(p, code->length);
    }
}

static void cOr() {
    cAnd();
    while(cConsumeTokenIf(T_OR)) {
        cAddOperation(OP_DUPLICATE);
        cAddOperation(OP_NOT);
        cAddOperation(OP_IF_GOTO);
        int p = cReserveInt();
        cAnd();
        cAddOperation(OP_OR);
        cSetInt(p, code->length);
    }
}

static void cExpression() {
    cOr();
}

static void cOperationSet(Operation op) {
    cAddOperation(OP_DUPLICATE);
    cAddOperation(OP_DEREFERENCE);
    cExpression();
    cAddOperation(op);
    cAddOperation(OP_SET);
}

static void cLineLiteral() {
    const char* literal = cReadString();
    if(cConsumeTokenIf(T_OPEN_BRACKET)) {
        cCallFunction(literal, true);
        return;
    }
    cAddReference(literal);
    Token t = cReadTokenAndLine();
    switch(t) {
        case T_SET:
            cExpression();
            cAddOperation(OP_SET);
            break;
        case T_ADD_SET: cOperationSet(OP_ADD); break;
        case T_SUB_SET: cOperationSet(OP_SUB); break;
        case T_MUL_SET: cOperationSet(OP_MUL); break;
        case T_DIV_SET: cOperationSet(OP_DIV); break;
        case T_MOD_SET: cOperationSet(OP_MOD); break;
        case T_BIT_AND_SET: cOperationSet(OP_BIT_AND); break;
        case T_BIT_OR_SET: cOperationSet(OP_BIT_OR); break;
        case T_BIT_XOR_SET: cOperationSet(OP_BIT_XOR); break;
        case T_LEFT_SHIFT_SET: cOperationSet(OP_LEFT_SHIFT); break;
        case T_RIGHT_SHIFT_SET: cOperationSet(OP_RIGHT_SHIFT); break;
        case T_INCREMENT:
            cAddOperation(OP_POST_INCREMENT);
            cAddOperation(OP_POP);
            break;
        case T_DECREMENT:
            cAddOperation(OP_POST_DECREMENT);
            cAddOperation(OP_POP);
            break;
        default: cUnexpectedToken(t);
    }
}

static int cFunctionArguments() {
    int arguments = 0;
    while(!cConsumeTokenIf(T_CLOSE_BRACKET)) {
        cConsumeToken(T_LITERAL);
        arguments++;
        cAddVar(cReadString());
        if(cConsumeTokenIf(T_COMMA) && tPeekToken() != T_LITERAL) {
            cUnexpectedToken(tPeekToken());
        }
    }
    return arguments;
}

static void cLine(Token t);

static void cConsumeBody() {
    cConsumeToken(T_OPEN_CURVED_BRACKET);
    int oldLine = line;
    while(!cConsumeTokenIf(T_CLOSE_CURVED_BRACKET)) {
        Token t = cReadTokenAndLine();
        if(t == T_END) {
            cError(
                "unexpected end of file: non closed curved bracket on line %d",
                oldLine);
        }
        cLine(t);
    }
}

static void cLinkReturns() {
    for(int i = 0; i < returnIndex; i++) {
        cSetInt(returns[i], vars[1].entries);
    }
    returnIndex = 0;
}

static void cFunctionBody(const char* name, int arguments) {
    int oldLine = line;
    cAddOperation(OP_GOTO);
    int gotoIndex = cReserveInt();

    int address = code->length;
    returnState = 0;

    int p = cAddPush(arguments);
    cConsumeBody();
    cAddPop(p, vars[1].entries);

    cLinkReturns();

    if(!fmAdd(&functions, name, arguments, address, returnState == 2)) {
        cError("function registered twice on line %d", oldLine);
    }

    cAddOperation(OP_RETURN);
    cSetInt(gotoIndex, code->length);
}

static void cFunction() {
    if(varIndex == 1) {
        cError("function inside function on line %d", line);
    }
    cConsumeToken(T_LITERAL);
    const char* name = cReadString();
    cConsumeToken(T_OPEN_BRACKET);
    varIndex = 1;
    vars[1].entries = 0;
    cFunctionBody(name, cFunctionArguments());
    varIndex = 0;
}

static void cAddReturn() {
    cAddOperation(OP_POP_VARS);
    returns[returnIndex++] = cReserveInt(vars);
    cAddOperation(OP_RETURN);
}

static void cReturn() {
    if(varIndex == 0) {
        cError("return without a function on line %d", line);
    } else if(returnIndex >= RETURN_BUFFER) {
        cError("too much returns in function around line %d", line);
    }
    if(cConsumeTokenIf(T_SEMICOLON)) {
        if(returnState == 2) {
            cError("mixed return type on line %d", line);
        }
        returnState = 1;
        cAddReturn();
    } else {
        if(returnState == 1) {
            cError("mixed return type on line %d", line);
        }
        returnState = 2;
        cExpression();
        cAddOperation(OP_SET_RETURN);
        cAddReturn();
        cConsumeToken(T_SEMICOLON);
    }
}

static void cPrint() {
    cExpression();
    cConsumeToken(T_SEMICOLON);
    cAddOperation(OP_PRINT);
}

static void cIf() {
    cConsumeToken(T_OPEN_BRACKET);
    cExpression();
    cConsumeToken(T_CLOSE_BRACKET);
    cAddOperation(OP_IF_GOTO);
    int ifP = cReserveInt();
    cConsumeBody();
    cSetInt(ifP, code->length);

    if(cConsumeTokenIf(T_ELSE)) {
        cAddOperation(OP_GOTO);
        int elseP = cReserveInt();
        cSetInt(ifP, code->length);
        if(cConsumeTokenIf(T_IF)) {
            cIf();
        } else {
            cConsumeBody();
        }
        cSetInt(elseP, code->length);
    }
}

static void cConsumeBreaks(int start, int address) {
    for(int i = start; i < breakIndex; i++) {
        cSetInt(breaks[i], address);
    }
    breakIndex = start;
}

static void cWhile() {
    int start = code->length;
    cConsumeToken(T_OPEN_BRACKET);
    cExpression();
    cConsumeToken(T_CLOSE_BRACKET);
    cAddOperation(OP_IF_GOTO);
    int ifP = cReserveInt();
    int breakStart = breakIndex;
    forWhileStack++;
    int oldContinue = continueAt;
    continueAt = start;
    cConsumeBody();
    continueAt = oldContinue;
    forWhileStack--;
    cAddOperation(OP_GOTO);
    cAddInt(start);
    cSetInt(ifP, code->length);
    cConsumeBreaks(breakStart, code->length);
}

static void cLineExpression(Token t) {
    switch(t) {
        case T_LITERAL: cLineLiteral(); break;
        case T_INCREMENT:
            cPreChange(OP_PRE_INCREMENT);
            cAddOperation(OP_POP);
            break;
        case T_DECREMENT:
            cPreChange(OP_PRE_DECREMENT);
            cAddOperation(OP_POP);
            break;
        default: cUnexpectedToken(t);
    }
}

static void cFor() {
    cConsumeToken(T_OPEN_BRACKET);
    cLineExpression(cReadTokenAndLine());
    cConsumeToken(T_SEMICOLON);
    int startCheck = code->length;
    cExpression();
    cConsumeToken(T_SEMICOLON);
    cAddOperation(OP_IF_GOTO);
    int end = cReserveInt();
    cAddOperation(OP_GOTO);
    int beginBody = cReserveInt();
    int startPerLoop = code->length;
    cLineExpression(cReadTokenAndLine());
    cAddOperation(OP_GOTO);
    cAddInt(startCheck);
    cConsumeToken(T_CLOSE_BRACKET);
    cSetInt(beginBody, code->length);
    int breakStart = breakIndex;
    forWhileStack++;
    int oldContinue = continueAt;
    continueAt = startPerLoop;
    cConsumeBody();
    continueAt = oldContinue;
    forWhileStack--;
    cAddOperation(OP_GOTO);
    cAddInt(startPerLoop);
    cSetInt(end, code->length);
    cConsumeBreaks(breakStart, code->length);
}

static void cBreak() {
    if(forWhileStack == 0) {
        cError("break without for or while on line %d", line);
    } else if(breakIndex >= BREAK_BUFFER) {
        cError("too much breaks around line %d", line);
    }
    cAddOperation(OP_GOTO);
    breaks[breakIndex++] = cReserveInt();
    cConsumeToken(T_SEMICOLON);
}

static void cContinue() {
    if(forWhileStack == 0) {
        cError("continue without for or while on line %d", line);
    }
    cAddOperation(OP_GOTO);
    cAddInt(continueAt);
    cConsumeToken(T_SEMICOLON);
}

static void cLine(Token t) {
    cAddOperation(OP_LINE);
    cAddInt16(line);
    switch(t) {
        case T_PRINT: cPrint(); break;
        case T_FUNCTION: cFunction(); break;
        case T_RETURN: cReturn(); break;
        case T_IF: cIf(); break;
        case T_WHILE: cWhile(); break;
        case T_FOR: cFor(); break;
        case T_BREAK: cBreak(); break;
        case T_CONTINUE: cContinue(); break;
        default: cLineExpression(t); cConsumeToken(T_SEMICOLON);
    }
}

static void cForEachLine() {
    Token t = cReadTokenAndLine();
    while(t != T_END) {
        cLine(t);
        t = cReadTokenAndLine();
    }
}

static void cLinkQueuedFunctions() {
    for(int i = 0; i < functions.queueEntries; i++) {
        Function* f = fmSearch(&functions, functions.queue[i].name,
                               functions.queue[i].arguments);
        if(f == NULL) {
            cError("unknown function on line %d", functions.queue[i].line);
        } else if(!functions.queue[i].noReturn && !f->returns) {
            cError("function '%s' needs a return value on line %d", f->name,
                   functions.queue[i].line);
        }
        cSetInt(functions.queue[i].reserved, f->address);
        if(functions.queue[i].noReturn && f->returns) {
            code->code[functions.queue[i].reserved + sizeof(int) * 2] = OP_POP;
        }
    }
}

static void cAllocAndCompile() {
    varIndex = 0;
    returnIndex = 0;
    returnState = 0;
    forWhileStack = 0;
    breakIndex = 0;
    simInit(vars);
    simInit(vars + 1);
    fmInit(&functions);
    if(!setjmp(errorJump)) {
        int p = cAddPush(0);
        cForEachLine();
        cAddPop(p, vars[varIndex].entries);
        cLinkQueuedFunctions();
    }
    fmDelete(&functions);
    simDelete(vars + 1);
    simDelete(vars);
}

ByteCode* cCompile() {
    error[0] = '\0';
    code = bcInit();
    cAllocAndCompile();
    if(error[0] != '\0') {
        bcDelete(code);
        return NULL;
    }
    return code;
}

const char* cGetError() {
    return error;
}