#ifndef HASHMAP_H
#define HASHMAP_H

#include "utils/Array.h"
#include "utils/List.h"
#include "utils/StringBuffer.h"
#include "utils/Types.h"
#include "utils/Utils.h"

template<typename K, typename V>
struct HashMap final {
    class Node final {
        friend HashMap;
        friend List<Node>;
        K key;

    public:
        V value;

        const K& getKey() const {
            return key;
        }

    private:
        int next;

        Node(const K& key, const V& value) : key(key), value(value), next(-1) {
        }

        Node(const K& key, V&& value)
            : key(key), value(std::move(value)), next(-1) {
        }

        template<typename... Args>
        Node(const K& key, Args&&... args)
            : key(key), value(std::forward<Args>(args)...), next(-1) {
        }
    };

    template<typename N, typename R>
    class BaseEntryIterator final {
        N& nodes;
        int indexA;
        int indexB;

    public:
        BaseEntryIterator(N& nodes, int indexA, int indexB)
            : nodes(nodes), indexA(indexA), indexB(indexB) {
            skip();
        }

        BaseEntryIterator& operator++() {
            indexB++;
            skip();
            return *this;
        }

        bool operator!=(const BaseEntryIterator& other) const {
            return indexA != other.indexA || indexB != other.indexB;
        }

        R& operator*() {
            return nodes[indexA][indexB];
        }

    private:
        void skip() {
            while(indexA < nodes.getLength() &&
                  indexB >= nodes[indexA].getLength()) {
                indexA++;
                indexB = 0;
            }
        }
    };

    typedef BaseEntryIterator<List<List<Node>>, Node> EntryIterator;
    typedef BaseEntryIterator<const List<List<Node>>, const Node>
        ConstEntryIterator;

    struct EntryIteratorAdapter final {
        HashMap& map;

        EntryIterator begin() {
            return EntryIterator(map.nodes, 0, 0);
        }

        EntryIterator end() {
            return EntryIterator(map.nodes, map.nodes.getLength(), 0);
        }
    };

    struct ConstEntryIteratorAdapter final {
        const HashMap& map;

        ConstEntryIterator begin() const {
            return ConstEntryIterator(map.nodes, 0, 0);
        }

        ConstEntryIterator end() const {
            return ConstEntryIterator(map.nodes, map.nodes.getLength(), 0);
        }
    };

    template<typename N, typename R>
    class BaseValueIterator final {
        N& nodes;
        int indexA;
        int indexB;

    public:
        BaseValueIterator(N& nodes, int indexA, int indexB)
            : nodes(nodes), indexA(indexA), indexB(indexB) {
            skip();
        }

        BaseValueIterator& operator++() {
            indexB++;
            skip();
            return *this;
        }

        bool operator!=(const BaseValueIterator& other) const {
            return indexA != other.indexA || indexB != other.indexB;
        }

        R& operator*() {
            return nodes[indexA][indexB].value;
        }

    private:
        void skip() {
            while(indexA < nodes.getLength() &&
                  indexB >= nodes[indexA].getLength()) {
                indexA++;
                indexB = 0;
            }
        }
    };

    typedef BaseValueIterator<List<List<Node>>, V> ValueIterator;
    typedef BaseValueIterator<const List<List<Node>>, const V>
        ConstValueIterator;

    struct ValueIteratorAdapter final {
        HashMap& map;

        ValueIterator begin() {
            return ValueIterator(map.nodes, 0, 0);
        }

        ValueIterator end() {
            return ValueIterator(map.nodes, map.nodes.getLength(), 0);
        }
    };

    struct ConstValueIteratorAdapter final {
        const HashMap& map;

        ConstValueIterator begin() const {
            return ConstValueIterator(map.nodes, 0, 0);
        }

        ConstValueIterator end() const {
            return ConstValueIterator(map.nodes, map.nodes.getLength(), 0);
        }
    };

    class ConstKeyIterator final {
        const List<List<Node>>& nodes;
        int indexA;
        int indexB;

    public:
        ConstKeyIterator(const List<List<Node>>& nodes, int indexA, int indexB)
            : nodes(nodes), indexA(indexA), indexB(indexB) {
            skip();
        }

        ConstKeyIterator& operator++() {
            indexB++;
            skip();
            return *this;
        }

        bool operator!=(const ConstKeyIterator& other) const {
            return indexA != other.indexA || indexB != other.indexB;
        }

