| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 | #include "core/BitArray.h"#include <assert.h>#include <inttypes.h>#include <string.h>#include "core/ToString.h"#include "core/Utility.h"static constexpr size_t U64_BITS = 64;static constexpr size_t DIVIDE_BITS = 6;static u64 roundUpDivide(u64 a, u64 b) {    return a / b + ((a % b) != 0);}static u64 readBits(const u64* data, size_t index, u64 bits) {    u64 dataIndexA = (index * bits) >> DIVIDE_BITS;    u64 dataIndexB = ((index * bits) + (bits - 1lu)) >> DIVIDE_BITS;    u64 shifts = (index * bits) & (U64_BITS - 1lu);    if(dataIndexA == dataIndexB) {        return (data[dataIndexA] >> shifts) & ((1lu << bits) - 1lu);    }    u64 bitsInA = U64_BITS - shifts;    u64 r = (data[dataIndexA] >> shifts) & ((1lu << bitsInA) - 1lu);    r |= (data[dataIndexB] & ((1lu << (bits - bitsInA)) - 1lu)) << bitsInA;    return r;}static void writeBits(u64* data, size_t index, size_t bits, u64 value) {    u64 mask = (1lu << bits) - 1lu;    value &= mask;    u64 dataIndexA = (index * bits) >> DIVIDE_BITS;    u64 dataIndexB = ((index * bits) + (bits - 1lu)) >> DIVIDE_BITS;    u64 shifts = (index * bits) & (U64_BITS - 1lu);    data[dataIndexA] &= ~(mask << shifts);    data[dataIndexA] |= (value << shifts);    if(dataIndexA != dataIndexB) {        u64 leftBits = bits - (U64_BITS - shifts);        data[dataIndexB] &= ~((1lu << leftBits) - 1lu);        data[dataIndexB] |= (value >> (U64_BITS - shifts));    }}static size_t getArrayLength(size_t length, size_t bits) {    return roundUpDivide(length * bits, U64_BITS);}void initBitArray(BitArray* a, size_t length, size_t bits) {    *a = (BitArray){0};    if(length > 0 && bits > 0) {        setBitLength(a, length, bits);    }}void destroyBitArray(BitArray* a) {    coreFree(a->data);    *a = (BitArray){0};}void setBits(BitArray* a, size_t index, u64 value) {    assert(a->data != nullptr);    assert(index < a->length);    writeBits(a->data, index, a->bits, value);}void setAllBits(BitArray* a, u64 value) {    size_t length = a->length;    for(size_t i = 0; i < length; i++) {        setBits(a, i, value);    }}u64 getBits(const BitArray* a, size_t index) {    assert(a->data != nullptr);    assert(index < a->length);    return readBits(a->data, index, a->bits);}i64 selectBits(const BitArray* a, size_t index) {    if(index <= 0) {        return -1;    }    u64 found = 0;    size_t end = getArrayLength(a->length, a->bits);    for(size_t i = 0; i < end; i++) {        u64 ones = popCount(a->data[i]);        found += ones;        if(found >= index) {            found -= ones;            u64 c = i * U64_BITS - 1;            u64 d = a->data[i];            while(found < index) {                found += d & 1;                d >>= 1;                c++;            }            return (i64)c;        }    }    return -1;}void setBitLength(BitArray* a, size_t newLength, size_t newBits) {    if(newLength == 0 || newBits == 0) {        destroyBitArray(a);        return;    } else if(newBits > 64) {        newBits = 64;    }    size_t arrayLength = getArrayLength(newLength, newBits);    u64* newData = coreAllocate(sizeof(u64) * arrayLength);    memset(newData, 0, arrayLength * sizeof(u64));    size_t end = minSize(a->length, newLength);    for(size_t i = 0; i < end; i++) {        writeBits(newData, i, newBits, getBits(a, i));    }    for(size_t i = end; i < newLength; i++) {        writeBits(newData, i, newBits, 0);    }    coreFree(a->data);    a->data = newData;    a->length = newLength & 0xFF'FFFF'FFFF'FFFF;    a->bits = newBits & 0xFF;}size_t toStringBitArray(const BitArray* a, char* buffer, size_t n) {    size_t w = 0;    stringAdd(&w, &buffer, &n, toString(buffer, n, "["));    size_t length = a->length;    if(length > 0) {        length--;        for(size_t i = 0; i < length; i++) {            u64 v = getBits(a, i);            stringAdd(&w, &buffer, &n, toString(buffer, n, "%" PRIu64 ", ", v));        }        u64 v = getBits(a, length);        stringAdd(&w, &buffer, &n, toString(buffer, n, "%" PRIu64, v));    }    stringAdd(&w, &buffer, &n, toString(buffer, n, "]"));    return w;}
 |