Procházet zdrojové kódy

try emplace for hash map, refactored search function for hash map

Kajetan Johannes Hammerle před 3 roky
rodič
revize
8e2270f9b9
2 změnil soubory, kde provedl 87 přidání a 28 odebrání
  1. 37 0
      tests/HashMapTests.cpp
  2. 50 28
      utils/HashMap.h

+ 37 - 0
tests/HashMapTests.cpp

@@ -59,6 +59,42 @@ static void testOverflow(Test& test) {
     test.checkEqual(true, true, "survives overflow");
 }
 
+struct A {
+    int a;
+    int b;
+
+    A(int a, int b) : a(a), b(b) {
+    }
+
+    bool operator==(const A& other) const {
+        return a == other.a && b == other.b;
+    }
+};
+
+std::ostream& operator<<(std::ostream& os, const A& a) {
+    return os << "A(" << a.a << ", " << a.b << ")";
+}
+
+static void testEmplace(Test& test) {
+    HashMap<int, A, 5> map;
+    
+    bool r1 = map.tryEmplace(0, 3, 4);
+    bool r2 = map.tryEmplace(3, 4, 5);
+    bool r3 = map.tryEmplace(20, 5, 6);
+    bool r4 = map.tryEmplace(3, 6, 7);
+    bool r5 = map.tryEmplace(20, 7, 8);
+    
+    test.checkEqual(A(3, 4), map.search(0, A(0, 0)), "contains emplaced value 1");
+    test.checkEqual(A(4, 5), map.search(3, A(0, 0)), "contains emplaced value 2");
+    test.checkEqual(A(5, 6), map.search(20, A(0, 0)), "contains emplaced value 3");
+
+    test.checkEqual(false, r1, "emplacing returns correct value 1");
+    test.checkEqual(false, r2, "emplacing returns correct value 2");
+    test.checkEqual(false, r3, "emplacing returns correct value 3");
+    test.checkEqual(true, r4, "emplacing returns correct value 4");
+    test.checkEqual(true, r5, "emplacing returns correct value 5");
+}
+
 void HashMapTests::test() {
     Test test("HashMap");
     testAdd(test);
@@ -67,5 +103,6 @@ void HashMapTests::test() {
     testAddReplace(test);
     testClear(test);
     testOverflow(test);
+    testEmplace(test);
     test.finalize();
 }

+ 50 - 28
utils/HashMap.h

@@ -35,17 +35,29 @@ class HashMap final {
         return reinterpret_cast<const V*> (values)[index];
     }
 
-    int searchIndex(const K& key) const {
+    enum SearchResult {
+        FREE_INDEX_FOUND, KEY_FOUND, NOTHING_FOUND
+    };
+
+    struct Search {
+        int index;
+        SearchResult result;
+
+        Search(int index, SearchResult result) : index(index), result(result) {
+        }
+    };
+
+    Search searchIndex(const K& key) const {
         int base = hash(key);
         for(int i = 0; i < CAPACITY; i++) {
             int h = (base + i) & MASK;
             if(!used[h]) {
-                return -h;
+                return Search(h, FREE_INDEX_FOUND);
             } else if(getKey(h) == key) {
-                return h;
+                return Search(h, KEY_FOUND);
             }
         }
-        return -CAPACITY;
+        return Search(-1, NOTHING_FOUND);
     }
 
     void copy(const HashMap& other) {
@@ -57,7 +69,7 @@ class HashMap final {
             }
         }
     }
-    
+
     void move(HashMap& other) {
         for(int i = 0; i < other.CAPACITY; i++) {
             if(other.used[i]) {
@@ -87,53 +99,63 @@ public:
         copy(other);
         return *this;
     }
-    
+
     HashMap(HashMap&& other) : used(false) {
         move(other);
     }
-    
+
     HashMap& operator=(HashMap&& other) {
         clear();
         move(other);
         return *this;
     }
 
+    template<typename... Args>
+    bool tryEmplace(const K& key, Args&&... args) {
+        Search s = searchIndex(key);
+        if(s.result == FREE_INDEX_FOUND) {
+            used[s.index] = true;
+            new (reinterpret_cast<K*> (keys) + s.index) K(key);
+            new (reinterpret_cast<V*> (values) + s.index) V(args...);
+            return false;
+        }
+        return true;
+    }
+
     void add(const K& key, const V& value) {
-        int index = searchIndex(key);
-        if(index >= 0) {
-            getValue(index) = value;
-        } else if(index > -CAPACITY) {
-            index = -index;
-            used[index] = true;
-            new (reinterpret_cast<K*> (keys) + index) K(key);
-            new (reinterpret_cast<V*> (values) + index) V(value);
+        Search s = searchIndex(key);
+        if(s.result == KEY_FOUND) {
+            getValue(s.index) = value;
+        } else if(s.result == FREE_INDEX_FOUND) {
+            used[s.index] = true;
+            new (reinterpret_cast<K*> (keys) + s.index) K(key);
+            new (reinterpret_cast<V*> (values) + s.index) V(value);
         }
     }
 
     void add(const K& key, const V&& value) {
-        int index = searchIndex(key);
-        if(index >= 0) {
-            getValue(index) = std::move(value);
-        } else if(index > -CAPACITY) {
-            index = -index;
-            used[index] = true;
-            new (reinterpret_cast<K*> (keys) + index) K(key);
-            new (reinterpret_cast<V*> (values) + index) V(std::move(value));
+        Search s = searchIndex(key);
+        if(s.result == KEY_FOUND) {
+            getValue(s.index) = std::move(value);
+        } else if(s.result == FREE_INDEX_FOUND) {
+            used[s.index] = true;
+            new (reinterpret_cast<K*> (keys) + s.index) K(key);
+            new (reinterpret_cast<V*> (values) + s.index) V(std::move(value));
         }
     }
 
     const V& search(const K& key, const V& notFound) const {
-        int index = searchIndex(key);
-        return index < 0 ? notFound : getValue(index);
+        Search s = searchIndex(key);
+        return s.result == KEY_FOUND ? getValue(s.index) : notFound;
     }
 
     V& search(const K& key, V& notFound) {
-        int index = searchIndex(key);
-        return index < 0 ? notFound : getValue(index);
+        Search s = searchIndex(key);
+        return s.result == KEY_FOUND ? getValue(s.index) : notFound;
     }
 
     bool contains(const K& key) const {
-        return searchIndex(key) >= 0;
+        return searchIndex(key).result == KEY_FOUND;
     }
 
     void clear() {