Browse Source

bit array is now on the heap

Kajetan Johannes Hammerle 3 năm trước cách đây
mục cha
commit
c7c902a09f
4 tập tin đã thay đổi với 178 bổ sung148 xóa
  1. 1 0
      meson.build
  2. 32 46
      tests/BitArrayTests.cpp
  3. 123 0
      utils/BitArray.cpp
  4. 22 102
      utils/BitArray.h

+ 1 - 0
meson.build

@@ -6,6 +6,7 @@ sources = ['Main.cpp',
     'tests/HeapArrayTests.cpp',
     'tests/HashMapTests.cpp',
     'tests/ListTests.cpp',
+    'utils/BitArray.cpp',
     'tests/BitArrayTests.cpp',
     'tests/StringBufferTests.cpp',
     'tests/RandomTests.cpp',

+ 32 - 46
tests/BitArrayTests.cpp

@@ -6,89 +6,75 @@
 typedef StringBuffer<50> String;
 
 static void testSetRead(Test& test) {
-    BitArray<4, 3> bits;
-    bits[0] = 1;
-    bits[1] = 2;
-    bits[2] = 3;
-    bits[3] = 4;
-    test.checkEqual(1, static_cast<int> (bits[0]), "set and read correct value 1");
-    test.checkEqual(2, static_cast<int> (bits[1]), "set and read correct value 2");
-    test.checkEqual(3, static_cast<int> (bits[2]), "set and read correct value 3");
-    test.checkEqual(4, static_cast<int> (bits[3]), "set and read correct value 4");
+    BitArray bits(4, 3);
+    bits.set(0, 1).set(1, 2).set(2, 3).set(3, 4);
+    test.checkEqual(1, bits.get(0), "set and read correct value 1");
+    test.checkEqual(2, bits.get(1), "set and read correct value 2");
+    test.checkEqual(3, bits.get(2), "set and read correct value 3");
+    test.checkEqual(4, bits.get(3), "set and read correct value 4");
 }
 
 static void testBigSetRead(Test& test) {
-    BitArray<100, 13> bits;
+    BitArray bits(100, 13);
     for(int i = 0; i < bits.getLength(); i++) {
-        bits[i] = i;
+        bits.set(i, i);
     }
     for(int i = 0; i < bits.getLength(); i++) {
-        test.checkEqual(i, static_cast<int> (bits[i]), "set and read correct value over long array");
+        test.checkEqual(i, bits.get(i), "set and read correct value over long array");
     }
 }
 
-static void testRandomSetRead(Test& test) {
+static void testRandomSetReadResize(Test& test) {
     const int length = 100;
     int data[length];
-    BitArray<100, 13> bits;
+    BitArray bits(100, 13);
     int seed = 534;
     for(int k = 0; k < 20; k++) {
         for(int i = 0; i < bits.getLength(); i++) {
             seed = seed * 636455 + 53453;
-            bits[i] = seed & (0x1FFF);
+            bits.set(i, seed & (0x1FFF));
             data[i] = seed & (0x1FFF);
         }
     }
     for(int i = 0; i < bits.getLength(); i++) {
-        test.checkEqual(data[i], static_cast<int> (bits[i]), "set and read correct value with random input");
+        test.checkEqual(data[i], bits.get(i), "set and read correct value with random input");
+    }
+    bits.resize(bits.getBits() + 1);
+    test.checkEqual(14, bits.getBits(), "corrects bits after resize");
+    test.checkEqual(100, bits.getLength(), "correct length after resize");
+    for(int i = 0; i < bits.getLength(); i++) {
+        test.checkEqual(data[i], bits.get(i), "set and read correct value with random input after resize");
     }
 }
 
 static void testReadOnly(Test& test) {
-    BitArray<4, 3> bits;
-    bits[0] = 1;
-    bits[1] = 2;
-    bits[2] = 3;
-    bits[3] = 4;
-    const BitArray<4, 3> bits2 = bits;
-    test.checkEqual(1, bits2[0], "can read from const 1");
-    test.checkEqual(2, bits2[1], "can read from const 2");
-    test.checkEqual(3, bits2[2], "can read from const 3");
-    test.checkEqual(4, bits2[3], "can read from const 4");
-}
-
-static void testChainedSet(Test& test) {
-    BitArray<4, 3> bits;
-    bits[0] = bits[2] = bits[3] = 2;
-    bits[3] = bits[1] = 7;
-    test.checkEqual(2, static_cast<int> (bits[0]), "chained set sets correct value 1");
-    test.checkEqual(7, static_cast<int> (bits[1]), "chained set sets correct value 2");
-    test.checkEqual(2, static_cast<int> (bits[2]), "chained set sets correct value 3");
-    test.checkEqual(7, static_cast<int> (bits[3]), "chained set sets correct value 4");
+    BitArray bits(4, 3);
+    bits.set(0, 1).set(1, 2).set(2, 3).set(3, 4);
+    const BitArray bits2 = bits;
+    test.checkEqual(1, bits2.get(0), "can read from const 1");
+    test.checkEqual(2, bits2.get(1), "can read from const 2");
+    test.checkEqual(3, bits2.get(2), "can read from const 3");
+    test.checkEqual(4, bits2.get(3), "can read from const 4");
 }
 
 static void testToString1(Test& test) {
-    BitArray<4, 3> bits;
-    bits[0] = 1;
-    bits[1] = 2;
-    bits[2] = 3;
-    bits[3] = 4;
+    BitArray bits(4, 3);
+    bits.set(0, 1).set(1, 2).set(2, 3).set(3, 4);
     test.checkEqual(String("[1, 2, 3, 4]"), String(bits), "bit array to string 1");
 }
 
 static void testToString2(Test& test) {
-    BitArray<1, 3> a;
-    a[0] = 1;
-    test.checkEqual(String("[1]"), String(a), "bit array to string 1");
+    BitArray bits(1, 3);
+    bits.set(0, 1);
+    test.checkEqual(String("[1]"), String(bits), "bit array to string 1");
 }
 
 void BitArrayTests::test() {
     Test test("BitArray");
     testSetRead(test);
     testBigSetRead(test);
-    testRandomSetRead(test);
+    testRandomSetReadResize(test);
     testReadOnly(test);
-    testChainedSet(test);
     testToString1(test);
     testToString2(test);
     test.finalize();

+ 123 - 0
utils/BitArray.cpp

@@ -0,0 +1,123 @@
+#include "utils/BitArray.h"
+#include "utils/Utils.h"
+
+static int roundUpDivide(int a, int b) {
+    if(a % b == 0) {
+        return a / b;
+    }
+    return a / b + 1;
+}
+
+static constexpr int INT_BITS = sizeof(int) * 8;
+static constexpr int DIVIDE_BITS = Utils::roundUpLog2(INT_BITS);
+
+static int readBits(const int* data, int index, int bits) {
+    int dataIndexA = (index * bits) >> DIVIDE_BITS;
+    int dataIndexB = ((index + 1) * bits) >> DIVIDE_BITS;
+    int shifts = (index * bits) & (INT_BITS - 1);
+    if(dataIndexA == dataIndexB) {
+        return (data[dataIndexA] >> shifts) & ((1 << bits) - 1);
+    }
+    int bitsInA = INT_BITS - shifts;
+    int r = (data[dataIndexA] >> shifts) & ((1 << bitsInA) - 1);
+    r |= (data[dataIndexB] & ((1 << (bits - bitsInA)) - 1)) << bitsInA;
+    return r;
+}
+
+static void setBits(int* data, int index, int bits, int value) {
+    int mask = (1 << bits) - 1;
+    value &= mask;
+    int dataIndexA = (index * bits) >> DIVIDE_BITS;
+    int dataIndexB = ((index + 1) * bits) >> DIVIDE_BITS;
+    int shifts = (index * bits) & (INT_BITS - 1);
+    data[dataIndexA] &= ~(mask << shifts);
+    data[dataIndexA] |= (value << shifts);
+    if(dataIndexA != dataIndexB) {
+        int leftBits = bits - (INT_BITS - shifts);
+        int mask = (1 << leftBits) - 1;
+        data[dataIndexB] &= ~mask;
+        data[dataIndexB] |= (value >> (INT_BITS - shifts));
+    }
+}
+
+BitArray::BitArray(int length, int bits)
+    : length(length), bits(bits), data(new int[roundUpDivide(length * bits, sizeof(int))]) {
+}
+
+BitArray::~BitArray() {
+    delete[] data;
+}
+
+BitArray::BitArray(const BitArray& other) : BitArray(other.length, other.bits) {
+    copyData(other);
+}
+
+BitArray::BitArray(BitArray&& other) : length(other.length), bits(other.bits), data(other.data) {
+    other.reset();
+}
+
+BitArray& BitArray::operator=(const BitArray& other) {
+    if(this == &other) {
+        return *this;
+    } else if(length == other.length && bits == other.bits) {
+        copyData(other);
+        return *this;
+    }
+    delete[] data;
+    length = other.length;
+    bits = other.bits;
+    data = new int[roundUpDivide(length * bits, sizeof(int))];
+    copyData(other);
+    return *this;
+}
+
+BitArray& BitArray::operator=(BitArray&& other) {
+    if(this == &other) {
+        return *this;
+    }
+    delete[] data;
+    length = other.length;
+    bits = other.bits;
+    data = other.data;
+    other.reset();
+    return *this;
+}
+
+BitArray& BitArray::set(int index, int value) {
+    setBits(data, index, bits, value);
+    return *this;
+}
+
+int BitArray::get(int index) const {
+    return readBits(data, index, bits);
+}
+
+int BitArray::getLength() const {
+    return length;
+}
+
+int BitArray::getBits() const {
+    return bits;
+}
+
+void BitArray::resize(int newBits) {
+    int* newData = new int[roundUpDivide(length * newBits, sizeof(int))];
+    for(int i = 0; i < length; i++) {
+        setBits(newData, i, newBits, get(i));
+    }
+    delete[] data;
+    data = newData;
+    bits = newBits;
+}
+
+void BitArray::copyData(const BitArray& other) {
+    for(int i = 0; i < length; i++) {
+        set(i, other.get(i));
+    }
+}
+
+void BitArray::reset() {
+    length = 0;
+    bits = 0;
+    data = nullptr;
+}

+ 22 - 102
utils/BitArray.h

@@ -1,125 +1,45 @@
 #ifndef BITARRAY_H
 #define BITARRAY_H
 
-#include <iostream>
-
 #include "utils/StringBuffer.h"
 
-template<int N, int BITS>
 class BitArray final {
-    static constexpr int INT_BITS = sizeof(int) * 8;
-
-    static_assert(BITS >= 1, "each bit array element must have at least one bit");
-    static_assert(BITS <= INT_BITS, "each bit array element can have at most as much bits as an int");
-
-    static constexpr int MASK = (1 << BITS) - 1;
-    static constexpr int LENGTH = (N * BITS) / INT_BITS + (((N * BITS) % INT_BITS) > 0);
-    static constexpr bool ALIGNED = (INT_BITS % BITS) == 0;
-
-    constexpr static int getDivideBits() {
-        int c = 0;
-        int i = INT_BITS - 1;
-        while(i > 0) {
-            i >>= 1;
-            c++;
-        }
-        return c;
-    }
-
-    static constexpr int DIVIDE_BITS = getDivideBits();
-
-    int data[LENGTH];
-
-    static int readBits(const int* data, int index) {
-        int dataIndexA = (index * BITS) >> DIVIDE_BITS;
-        int dataIndexB = ((index + 1) * BITS) >> DIVIDE_BITS;
-        int shifts = (index * BITS) & (INT_BITS - 1);
-        if(dataIndexA == dataIndexB || ALIGNED) {
-            return (data[dataIndexA] >> shifts) & MASK;
-        }
-        int bitsInA = INT_BITS - shifts;
-        int r = (data[dataIndexA] >> shifts) & ((1 << bitsInA) - 1);
-        r |= (data[dataIndexB] & ((1 << (BITS - bitsInA)) - 1)) << bitsInA;
-        return r;
-    }
+    int length;
+    int bits;
+    int* data;
 
 public:
-    class BitInt {
-        friend BitArray;
-
-        int* data;
-        int index;
-
-        BitInt(int* data, int index) : data(data), index(index) {
-        }
-
-        BitInt(const BitInt& other) = default;
-        BitInt(BitInt&& other) = default;
-
-    public:
-        BitInt& operator=(const BitInt& other) {
-            return (*this) = static_cast<int>(other);
-        }
-
-        BitInt& operator=(BitInt&& other) {
-            return (*this) = static_cast<int>(other);
-        }
-
-        BitInt& operator=(int i) {
-            i &= MASK;
-            int dataIndexA = (index * BITS) >> DIVIDE_BITS;
-            int dataIndexB = ((index + 1) * BITS) >> DIVIDE_BITS;
-            int shifts = (index * BITS) & (INT_BITS - 1);
-            data[dataIndexA] &= ~(MASK << shifts);
-            data[dataIndexA] |= (i << shifts);
-            if(dataIndexA != dataIndexB && !ALIGNED) {
-                int leftBits = BITS - (INT_BITS - shifts);
-                int mask = (1 << leftBits) - 1;
-                data[dataIndexB] &= ~mask;
-                data[dataIndexB] |= (i >> (INT_BITS - shifts));
-            }
-            return *this;
-        }
-
-        operator int() const {
-            return readBits(data, index);
-        }
-    };
-
-    BitArray(int bits = 0) {
-        fill(bits);
-    }
-
-    void fill(int bits) {
-        for(int i = 0; i < N; i++) {
-            (*this)[i] = bits;
-        }
-    }
+    BitArray(int length, int bits);
+    ~BitArray();
+    BitArray(const BitArray& other);
+    BitArray(BitArray&& other);
+    BitArray& operator=(const BitArray& other);
+    BitArray& operator=(BitArray&& other);
 
-    BitInt operator[](int index) {
-        return BitInt(data, index);
-    }
+    BitArray& set(int index, int value);
+    int get(int index) const;
 
-    int operator[](int index) const {
-        return readBits(data, index);
-    }
+    int getLength() const;
+    int getBits() const;
 
-    constexpr int getLength() const {
-        return N;
-    }
+    void resize(int newBits);
 
     template<int L>
     void toString(StringBuffer<L>& s) const {
         s.append("[");
-        for(int i = 0; i < N - 1; i++) {
-            s.append((*this)[i]);
+        for(int i = 0; i < length - 1; i++) {
+            s.append(get(i));
             s.append(", ");
         }
-        if(N > 0) {
-            s.append((*this)[N - 1]);
+        if(length > 0) {
+            s.append(get(length - 1));
         }
         s.append("]");
     }
+
+private:
+    void copyData(const BitArray& other);
+    void reset();
 };
 
 #endif