Kaynağa Gözat

hash map updated to late object creation

Kajetan Johannes Hammerle 3 yıl önce
ebeveyn
işleme
9ac1b08316
3 değiştirilmiş dosya ile 121 ekleme ve 35 silme
  1. 5 0
      Main.cpp
  2. 20 14
      tests/HashMapTests.cpp
  3. 96 21
      utils/HashMap.h

+ 5 - 0
Main.cpp

@@ -2,6 +2,11 @@
 #include "tests/ListTests.h"
 #include "tests/BitArrayTests.h"
 
+#include "utils/HashMap.h"
+
+#include <unordered_map>
+#include <iostream>
+
 int main() {
     HashMapTests::test();
     ListTests::test();

+ 20 - 14
tests/HashMapTests.cpp

@@ -8,8 +8,8 @@ typedef HashMap<int, int, MAP_MIN_CAPACITY> IntMap;
 static void testAdd(Test& test) {
     IntMap map;
     map.add(5, 4);
-    test.checkEqual(true, map.contains(5), "map contains added value");
-    test.checkEqual(true, map.search(5, -1) == 4, "map search finds added value");
+    test.checkEqual(true, map.contains(5), "contains added value");
+    test.checkEqual(4, map.search(5, -1), "search finds added value");
 }
 
 static void testMultipleAdd(Test& test) {
@@ -17,20 +17,25 @@ static void testMultipleAdd(Test& test) {
     map.add(5, 4);
     map.add(10, 3);
     map.add(15, 2);
-    test.checkEqual(true, map.contains(5), "map contains added value 1");
-    test.checkEqual(true, map.contains(10), "map contains added value 2");
-    test.checkEqual(true, map.contains(15), "map contains added value 3");
-    test.checkEqual(true, map.search(5, -1) == 4, "map search finds added value 1");
-    test.checkEqual(true, map.search(10, -1) == 3, "map search finds added value 2");
-    test.checkEqual(true, map.search(15, -1) == 2, "map search finds added value 3");
+    test.checkEqual(true, map.contains(5), "contains added value 1");
+    test.checkEqual(true, map.contains(10), "contains added value 2");
+    test.checkEqual(true, map.contains(15), "contains added value 3");
+    test.checkEqual(4, map.search(5, -1), "search finds added value 1");
+    test.checkEqual(3, map.search(10, -1), "search finds added value 2");
+    test.checkEqual(2, map.search(15, -1), "search finds added value 3");
+}
+
+static void testSearch(Test& test) {
+    IntMap map;
+    test.checkEqual(-1, map.search(6, -1), "search does not find missing key");
 }
 
 static void testAddReplace(Test& test) {
     IntMap map;
     map.add(5, 4);
     map.add(5, 10);
-    test.checkEqual(true, map.contains(5), "map contains replaced value");
-    test.checkEqual(true, map.search(5, -1) == 10, "map search finds replaced value");
+    test.checkEqual(true, map.contains(5), "contains replaced value");
+    test.checkEqual(10, map.search(5, -1), "search finds replaced value");
 }
 
 static void testClear(Test& test) {
@@ -39,8 +44,8 @@ static void testClear(Test& test) {
     map.add(4, 10);
     map.clear();
 
-    test.checkEqual(false, map.contains(5), "map does not contain cleared values");
-    test.checkEqual(false, map.contains(4), "map does not contain cleared values");
+    test.checkEqual(false, map.contains(5), "does not contain cleared values");
+    test.checkEqual(false, map.contains(4), "does not contain cleared values");
 }
 
 static void testOverflow(Test& test) {
@@ -49,15 +54,16 @@ static void testOverflow(Test& test) {
         map.add(i, i);
     }
     for(int i = 0; i < MAP_MIN_CAPACITY; i++) {
-        test.checkEqual(true, map.contains(i), "map still contains values after overflow");
+        test.checkEqual(true, map.contains(i), "still contains values after overflow");
     }
-    test.checkEqual(true, true, "map survives overflow");
+    test.checkEqual(true, true, "survives overflow");
 }
 
 void HashMapTests::test() {
     Test test("HashMap");
     testAdd(test);
     testMultipleAdd(test);
+    testSearch(test);
     testAddReplace(test);
     testClear(test);
     testOverflow(test);

+ 96 - 21
utils/HashMap.h

@@ -1,10 +1,12 @@
 #ifndef HASHMAP_H
 #define HASHMAP_H
 
+#include <iostream>
+
 #include "utils/Array.h"
 
 template<typename K, typename V, int N_MIN>
-class HashMap {
+class HashMap final {
 
     static constexpr int getCapacity() {
         int i = 1;
@@ -18,20 +20,53 @@ class HashMap {
     static constexpr int MASK = CAPACITY - 1;
 
     Array<bool, CAPACITY> used;
-    Array<K, CAPACITY> keys;
-    Array<V, CAPACITY> values;
+    char keys[sizeof (K) * CAPACITY];
+    char values[sizeof (V) * CAPACITY];
+
+    const K& getKey(int index) const {
+        return reinterpret_cast<const K*> (keys)[index];
+    }
+
+    V& getValue(int index) {
+        return reinterpret_cast<V*> (values)[index];
+    }
+
+    const V& getValue(int index) const {
+        return reinterpret_cast<const V*> (values)[index];
+    }
 
     int searchIndex(const K& key) const {
         int base = hash(key);
         for(int i = 0; i < CAPACITY; i++) {
             int h = (base + i) & MASK;
             if(!used[h]) {
-                return -1;
-            } else if(keys[h] == key) {
+                return -h;
+            } else if(getKey(h) == key) {
                 return h;
             }
         }
-        return -1;
+        return -CAPACITY;
+    }
+
+    void copy(const HashMap& other) {
+        for(int i = 0; i < other.CAPACITY; i++) {
+            if(other.used[i]) {
+                used[i] = true;
+                new (reinterpret_cast<K*> (keys) + i) K(other.getKey(i));
+                new (reinterpret_cast<V*> (values) + i) V(other.getValue(i));
+            }
+        }
+    }
+    
+    void move(HashMap& other) {
+        for(int i = 0; i < other.CAPACITY; i++) {
+            if(other.used[i]) {
+                used[i] = true;
+                new (reinterpret_cast<K*> (keys) + i) K(std::move(other.getKey(i)));
+                new (reinterpret_cast<V*> (values) + i) V(std::move(other.getValue(i)));
+            }
+        }
+        other.clear();
     }
 
 public:
@@ -39,37 +74,77 @@ public:
     HashMap() : used(false) {
     }
 
+    ~HashMap() {
+        clear();
+    }
+
+    HashMap(const HashMap& other) : used(false) {
+        copy(other);
+    }
+
+    HashMap& operator=(const HashMap& other) {
+        clear();
+        copy(other);
+        return *this;
+    }
+    
+    HashMap(HashMap&& other) : used(false) {
+        move(other);
+    }
+    
+    HashMap& operator=(HashMap&& other) {
+        clear();
+        move(other);
+        return *this;
+    }
+
     void add(const K& key, const V& value) {
-        int base = hash(key);
-        for(int i = 0; i < CAPACITY; i++) {
-            int h = (base + i) & MASK;
-            if(!used[h]) {
-                used[h] = true;
-                keys[h] = key;
-                values[h] = value;
-                return;
-            } else if(keys[h] == key) {
-                values[h] = value;
-                return;
-            }
+        int index = searchIndex(key);
+        if(index >= 0) {
+            getValue(index) = value;
+        } else if(index > -CAPACITY) {
+            index = -index;
+            used[index] = true;
+            new (reinterpret_cast<K*> (keys) + index) K(key);
+            new (reinterpret_cast<V*> (values) + index) V(value);
+        }
+    }
+
+    void add(const K& key, const V&& value) {
+        int index = searchIndex(key);
+        if(index >= 0) {
+            getValue(index) = std::move(value);
+        } else if(index > -CAPACITY) {
+            index = -index;
+            used[index] = true;
+            new (reinterpret_cast<K*> (keys) + index) K(key);
+            new (reinterpret_cast<V*> (values) + index) V(std::move(value));
         }
     }
 
     const V& search(const K& key, const V& notFound) const {
         int index = searchIndex(key);
-        return index == -1 ? notFound : values[index];
+        return index < 0 ? notFound : getValue(index);
     }
 
     V& search(const K& key, V& notFound) {
         int index = searchIndex(key);
-        return index == -1 ? notFound : values[index];
+        return index < 0 ? notFound : getValue(index);
     }
 
     bool contains(const K& key) const {
-        return searchIndex(key) != -1;
+        return searchIndex(key) >= 0;
     }
 
     void clear() {
+        K* k = reinterpret_cast<K*> (keys);
+        V* v = reinterpret_cast<V*> (values);
+        for(int i = 0; i < CAPACITY; i++) {
+            if(used[i]) {
+                k[i].~K();
+                v[i].~V();
+            }
+        }
         used.fill(false);
     }