        const K& operator*() {
            return nodes[indexA][indexB].getKey();
        }

    private:
        void skip() {
            while(indexA < nodes.getLength() &&
                  indexB >= nodes[indexA].getLength()) {
                indexA++;
                indexB = 0;
            }
        }
    };

    struct ConstKeyIteratorAdapter final {
        const HashMap& map;

        ConstKeyIterator begin() const {
            return ConstKeyIterator(map.nodes, 0, 0);
        }

        ConstKeyIterator end() const {
            return ConstKeyIterator(map.nodes, map.nodes.getLength(), 0);
        }
    };

private:
    List<List<Node>> nodes;
    int elements;

public:
    HashMap(int minCapacity = 8) : elements(0) {
        nodes.resize(1 << Utils::roundUpLog2(minCapacity));
    }

    template<typename... Args>
    bool tryEmplace(const K& key, Args&&... args) {
        rehash();
        Hash h = hash(key);
        V* v = searchList(key, h);
        if(v == nullptr) {
            nodes[h].add(key, std::forward<Args>(args)...);
            elements++;
            return false;
        }
        return true;
    }

    HashMap& add(const K& key, const V& value) {
        rehash();
        Hash h = hash(key);
        V* v = searchList(key, h);
        if(v == nullptr) {
            nodes[h].add(key, value);
            elements++;
        } else {
            *v = value;
        }
        return *this;
    }

    HashMap& add(const K& key, V&& value) {
        rehash();
        Hash h = hash(key);
        V* v = searchList(key, h);
        if(v == nullptr) {
            nodes[h].add(key, std::move(value));
            elements++;
        } else {
            *v = std::move(value);
        }
        return *this;
    }

    bool remove(const K& key) {
        List<Node>& list = nodes[hash(key)];
        for(int i = 0; i < list.getLength(); i++) {
            if(list[i].key == key) {
                list.removeBySwap(i);
                return true;
            }
        }
        return false;
    }

    const V* search(const K& key) const {
        return searchList(key, hash(key));
    }

    V* search(const K& key) {
        return searchList(key, hash(key));
    }

    bool contains(const K& key) const {
        return search(key) != nullptr;
    }

    HashMap& clear() {
        for(List<Node>& n : nodes) {
            n.clear();
        }
        elements = 0;
        return *this;
    }

    EntryIteratorAdapter entries() {
        return {*this};
    }

    const ConstEntryIteratorAdapter entries() const {
        return {*this};
    }

    const ConstKeyIteratorAdapter keys() const {
        return {*this};
    }

    ValueIteratorAdapter values() {
        return {*this};
    }

    const ConstValueIteratorAdapter values() const {
        return {*this};
    }

    EntryIterator begin() {
        return EntryIterator(nodes, 0, 0);
    }

    EntryIterator end() {
        return EntryIterator(nodes, nodes.getLength(), 0);
    }

    ConstEntryIterator begin() const {
        return ConstEntryIterator(nodes, 0, 0);
    }

    ConstEntryIterator end() const {
        return ConstEntryIterator(nodes, nodes.getLength(), 0);
    }

    template<int L>
    void toString(StringBuffer<L>& s) const {
        s.append("[");
        bool c = false;
        for(const List<Node>& list : nodes) {
            for(const Node& n : list) {
                if(c) {
                    s.append(", ");
                }
                s.append(n.key).append(" = ").append(n.value);
                c = true;
            }
        }
        s.append("]");
    }

private:
    template<typename H>
    Hash hash(const H& key) const {
        return fullHash(key) & (nodes.getLength() - 1);
    }

    template<typename H>
    Hash fullHash(const H& key) const {
        return key.hashCode();
    }

    Hash fullHash(int key) const {
        return key;
    }

    Hash fullHash(unsigned int key) const {
        return key;
    }

    void rehash() {
        if(elements < nodes.getLength()) {
            return;
        }
        HashMap<K, V> map(nodes.getLength() * 2);
        for(List<Node>& list : nodes) {
            for(Node& n : list) {
                map.tryEmplace(n.key, std::move(n.value));
            }
        }
        *this = std::move(map);
    }

    const V* searchList(const K& key, Hash h) const {
        for(const Node& n : nodes[h]) {
            if(n.key == key) {
                return &n.value;
            }
        }
        return nullptr;
    }

    V* searchList(const K& key, Hash h) {
        return const_cast<V*>(
            static_cast<const HashMap*>(this)->searchList(key, h));
    }
};

#endif