#include "core/data/BitArray.hpp"

#include <string.h>

#include "core/math/Math.hpp"
#include "core/utils/New.hpp"

static u64 roundUpDivide(u64 a, u64 b) {
    if(a % b == 0) {
        return a / b;
    }
    return a / b + 1;
}

static constexpr u64 U64_BITS = 64;
static constexpr u64 DIVIDE_BITS = Core::Math::roundUpLog2(U64_BITS);
static constexpr u64 LENGTH_MASK = 0x01FF'FFFF'FFFF'FFFF;
static constexpr u64 LENGTH_BITS = Core::Math::roundUpLog2(LENGTH_MASK);
static_assert(LENGTH_BITS == 57, "bit array calculation error");

static u64 readBits(const u64* data, size_t index, u64 bits) {
    u64 dataIndexA = (index * bits) >> DIVIDE_BITS;
    u64 dataIndexB = ((index * bits) + (bits - 1lu)) >> DIVIDE_BITS;
    u64 shifts = (index * bits) & (U64_BITS - 1lu);
    if(dataIndexA == dataIndexB) {
        return (data[dataIndexA] >> shifts) & ((1lu << bits) - 1lu);
    }
    u64 bitsInA = U64_BITS - shifts;
    u64 r = (data[dataIndexA] >> shifts) & ((1lu << bitsInA) - 1lu);
    r |= (data[dataIndexB] & ((1lu << (bits - bitsInA)) - 1lu)) << bitsInA;
    return r;
}

static void setBits(u64* data, size_t index, size_t bits, u64 value) {
    u64 mask = (1lu << bits) - 1lu;
    value &= mask;
    u64 dataIndexA = (index * bits) >> DIVIDE_BITS;
    u64 dataIndexB = ((index * bits) + (bits - 1lu)) >> DIVIDE_BITS;
    u64 shifts = (index * bits) & (U64_BITS - 1lu);
    data[dataIndexA] &= ~(mask << shifts);
    data[dataIndexA] |= (value << shifts);
    if(dataIndexA != dataIndexB) {
        u64 leftBits = bits - (U64_BITS - shifts);
        data[dataIndexB] &= ~((1lu << leftBits) - 1lu);
        data[dataIndexB] |= (value >> (U64_BITS - shifts));
    }
}

static size_t getArrayLength(size_t length, size_t bits) {
    return roundUpDivide(length * bits, U64_BITS);
}

Core::BitArray::BitArray() : lengthBits(0), data(nullptr) {
}

Core::BitArray::BitArray(const BitArray& other) : BitArray() {
    (void)resize(other.getLength(), other.getBits());
    size_t length = getLength();
    for(size_t i = 0; i < length; i++) {
        set(i, other.get(i));
    }
}

Core::BitArray::BitArray(BitArray&& other) : BitArray() {
    swap(other);
}

Core::BitArray::~BitArray() {
    delete[] data;
}

Core::BitArray& Core::BitArray::operator=(BitArray other) {
    swap(other);
    return *this;
}

Core::BitArray& Core::BitArray::set(size_t index, u64 value) {
    if(data == nullptr || index >= getLength()) {
        return *this;
    }
    setBits(data, index, getBits(), value);
    return *this;
}

u64 Core::BitArray::get(size_t index) const {
    if(data == nullptr || index >= getLength()) {
        return 0;
    }
    return readBits(data, index, getBits());
}

size_t Core::BitArray::getLength() const {
    return lengthBits & LENGTH_MASK;
}

size_t Core::BitArray::getBits() const {
    return (lengthBits & ~LENGTH_MASK) >> LENGTH_BITS;
}

size_t Core::BitArray::getInternalByteSize() const {
    if(getLength() <= 0 || getBits() <= 0) {
        return 0;
    }
    return getArrayLength(getLength(), getBits()) * sizeof(u64);
}

i64 Core::BitArray::select(u64 index) const {
    if(index <= 0) {
        return -1;
    }
    u64 found = 0;
    size_t end = getArrayLength(getLength(), getBits());
    for(size_t i = 0; i < end; i++) {
        u64 ones = Core::popCount<u64, u64>(data[i]);
        found += ones;
        if(found >= index) {
            found -= ones;
            u64 a = i * U64_BITS - 1;
            u64 d = data[i];
            while(found < index) {
                found += d & 1;
                d >>= 1;
                a++;
            }
            return static_cast<i64>(a);
        }
    }
    return -1;
}

void Core::BitArray::fill(u64 value) {
    size_t length = getLength();
    for(size_t i = 0; i < length; i++) {
        set(i, value);
    }
}

CError Core::BitArray::resize(size_t newLength, size_t newBits) {
    if(newLength == 0 || newBits == 0 || newBits > 64) {
        return ErrorCode::INVALID_ARGUMENT;
    }
    size_t arrayLength = getArrayLength(newLength, newBits);
    u64* newData = new(noThrow) u64[arrayLength];
    memset(newData, 0, arrayLength * sizeof(u64));

    size_t end = Math::min(getLength(), newLength);
    for(size_t i = 0; i < end; i++) {
        setBits(newData, i, newBits, get(i));
    }
    for(size_t i = end; i < newLength; i++) {
        setBits(newData, i, newBits, 0);
    }
    delete[] data;
    data = newData;
    lengthBits = newLength | (newBits << LENGTH_BITS);
    return ErrorCode::NONE;
}

void Core::BitArray::swap(BitArray& other) {
    Core::swap(lengthBits, other.lengthBits);
    Core::swap(data, other.data);
}