#ifndef HASHMAP_H
#define HASHMAP_H

const static int PRIMES[26] = 
{
    17, 37, 79, 163, 331, 673, 1361, 2729, 5471, 10949, 21911, 43853, 87719, 
    175447, 350899, 701819, 1403641, 2807303, 5614657, 11229331, 22458671,
    44917381, 89834777, 179669557, 359339171, 718678369
};

template<class K, class V> 
class HashMap
{
    class Node
    {
    public:
        Node* next;
        K k;
        V v;
        
        Node(K k, V v)
        {
            next = nullptr;
            this->k = k;
            this->v = v;
        }
    };
    
private:
    // hasher, key comparison
    int (*hasher)(K); 
    bool (*equal)(K, K);
    
    // size specs
    int primeIndex;
    int capacity;
    int resizeCap;
    
    // entries
    Node** data;
    int entries;
    
    // found node
    K lastKey;
    Node* found;
    int lastIndex;
    
    int getHigherPrimeIndex(int lower) const
    {
        int low = 0;
        int high = 25;
        int mid;
        while(true)
        {
            if(low == high)
            {
                return low;
            }
            mid = (high + low) >> 1;
            if(PRIMES[mid] >= lower)
            {
                high = mid;
            }
            else
            {
                low = mid + 1;
            }
        }
    }
    
    int getHash(K k) const
    {
        int hash = ((*hasher)(k)) % capacity;
        return hash < 0 ? hash + capacity : hash;
    }
    
    void ensureCapacity()
    {
        if(entries < resizeCap)
        {
            return;
        }
        
        primeIndex++;
        if(primeIndex >= 26)
        {
            resizeCap = 2147483647;
            return;
        }
        
        int oldCapacity = capacity;
        capacity = PRIMES[primeIndex];
        
        resizeCap = (capacity >> 2) * 3;
        
        Node** newData = new Node*[capacity];
        Node* n;
        Node* m;
        
        int hash;
        for(int i = 0; i < oldCapacity; i++)
        {
            Node* old = data[i];
            if(old != nullptr)
            {
                hash = getHash(old->k);
                n = newData[hash];
                if(n == nullptr)
                {
                    newData[hash] = old;
                }
                else
                {
                    while(n->next != nullptr)
                    {
                        n = n->next;
                    }
                    n->next = old;
                }
                
                while(old->next != nullptr)
                {
                    n = old->next;
                    old->next = nullptr;
                    
                    hash = getHash(n->k);
                    m = newData[hash];
                    if(m == nullptr)
                    {
                        newData[hash] = n;
                    }
                    else
                    {
                        while(m->next != nullptr)
                        {
                            m = m->next;
                        }
                        m->next = n;
                    }
                    
                    old = n;
                }
            }
        }
        
        delete[] data;
        data = newData;
    }

public:
    HashMap(int initialLoad, int (*hasher)(K), bool (*equal)(K, K))
    {
        this->hasher = hasher;
        this->equal = equal;
        this->hasher = hasher;
        
        primeIndex = getHigherPrimeIndex(initialLoad);
        capacity = PRIMES[primeIndex];
        resizeCap = (capacity >> 2) * 3;
        
        data = new Node*[capacity];
        for(int i = 0; i < capacity; i++)
        {
            data[i] = nullptr;
        }
        entries = 0;
        
        found = nullptr;
        lastIndex = -1;
    }
    
    virtual ~HashMap()
    {
        for(int i = 0; i < capacity; i++)
        {
            Node* n = data[i];
            while(n != nullptr)
            {
                Node* next = n->next;
                delete n;
                n = next;
            }
        }
        delete[] data;
    }
    
    void search(K k)
    {
        lastKey = k;
        int hash = getHash(k);
        Node* n = data[hash];
        if(n == nullptr) // free slot, nothing found
        {
            found = nullptr;
            // mark index for inserting
            lastIndex = hash;
        }
        else
        {
            while(true)
            {
                if((*equal)(k, n->k)) // key was found
                {
                    // mark existing node for overwrite / deleting
                    found = n;
                    lastIndex = hash;
                    return;
                }
                if(n->next == nullptr)
                {
                    break;
                }
                n = n->next;
            }
            // nothing found, mark last node for adding
            found = n;
            lastIndex = -1;
        }
    }
    
    void print() const
    {
        for(int i = 0; i < capacity; i++)
        {
            Node* n = data[i];
            if(n != nullptr)
            {
                std::cout << n->k << " - " << n->v;
                while(n->next != nullptr)
                {
                    n = n->next;
                    std::cout << ", " << n->k << " - " << n->v;
                }
                std::cout << "\n";
            }
        }
    }
    
    void insert(K k, V v)
    {
        search(k);
        insert(v);
    }
    
    void insert(V v)
    {
        if(found == nullptr)
        {
            if(lastIndex != -1)
            {
                // inserting into empty slot
                data[lastIndex] = new Node(lastKey, v);
                entries++;
                lastIndex = -1;
                ensureCapacity();
            }
        }
        else
        {
            if(lastIndex != -1)
            {
                // overwriting old value
                found->v = v;
            }
            else
            {
                // adding new node to list
                found->next = new Node(lastKey, v);
                entries++;
                ensureCapacity();
            }
            found = nullptr;
        }
    }
    
    void remove()
    {
        if(found != nullptr && lastIndex != -1)
        {
            // search previous node
            Node* n = data[lastIndex];
            if(found == n)
            {
                data[lastIndex] = n->next;
                delete n;
            }
            else
            {
                while(n->next != found)
                {
                    n = n->next;
                }
                n->next = found->next;
                delete found;
            }
            entries--;
            found = nullptr;
        }
    }
    
    bool isFound() const
    {
        return found != nullptr && lastIndex != -1;
    }
    
    V getValue() const
    {
        return found->v;
    }
    
    int getCapacity() const
    {
        return capacity;
    }
    
    int getSize() const
    {
        return entries;
    }
    
    void forEach(void (*w)(K, V)) const
    {
        for(int i = 0; i < capacity; i++)
        {
            Node* n = data[i];
            if(n != nullptr)
            {
                (*w)(n->k, n->v);
                while(n->next != nullptr)
                {
                    n = n->next;
                    (*w)(n->k, n->v);
                }
            }
        }
    }
};

#endif