#ifndef BITARRAY_H
#define BITARRAY_H

#include <iostream>

#include "utils/StringBuffer.h"

template<int N, int BITS>
class BitArray final {
    static constexpr int INT_BITS = sizeof (int) * 8;

    static_assert(BITS >= 1, "each bit array element must have at least one bit");
    static_assert(BITS <= INT_BITS, "each bit array element can have at most as much bits as an int");

    static constexpr int MASK = (1 << BITS) - 1;
    static constexpr int LENGTH = (N * BITS) / INT_BITS + (((N * BITS) % INT_BITS) > 0);
    static constexpr bool ALIGNED = (INT_BITS % BITS) == 0;

    constexpr static int getDivideBits() {
        int c = 0;
        int i = INT_BITS - 1;
        while(i > 0) {
            i >>= 1;
            c++;
        }
        return c;
    }

    static constexpr int DIVIDE_BITS = getDivideBits();

    int data[LENGTH];

    static int readBits(const int* data, int index) {
        int dataIndexA = (index * BITS) >> DIVIDE_BITS;
        int dataIndexB = ((index + 1) * BITS) >> DIVIDE_BITS;
        int shifts = (index * BITS) & (INT_BITS - 1);
        if(dataIndexA == dataIndexB || ALIGNED) {
            return (data[dataIndexA] >> shifts) & MASK;
        }
        int bitsInA = INT_BITS - shifts;
        int r = (data[dataIndexA] >> shifts) & ((1 << bitsInA) - 1);
        r |= (data[dataIndexB] & ((1 << (BITS - bitsInA)) - 1)) << bitsInA;
        return r;
    }

public:

    class BitInt {
        friend BitArray;

        int* data;
        int index;

        BitInt(int* data, int index) : data(data), index(index) {
        }

        BitInt(const BitInt& other) = default;
        BitInt(BitInt&& other) = default;

    public:

        BitInt& operator=(const BitInt& other) {
            return (*this) = static_cast<int> (other);
        }

        BitInt& operator=(BitInt&& other) {
            return (*this) = static_cast<int> (other);
        }

        BitInt& operator=(int i) {
            i &= MASK;
            int dataIndexA = (index * BITS) >> DIVIDE_BITS;
            int dataIndexB = ((index + 1) * BITS) >> DIVIDE_BITS;
            int shifts = (index * BITS) & (INT_BITS - 1);
            data[dataIndexA] &= ~(MASK << shifts);
            data[dataIndexA] |= (i << shifts);
            if(dataIndexA != dataIndexB && !ALIGNED) {
                int leftBits = BITS - (INT_BITS - shifts);
                int mask = (1 << leftBits) - 1;
                data[dataIndexB] &= ~mask;
                data[dataIndexB] |= (i >> (INT_BITS - shifts));
            }
            return *this;
        }

        operator int() const {
            return readBits(data, index);
        }
    };

    BitArray(int bits = 0) {
        fill(bits);
    }

    void fill(int bits) {
        for(int i = 0; i < N; i++) {
            (*this)[i] = bits;
        }
    }

    BitInt operator[](int index) {
        return BitInt(data, index);
    }

    int operator[](int index) const {
        return readBits(data, index);
    }

    constexpr int getLength() const {
        return N;
    }
    
    template<int L>
    void toString(StringBuffer<L>& s) const {
        s.append("[");
        for(int i = 0; i < N - 1; i++) {
            s.append((*this)[i]);
            s.append(", ");
        }
        if(N > 0) {
            s.append((*this)[N - 1]);
        }
        s.append("]");
    }
};

#endif