Browse Source

hashmap is on heap, improved hashmap iterator

Kajetan Johannes Hammerle 3 years ago
parent
commit
c836197dbc
3 changed files with 243 additions and 154 deletions
  1. 98 36
      tests/HashMapTests.cpp
  2. 35 14
      tests/StringBufferTests.cpp
  3. 110 104
      utils/HashMap.h

+ 98 - 36
tests/HashMapTests.cpp

@@ -3,15 +3,18 @@
 #include "utils/HashMap.h"
 #include "utils/StringBuffer.h"
 
-constexpr int MAP_MIN_CAPACITY = 5;
-typedef HashMap<int, int, MAP_MIN_CAPACITY> IntMap;
+typedef HashMap<int, int> IntMap;
 typedef StringBuffer<50> String;
 
 static void testAdd(Test& test) {
+    (void)test;
     IntMap map;
     map.add(5, 4);
-    test.checkEqual(true, map.contains(5), "contains added value");
-    test.checkEqual(4, map.search(5, -1), "search finds added value");
+    int* value = map.search(5);
+    test.checkEqual(true, value != nullptr, "contains added value");
+    if(value != nullptr) {
+        test.checkEqual(4, *value, "search finds added value");
+    }
 }
 
 static void testMultipleAdd(Test& test) {
@@ -20,28 +23,43 @@ static void testMultipleAdd(Test& test) {
     test.checkEqual(true, map.contains(5), "contains added value 1");
     test.checkEqual(true, map.contains(10), "contains added value 2");
     test.checkEqual(true, map.contains(15), "contains added value 3");
-    test.checkEqual(4, map.search(5, -1), "search finds added value 1");
-    test.checkEqual(3, map.search(10, -1), "search finds added value 2");
-    test.checkEqual(2, map.search(15, -1), "search finds added value 3");
+    int* a = map.search(5);
+    int* b = map.search(10);
+    int* c = map.search(15);
+    test.checkEqual(true, a != nullptr, "contains added value 1");
+    test.checkEqual(true, b != nullptr, "contains added value 2");
+    test.checkEqual(true, c != nullptr, "contains added value 3");
+    if(a != nullptr && b != nullptr && c != nullptr) {
+        test.checkEqual(4, *a, "search finds added value 1");
+        test.checkEqual(3, *b, "search finds added value 2");
+        test.checkEqual(2, *c, "search finds added value 3");
+    }
 }
 
 static void testSearch(Test& test) {
     IntMap map;
-    test.checkEqual(-1, map.search(6, -1), "search does not find missing key");
+    test.checkEqual(true, nullptr == map.search(6),
+                    "search does not find missing key");
 }
 
 static void testAddReplace(Test& test) {
     IntMap map;
     map.add(5, 4).add(5, 10);
     test.checkEqual(true, map.contains(5), "contains replaced value");
-    test.checkEqual(10, map.search(5, -1), "search finds replaced value");
+    int* a = map.search(5);
+    test.checkEqual(true, a != nullptr, "contains replaced value");
+    if(a != nullptr) {
+        test.checkEqual(10, *a, "search finds replaced value");
+    }
 }
 
 static void testClear(Test& test) {
     IntMap map;
     map.add(5, 4).add(4, 10).clear();
-    test.checkEqual(false, map.contains(5), "does not contain cleared values 1");
-    test.checkEqual(false, map.contains(4), "does not contain cleared values 2");
+    test.checkEqual(false, map.contains(5),
+                    "does not contain cleared values 1");
+    test.checkEqual(false, map.contains(4),
+                    "does not contain cleared values 2");
 }
 
 static void testOverflow(Test& test) {
@@ -49,8 +67,9 @@ static void testOverflow(Test& test) {
     for(int i = 0; i < 1000000; i++) {
         map.add(i, i);
     }
-    for(int i = 0; i < MAP_MIN_CAPACITY; i++) {
-        test.checkEqual(true, map.contains(i), "still contains values after overflow");
+    for(int i = 0; i < 1000000; i++) {
+        test.checkEqual(true, map.contains(i),
+                        "still contains values after overflow");
     }
     test.checkEqual(true, true, "survives overflow");
 }
@@ -72,7 +91,7 @@ std::ostream& operator<<(std::ostream& os, const A& a) {
 }
 
 static void testEmplace(Test& test) {
-    HashMap<int, A, 5> map;
+    HashMap<int, A> map;
 
     bool r1 = map.tryEmplace(0, 3, 4);
     bool r2 = map.tryEmplace(3, 4, 5);
@@ -80,9 +99,19 @@ static void testEmplace(Test& test) {
     bool r4 = map.tryEmplace(3, 6, 7);
     bool r5 = map.tryEmplace(20, 7, 8);
 
-    test.checkEqual(A(3, 4), map.search(0, A(0, 0)), "contains emplaced value 1");
-    test.checkEqual(A(4, 5), map.search(3, A(0, 0)), "contains emplaced value 2");
-    test.checkEqual(A(5, 6), map.search(20, A(0, 0)), "contains emplaced value 3");
+    A* a = map.search(0);
+    A* b = map.search(3);
+    A* c = map.search(20);
+
+    test.checkEqual(true, a != nullptr, "contains emplaced value 1");
+    test.checkEqual(true, b != nullptr, "contains emplaced value 2");
+    test.checkEqual(true, c != nullptr, "contains emplaced value 3");
+
+    if(a != nullptr && b != nullptr && c != nullptr) {
+        test.checkEqual(A(3, 4), *a, "contains emplaced value 1");
+        test.checkEqual(A(4, 5), *b, "contains emplaced value 2");
+        test.checkEqual(A(5, 6), *c, "contains emplaced value 3");
+    }
 
     test.checkEqual(false, r1, "emplacing returns correct value 1");
     test.checkEqual(false, r2, "emplacing returns correct value 2");
@@ -94,7 +123,8 @@ static void testEmplace(Test& test) {
 static void testToString1(Test& test) {
     IntMap map;
     map.add(1, 3).add(2, 4).add(3, 5);
-    test.checkEqual(String("[1 = 3, 2 = 4, 3 = 5]"), String(map), "to string 1");
+    test.checkEqual(String("[1 = 3, 2 = 4, 3 = 5]"), String(map),
+                    "to string 1");
 }
 
 static void testToString2(Test& test) {
@@ -112,9 +142,16 @@ static void testCopy(Test& test) {
     IntMap map;
     map.add(1, 3).add(2, 4).add(3, 5);
     IntMap copy(map);
-    test.checkEqual(map.search(1, 0), copy.search(1, -1), "copy has same values 1");
-    test.checkEqual(map.search(2, 0), copy.search(2, -1), "copy has same values 2");
-    test.checkEqual(map.search(3, 0), copy.search(3, -1), "copy has same values 3");
+
+    int* a[6] = {map.search(1),  map.search(2),  map.search(3),
+                 copy.search(1), copy.search(2), copy.search(3)};
+    for(int i = 0; i < 3; i++) {
+        test.checkEqual(true, a[i] != nullptr && a[i + 3] != nullptr,
+                        "copy has same values");
+        if(a[i] != nullptr && a[i + 3] != nullptr) {
+            test.checkEqual(*(a[i]), *(a[i + 3]), "copy has same values");
+        }
+    }
 }
 
 static void testCopyAssignment(Test& test) {
@@ -122,34 +159,59 @@ static void testCopyAssignment(Test& test) {
     map.add(1, 3).add(2, 4).add(3, 5);
     IntMap copy;
     copy = map;
-    test.checkEqual(map.search(1, 0), copy.search(1, -1), "copy assignment has same values 1");
-    test.checkEqual(map.search(2, 0), copy.search(2, -1), "copy assignment has same values 2");
-    test.checkEqual(map.search(3, 0), copy.search(3, -1), "copy assignment has same values 3");
+
+    int* a[6] = {map.search(1),  map.search(2),  map.search(3),
+                 copy.search(1), copy.search(2), copy.search(3)};
+    for(int i = 0; i < 3; i++) {
+        test.checkEqual(true, a[i] != nullptr && a[i + 3] != nullptr,
+                        "copy assignment has same values");
+        if(a[i] != nullptr && a[i + 3] != nullptr) {
+            test.checkEqual(*(a[i]), *(a[i + 3]),
+                            "copy assignment has same values");
+        }
+    }
 }
 
 static void testMove(Test& test) {
     IntMap map;
     map.add(1, 3).add(2, 4).add(3, 5);
     IntMap move(std::move(map));
-    test.checkEqual(3, move.search(1, -1), "move moves values 1");
-    test.checkEqual(4, move.search(2, -1), "move moves values 2");
-    test.checkEqual(5, move.search(3, -1), "move moves values 3");
-    test.checkEqual(-1, map.search(1, -1), "move removes values 1");
-    test.checkEqual(-1, map.search(2, -1), "move removes values 2");
-    test.checkEqual(-1, map.search(3, -1), "move removes values 3");
+
+    int* a = move.search(1);
+    int* b = move.search(2);
+    int* c = move.search(3);
+
+    test.checkEqual(true, a != nullptr, "move moves values 1");
+    test.checkEqual(true, b != nullptr, "move moves values 2");
+    test.checkEqual(true, c != nullptr, "move moves values 3");
+
+    if(a != nullptr && b != nullptr && c != nullptr) {
+        test.checkEqual(3, *a, "move moves values 1");
+        test.checkEqual(4, *b, "move moves values 2");
+        test.checkEqual(5, *c, "move moves values 3");
+    }
 }
 
 static void testMoveAssignment(Test& test) {
     IntMap map;
     map.add(1, 3).add(2, 4).add(3, 5);
+
     IntMap move;
     move = std::move(map);
-    test.checkEqual(3, move.search(1, -1), "move assignment moves values 1");
-    test.checkEqual(4, move.search(2, -1), "move assignment moves values 2");
-    test.checkEqual(5, move.search(3, -1), "move assignment moves values 3");
-    test.checkEqual(-1, map.search(1, -1), "move assignment removes values 1");
-    test.checkEqual(-1, map.search(2, -1), "move assignment removes values 2");
-    test.checkEqual(-1, map.search(3, -1), "move assignment removes values 3");
+
+    int* a = move.search(1);
+    int* b = move.search(2);
+    int* c = move.search(3);
+
+    test.checkEqual(true, a != nullptr, "move moves values 1");
+    test.checkEqual(true, b != nullptr, "move moves values 2");
+    test.checkEqual(true, c != nullptr, "move moves values 3");
+
+    if(a != nullptr && b != nullptr && c != nullptr) {
+        test.checkEqual(3, *a, "move moves values 1");
+        test.checkEqual(4, *b, "move moves values 2");
+        test.checkEqual(5, *c, "move moves values 3");
+    }
 }
 
 void HashMapTests::test() {

+ 35 - 14
tests/StringBufferTests.cpp

@@ -1,7 +1,7 @@
 #include "tests/StringBufferTests.h"
 #include "tests/Test.h"
-#include "utils/StringBuffer.h"
 #include "utils/HashMap.h"
+#include "utils/StringBuffer.h"
 
 typedef StringBuffer<20> String;
 
@@ -10,16 +10,19 @@ static void testEquality(Test& test) {
     test.checkEqual(true, s == "test", "equality with c-string");
     test.checkEqual(true, s == String("test"), "equality with another string");
     test.checkEqual(true, "test" == s, "inverse equality with c-string");
-    test.checkEqual(true, String("test") == s, "inverse equality with another string");
+    test.checkEqual(true, String("test") == s,
+                    "inverse equality with another string");
     test.checkEqual(true, s == s, "equality with itself");
 }
 
 static void testInequality(Test& test) {
     String s("test");
     test.checkEqual(false, s != "test", "inequality with c-string");
-    test.checkEqual(false, s != String("test"), "inequality with another string");
+    test.checkEqual(false, s != String("test"),
+                    "inequality with another string");
     test.checkEqual(false, "test" != s, "inverse inequality with c-string");
-    test.checkEqual(false, String("test") != s, "inverse inequality with another string");
+    test.checkEqual(false, String("test") != s,
+                    "inverse inequality with another string");
     test.checkEqual(false, s != s, "inequality with itself");
 }
 
@@ -32,8 +35,10 @@ static void testStringAppend(Test& test) {
 static void testStringAppendOverflow(Test& test) {
     StringBuffer<5> s("te");
     s.append("2").append("333").append("4444");
-    test.checkEqual(StringBuffer<5>("te23"), s, "multiple appends with overflow");
-    test.checkEqual(4, s.getLength(), "length after multiple appends with overflow");
+    test.checkEqual(StringBuffer<5>("te23"), s,
+                    "multiple appends with overflow");
+    test.checkEqual(4, s.getLength(),
+                    "length after multiple appends with overflow");
 }
 
 static void testCharacters(Test& test) {
@@ -100,18 +105,34 @@ static void testAppendChar(Test& test) {
 static void testHashCode(Test& test) {
     String s;
     s.append("a").append("bc").append(20).append(25.5f).append(true);
-    test.checkEqual(String("abc2025.50true").hashCode(), s.hashCode(), "string modification recalculates hash 1");
+    test.checkEqual(String("abc2025.50true").hashCode(), s.hashCode(),
+                    "string modification recalculates hash 1");
     s.clear();
-    test.checkEqual(String().hashCode(), s.hashCode(), "string modification recalculates hash 2");
+    test.checkEqual(String().hashCode(), s.hashCode(),
+                    "string modification recalculates hash 2");
 }
 
 static void testAsHashMapKey(Test& test) {
-    HashMap<String, int, 5> map;
-    map.add(String("wusi"), 3).add(String("hiThere"), 7).add(String("baum123"), 5);
-    test.checkEqual(3, map.search(String("wusi"), 0), "strings works as hash key 1");
-    test.checkEqual(7, map.search(String("hiThere"), 0), "strings works as hash key 2");
-    test.checkEqual(5, map.search(String("baum123"), 0), "strings works as hash key 3");
-    test.checkEqual(0, map.search(String("423hifd"), 0), "strings works as hash key 4");
+    HashMap<String, int> map;
+    map.add(String("wusi"), 3)
+        .add(String("hiThere"), 7)
+        .add(String("baum123"), 5);
+
+    int* a = map.search(String("wusi"));
+    int* b = map.search(String("hiThere"));
+    int* c = map.search(String("baum123"));
+    int* d = map.search(String("423hifd"));
+
+    test.checkEqual(true, a != nullptr, "strings works as hash key 1");
+    test.checkEqual(true, b != nullptr, "strings works as hash key 2");
+    test.checkEqual(true, c != nullptr, "strings works as hash key 3");
+    test.checkEqual(true, d == nullptr, "strings works as hash key 4");
+
+    if(a != nullptr && b != nullptr && c != nullptr) {
+        test.checkEqual(3, *a, "strings works as hash key 1");
+        test.checkEqual(7, *b, "strings works as hash key 2");
+        test.checkEqual(5, *c, "strings works as hash key 3");
+    }
 }
 
 void StringBufferTests::test() {

+ 110 - 104
utils/HashMap.h

@@ -7,167 +7,173 @@
 #include "utils/Types.h"
 #include "utils/Utils.h"
 
-template<typename K, typename V, int N_MIN>
-class HashMap final {
-    static constexpr int CAPACITY = 1 << Utils::roundUpLog2(N_MIN);
-    static constexpr int MASK = CAPACITY - 1;
+#include <unordered_map>
 
-    Array<int, CAPACITY> used;
-    List<K> keys;
-    List<V> values;
+template<typename K, typename V>
+struct HashMap final {
+    struct Node {
+        friend HashMap;
+        friend List<Node>;
 
-    enum SearchResult { FREE_INDEX_FOUND, KEY_FOUND, NOTHING_FOUND };
+        const K key;
+        V value;
 
-    struct Search {
-        int index;
-        SearchResult result;
+    private:
+        int next;
 
-        Search(int index, SearchResult result) : index(index), result(result) {
+        Node(const K& key, const V& value) : key(key), value(value), next(-1) {
         }
-    };
 
-    Search searchIndex(const K& key) const {
-        int base = hash(key);
-        for(int i = 0; i < CAPACITY; i++) {
-            int h = (base + i) & MASK;
-            if(used[h] == -1) {
-                return Search(h, FREE_INDEX_FOUND);
-            } else if(keys[used[h]] == key) {
-                return Search(h, KEY_FOUND);
-            }
+        Node(const K& key, V&& value)
+            : key(key), value(std::move(value)), next(-1) {
         }
-        return Search(-1, NOTHING_FOUND);
-    }
-
-    template<typename H>
-    static Hash hash(const H& key) {
-        return key.hashCode();
-    }
-
-    static Hash hash(int key) {
-        return key;
-    }
-
-public:
-    HashMap() : used(-1) {
-    }
-
-    HashMap(const HashMap& other)
-        : used(other.used), keys(other.keys), values(other.values) {
-    }
 
-    HashMap& operator=(const HashMap& other) {
-        if(&other != this) {
-            used = other.used;
-            keys = other.keys;
-            values = other.values;
+        template<typename... Args>
+        Node(const K& key, Args&&... args)
+            : key(key), value(std::forward<Args>(args)...), next(-1) {
         }
-        return *this;
-    }
+    };
 
-    HashMap(HashMap&& other)
-        : used(other.used), keys(std::move(other.keys)),
-          values(std::move(other.values)) {
-        other.used.fill(-1);
-    }
+private:
+    List<int> nodePointers;
+    List<Node> nodes;
 
-    HashMap& operator=(HashMap&& other) {
-        if(&other != this) {
-            used = std::move(other.used);
-            keys = std::move(other.keys);
-            values = std::move(other.values);
-            other.used.fill(-1);
-        }
-        return *this;
+public:
+    HashMap(int minCapacity = 8) {
+        nodePointers.resize(1 << Utils::roundUpLog2(minCapacity), -1);
     }
 
     template<typename... Args>
     bool tryEmplace(const K& key, Args&&... args) {
-        Search s = searchIndex(key);
-        if(s.result == FREE_INDEX_FOUND) {
-            used[s.index] = keys.getLength();
-            keys.add(key);
-            values.add(std::forward<Args>(args)...);
+        int pointer = prepareAdd(key);
+        if(pointer == -1) {
+            nodes.add(key, std::forward<Args>(args)...);
             return false;
         }
         return true;
     }
 
     HashMap& add(const K& key, const V& value) {
-        Search s = searchIndex(key);
-        if(s.result == KEY_FOUND) {
-            values[used[s.index]] = value;
-        } else if(s.result == FREE_INDEX_FOUND) {
-            used[s.index] = keys.getLength();
-            keys.add(key);
-            values.add(value);
+        int pointer = prepareAdd(key);
+        if(pointer != -1) {
+            nodes[pointer].value = value;
+        } else {
+            nodes.add(key, value);
         }
         return *this;
     }
 
-    HashMap& add(const K& key, const V&& value) {
-        Search s = searchIndex(key);
-        if(s.result == KEY_FOUND) {
-            values[used[s.index]] = std::move(value);
-        } else if(s.result == FREE_INDEX_FOUND) {
-            used[s.index] = keys.getLength();
-            keys.add(key);
-            values.add(std::move(value));
+    HashMap& add(const K& key, V&& value) {
+        int pointer = prepareAdd(key);
+        if(pointer != -1) {
+            nodes[pointer].value = std::move(value);
+        } else {
+            nodes.add(key, std::move(value));
         }
         return *this;
     }
 
-    const V& search(const K& key, const V& notFound) const {
-        Search s = searchIndex(key);
-        return s.result == KEY_FOUND ? values[used[s.index]] : notFound;
+    const V* search(const K& key) const {
+        Hash h = hash(key);
+        int pointer = nodePointers[h];
+        while(pointer != -1) {
+            if(nodes[pointer].key == key) {
+                return &(nodes[pointer].value);
+            }
+            pointer = nodes[pointer].next;
+        }
+        return nullptr;
     }
 
-    V& search(const K& key, V& notFound) {
-        Search s = searchIndex(key);
-        return s.result == KEY_FOUND ? values[used[s.index]] : notFound;
+    V* search(const K& key) {
+        return const_cast<V*>(static_cast<const HashMap*>(this)->search(key));
     }
 
     bool contains(const K& key) const {
-        return searchIndex(key).result == KEY_FOUND;
+        return search(key) != nullptr;
     }
 
     HashMap& clear() {
-        keys.clear();
-        values.clear();
-        used.fill(-1);
+        for(int& pointer : nodePointers) {
+            pointer = -1;
+        }
+        nodes.clear();
         return *this;
     }
 
+    Node* begin() {
+        return nodes.begin();
+    }
+
+    Node* end() {
+        return nodes.end();
+    }
+
+    const Node* begin() const {
+        return nodes.begin();
+    }
+
+    const Node* end() const {
+        return nodes.end();
+    }
+
     template<int L>
     void toString(StringBuffer<L>& s) const {
         s.append("[");
         bool c = false;
-        for(int i = 0; i < CAPACITY; i++) {
-            if(used[i] == -1) {
-                continue;
-            } else if(c) {
+        for(const Node& n : nodes) {
+            if(c) {
                 s.append(", ");
             }
-            s.append(keys[used[i]]).append(" = ").append(values[used[i]]);
+            s.append(n.key).append(" = ").append(n.value);
             c = true;
         }
         s.append("]");
     }
 
-    V* begin() {
-        return values.begin();
+private:
+    template<typename H>
+    Hash hash(const H& key) const {
+        return fullHash(key) & (nodePointers.getLength() - 1);
     }
 
-    V* end() {
-        return values.end();
+    template<typename H>
+    Hash fullHash(const H& key) const {
+        return key.hashCode();
     }
 
-    const V* begin() const {
-        return values.begin();
+    Hash fullHash(int key) const {
+        return key;
     }
 
-    const V* end() const {
-        return values.end();
+    void rehash() {
+        if(nodes.getLength() < nodePointers.getLength()) {
+            return;
+        }
+        HashMap<K, V> map(nodePointers.getLength() * 2);
+        for(const Node& n : nodes) {
+            map.add(n.key, std::move(n.value));
+        }
+        *this = std::move(map);
+    }
+
+    int prepareAdd(const K& key) {
+        rehash();
+        Hash h = hash(key);
+        int pointer = nodePointers[h];
+        if(pointer == -1) {
+            nodePointers[h] = nodes.getLength();
+            return -1;
+        }
+        while(true) {
+            if(nodes[pointer].key == key) {
+                return pointer;
+            } else if(nodes[pointer].next == -1) {
+                nodes[pointer].next = nodes.getLength();
+                return -1;
+            }
+            pointer = nodes[pointer].next;
+        }
     }
 };