#include "data/BitArray.h"

#include "math/Math.h"
#include "utils/Utility.h"

static int roundUpDivide(int a, int b) {
    if(a % b == 0) {
        return a / b;
    }
    return a / b + 1;
}

static constexpr int INT_BITS = sizeof(int) * 8;
static constexpr int DIVIDE_BITS = Core::Math::roundUpLog2(INT_BITS);

static int readBits(const int* data, int index, int bits) {
    int dataIndexA = (index * bits) >> DIVIDE_BITS;
    int dataIndexB = ((index + 1) * bits) >> DIVIDE_BITS;
    int shifts = (index * bits) & (INT_BITS - 1);
    if(dataIndexA == dataIndexB) {
        return (data[dataIndexA] >> shifts) & ((1 << bits) - 1);
    }
    int bitsInA = INT_BITS - shifts;
    int r = (data[dataIndexA] >> shifts) & ((1 << bitsInA) - 1);
    r |= (data[dataIndexB] & ((1 << (bits - bitsInA)) - 1)) << bitsInA;
    return r;
}

static void setBits(int* data, int index, int bits, int value) {
    int mask = (1 << bits) - 1;
    value &= mask;
    int dataIndexA = (index * bits) >> DIVIDE_BITS;
    int dataIndexB = ((index + 1) * bits) >> DIVIDE_BITS;
    int shifts = (index * bits) & (INT_BITS - 1);
    data[dataIndexA] &= ~(mask << shifts);
    data[dataIndexA] |= (value << shifts);
    if(dataIndexA != dataIndexB) {
        int leftBits = bits - (INT_BITS - shifts);
        data[dataIndexB] &= ~((1 << leftBits) - 1);
        data[dataIndexB] |= (value >> (INT_BITS - shifts));
    }
}

static int getArrayLength(int length, int bits) {
    return roundUpDivide(length * bits, sizeof(int) * 8);
}

Core::BitArray::BitArray() : length(0), bits(0), data(nullptr) {
}

check_return bool Core::BitArray::copyFrom(const BitArray& other) {
    if(resize(other.length, other.bits)) {
        return true;
    }
    for(int i = 0; i < length; i++) {
        set(i, other.get(i));
    }
    return false;
}

Core::BitArray::BitArray(BitArray&& other) : BitArray() {
    swap(other);
}

Core::BitArray::~BitArray() {
    delete[] data;
}

Core::BitArray& Core::BitArray::operator=(BitArray&& other) {
    swap(other);
    return *this;
}

Core::BitArray& Core::BitArray::set(int index, int value) {
    if(data == nullptr || index < 0 || index >= length) {
        return *this;
    }
    setBits(data, index, bits, value);
    return *this;
}

int Core::BitArray::get(int index) const {
    if(data == nullptr || index < 0 || index >= length) {
        return 0;
    }
    return readBits(data, index, bits);
}

int Core::BitArray::getLength() const {
    return length;
}

int Core::BitArray::getBits() const {
    return bits;
}

int Core::BitArray::getInternalByteSize() const {
    if(bits <= 0 || length <= 0) {
        return 0;
    }
    return getArrayLength(length, bits) * CORE_SIZE(int);
}

int Core::BitArray::select(int index) const {
    if(index <= 0) {
        return -1;
    }
    int found = 0;
    int end = getArrayLength(length, bits);
    for(int i = 0; i < end; i++) {
        int ones = Core::popCount(data[i]);
        found += ones;
        if(found >= index) {
            found -= ones;
            int a = i * 32 - 1;
            int d = data[i];
            while(found < index) {
                found += d & 1;
                d >>= 1;
                a++;
            }
            return a;
        }
    }
    return -1;
}

void Core::BitArray::fill(int value) {
    for(int i = 0; i < length; i++) {
        set(i, value);
    }
}

check_return bool Core::BitArray::resize(int newLength, int newBits) {
    if(newLength <= 0 || newBits <= 0) {
        return CORE_ERROR(Error::NEGATIVE_ARGUMENT);
    }
    int arrayLength = getArrayLength(newLength, newBits);
    int* newData = new int[arrayLength];
    if(newData == nullptr) {
        return CORE_ERROR(Error::OUT_OF_MEMORY);
    }
    Core::memorySet(newData, 0, arrayLength * CORE_SIZE(int));

    int end = Math::min(length, newLength);
    for(int i = 0; i < end; i++) {
        setBits(newData, i, newBits, get(i));
    }
    for(int i = end; i < newLength; i++) {
        setBits(newData, i, newBits, 0);
    }
    delete[] data;
    data = newData;
    bits = newBits;
    length = newLength;
    return false;
}

void Core::BitArray::swap(BitArray& other) {
    Core::swap(length, other.length);
    Core::swap(bits, other.bits);
    Core::swap(data, other.data);
}