Kajetan Johannes Hammerle 2 сар өмнө
parent
commit
17d3a13663

+ 45 - 10
include/core/ProbingHashMap.hpp

@@ -2,6 +2,7 @@
 #define CORE_PROBING_HASHMAP_HPP
 
 #include "core/List.hpp"
+#include "core/Logger.hpp"
 #include "core/ToString.hpp"
 #include "core/Utility.hpp"
 
@@ -141,7 +142,7 @@ namespace Core {
         using ConstKeyIteratorAdapter =
             IteratorAdapter<const ProbingHashMap, ConstKeyIterator>;
 
-    public:
+    private:
         List<K> keys{};
         V* values = nullptr;
         List<int> jumps{};
@@ -211,7 +212,7 @@ namespace Core {
                 if(invalidSet) {
                     return false;
                 }
-                rehash(entries * 2 + 1);
+                rehash(1);
                 invalidSet = true;
             } else {
                 index = searchSlot(key);
@@ -233,7 +234,7 @@ namespace Core {
                 if(invalidSet) {
                     return (values[0] = Core::forward<VA>(value));
                 }
-                rehash(entries * 2 + 1);
+                rehash(1);
                 invalidSet = true;
             } else {
                 index = searchSlot(key);
@@ -254,6 +255,26 @@ namespace Core {
             return *this;
         }
 
+        bool remove(const K& key) {
+            size_t index = 0;
+            if(key == INVALID) {
+                if(!invalidSet) {
+                    return false;
+                }
+                invalidSet = false;
+            } else {
+                index = searchSlot(key);
+                if(keys[index] != key) {
+                    return false;
+                }
+            }
+            values[index].~V();
+            entries--;
+            demarkSlot(key);
+            keys[index] = INVALID;
+            return true;
+        }
+
         const V* search(const K& key) const {
             return searchValue<const V>(key);
         }
@@ -309,27 +330,29 @@ namespace Core {
         }
 
     private:
+        static constexpr size_t MAX_CLUSTER = 5;
+
         size_t searchSlot(const K& key) {
-            size_t rehashFactor = 2;
+            rehash(1);
             while(true) {
-                rehash(entries * rehashFactor + 1);
                 size_t baseHash = hashCode(key) * 514'685'581u;
                 size_t end = keys.getLength() - 2;
                 // rehash on bad clustering
-                for(size_t i = 0; i <= 5; i++) {
+                for(size_t i = 0; i <= MAX_CLUSTER; i++) {
                     size_t hash = 1 + ((baseHash + i) & end);
-                    if(keys[hash] == INVALID || keys[hash] == key) {
+                    if((keys[hash] == INVALID && jumps[hash] == 0) ||
+                       keys[hash] == key) {
                         return hash;
                     }
                 }
-                rehashFactor *= 2;
+                rehash(keys.getLength() + 1);
             }
         }
 
         void markSlot(const K& key) {
             size_t baseHash = hashCode(key) * 514'685'581u;
             size_t end = keys.getLength() - 2;
-            for(size_t i = 0; i <= 5; i++) {
+            for(size_t i = 0; i <= MAX_CLUSTER; i++) {
                 size_t hash = 1 + ((baseHash + i) & end);
                 if(keys[hash] == key) {
                     return;
@@ -338,6 +361,18 @@ namespace Core {
             }
         }
 
+        void demarkSlot(const K& key) {
+            size_t baseHash = hashCode(key) * 514'685'581u;
+            size_t end = keys.getLength() - 2;
+            for(size_t i = 0; i <= MAX_CLUSTER; i++) {
+                size_t hash = 1 + ((baseHash + i) & end);
+                if(keys[hash] == key) {
+                    return;
+                }
+                jumps[hash]--;
+            }
+        }
+
         template<typename Value>
         Value* searchValue(const K& key) const {
             if(keys.getLength() != 0) {
@@ -350,7 +385,7 @@ namespace Core {
                     size_t hash = 1 + ((baseHash + i) & end);
                     if(keys[hash] == key) [[likely]] {
                         return values + hash;
-                    } else if(keys[hash] == INVALID) {
+                    } else if(jumps[hash] == 0) {
                         return nullptr;
                     }
                 }

+ 67 - 46
test/modules/HashMapTests.cpp

@@ -1,5 +1,6 @@
 #include "../Tests.hpp"
 #include "core/ProbingHashMap.hpp"
+#include "core/Random.hpp"
 #include "core/Test.hpp"
 
 template struct Core::ProbingHashMap<int, int>;
@@ -337,6 +338,55 @@ static void testAddCollisions() {
     }
 }
 
+template<typename T>
+static void testRemove() {
+    T map;
+    map.add(1, 3).add(2, 4).add(3, 5);
+
+    TEST_TRUE(map.remove(2));
+    TEST_FALSE(map.remove(7));
+
+    int* a = map.search(1);
+    int* b = map.search(2);
+    int* c = map.search(3);
+
+    TEST_NULL(b);
+    if(TEST_NOT_NULL(a) && TEST_NOT_NULL(c)) {
+        TEST(3, *a);
+        TEST(5, *c);
+    }
+}
+
+template<typename T>
+static void testRemoveLong() {
+    T map;
+    Core::Random r(5);
+    constexpr size_t LIMIT = 75;
+    Core::Array<i32, LIMIT> a;
+    for(int i = 0; i < 10'000; i++) {
+        i32 r1 = r.nextI32(0, LIMIT);
+        if(r.nextBool()) {
+            i32 r2 = r.nextI32(0, LIMIT);
+            map.add(r1, 2);
+            map.add(r2, 2);
+            a[static_cast<size_t>(r1)] = 1;
+            a[static_cast<size_t>(r2)] = 1;
+        } else {
+            map.remove(r1);
+            a[static_cast<size_t>(r1)] = 0;
+        }
+
+        Core::Array<i32, LIMIT> copy = a;
+        for(int key : map.getKeys()) {
+            TEST_TRUE(copy[static_cast<size_t>(key)]);
+            copy[static_cast<size_t>(key)] = 0;
+        }
+        for(int key : copy) {
+            TEST(0, key);
+        }
+    }
+}
+
 template<typename T>
 static void testMap(bool light) {
     testAdd<T>();
@@ -356,64 +406,35 @@ static void testMap(bool light) {
     testInvalid<T>();
     testInvalidPut<T>();
     testAddCollisions<T>();
+    testRemove<T>();
+    testRemoveLong<T>();
 }
 
 // static void testEmplace() {
-//     Core::HashMap<int, HashMapTest> map;
+//     Core::ProbingHashMap<int, HashMapTest> map;
 //
 //     HashMapTest* ar = nullptr;
-//     CORE_TEST_TRUE(map.tryEmplace(ar, 0, 3, 4));
-//     CORE_TEST_TRUE(map.tryEmplace(ar, 3, 4, 5));
-//     CORE_TEST_TRUE(map.tryEmplace(ar, 20, 5, 6));
-//     CORE_TEST_FALSE(map.tryEmplace(ar, 3, 6, 7));
-//     CORE_TEST_FALSE(map.tryEmplace(ar, 20, 7, 8));
+//     TEST_TRUE(map.tryEmplace(ar, 0, 3, 4));
+//     TEST_TRUE(map.tryEmplace(ar, 3, 4, 5));
+//     TEST_TRUE(map.tryEmplace(ar, 20, 5, 6));
+//     TEST_FALSE(map.tryEmplace(ar, 3, 6, 7));
+//     TEST_FALSE(map.tryEmplace(ar, 20, 7, 8));
 //
 //     HashMapTest* a = map.search(0);
 //     HashMapTest* b = map.search(3);
 //     HashMapTest* c = map.search(20);
 //
-//     if(CORE_TEST_NOT_NULL(a) && CORE_TEST_NOT_NULL(b) &&
-//        CORE_TEST_NOT_NULL(c)) {
-//         CORE_TEST_EQUAL(HashMapTest(3, 4), *a);
-//         CORE_TEST_EQUAL(HashMapTest(4, 5), *b);
-//         CORE_TEST_EQUAL(HashMapTest(5, 6), *c);
-//     }
-// }
-
-// static void testRemove() {
-//     IntMap map;
-//     map.add(1, 3).add(2, 4).add(3, 5);
-//
-//     TEST_TRUE(map.remove(2));
-//     TEST_FALSE(map.remove(7));
-//
-//     int* a = map.search(1);
-//     int* b = map.search(2);
-//     int* c = map.search(3);
-//
-//     TEST_NULL(b);
-//     if(TEST_NOT_NULL(a) && TEST_NOT_NULL(c)) {
-//         TEST(3, *a);
-//         TEST(5, *c);
+//     if(TEST_NOT_NULL(a) && TEST_NOT_NULL(b) && TEST_NOT_NULL(c)) {
+//         TEST(HashMapTest(3, 4), *a);
+//         TEST(HashMapTest(4, 5), *b);
+//         TEST(HashMapTest(5, 6), *c);
 //     }
 // }
 
 void testHashMap(bool light) {
-    if(!light) {
-        // testHash();
-        testMap<ProbingIntMap>(light);
-        // testMap<IntMap>(light);
-        // testEmplace();
-        testEmplaceProbing();
-        // testRemove();
-    }
-    ProbingIntMap map;
-    map.add(1, 2);
-    map.add(17, 3);
-    map.add(33, 4);
-    map.add(33, 6);
-    map.add(49, 5);
-    LOG_WARNING("#", map);
-    LOG_WARNING("#", map.keys);
-    LOG_WARNING("#", map.jumps);
+    // testHash();
+    testMap<ProbingIntMap>(light);
+    // testMap<IntMap>(light);
+    // testEmplace();
+    testEmplaceProbing();
 }