#include <iostream>

#include "memory/StackAllocator.h"

static constexpr int MAX_BYTES = 50 * 1024 * 1024;
alignas(16) static char data[MAX_BYTES];
static int index = 0;
static int lastIndex = 0;

StackAllocator::Pointer incrementIndex(int inc) {
    StackAllocator::Pointer p = {lastIndex, index};
    lastIndex = index;
    if((inc & 0xF) != 0) {
        inc = (inc & ~0xF) + 16;
    }
    index += inc;
    return p;
}

StackAllocator::Pointer StackAllocator::allocate(int bytesPerElement, int& elements) {
    int bytes = bytesPerElement * elements;
    if(index + bytes <= MAX_BYTES) {
        return incrementIndex(bytes);
    }
    elements = (MAX_BYTES - index) / bytesPerElement;
    return incrementIndex(bytesPerElement * elements);
}

void StackAllocator::free(const Pointer& p) {
    if(p.pointer > index) {
        return;
    }
    index = p.pointer;
    lastIndex = p.lastPointer;
}

int StackAllocator::grow(const Pointer& p, int bytesPerElement, int elements) {
    if(p.pointer != lastIndex) {
        return 0;
    }
    int bytes = bytesPerElement * elements;
    if(index + bytes <= MAX_BYTES) {
        index += bytes;
        return elements;
    }
    elements = (MAX_BYTES - index) / bytesPerElement;
    index += bytesPerElement * elements;
    return elements;
}

void* StackAllocator::get(const Pointer& p) {
    return data + p.pointer;
}