#ifndef CORE_PROBING_HASHMAP_HPP #define CORE_PROBING_HASHMAP_HPP #include "core/data/List.hpp" #include "core/utils/ArrayString.hpp" #include "core/utils/HashCode.hpp" #include "core/utils/New.hpp" #include "core/utils/Types.hpp" namespace Core { template<typename K, typename V> struct ProbingHashMap final { template<typename Value> class Node final { friend ProbingHashMap; friend List<Node>; K key; public: Value& value; const K& getKey() const { return key; } void toString(BufferString& s) const { s.append(key).append(" = ").append(value); } private: Node(const K& key_, Value& value_) : key(key_), value(value_) { } }; private: static constexpr K INVALID = emptyValue<K>(); template<typename Value, typename R, R (*A)(const K&, Value&)> class Iterator final { const K* currentKey; const K* endKey; Value* currentValue; public: Iterator(const K* key, const K* endKey_, Value* value) : currentKey(key), endKey(endKey_), currentValue(value) { skip(); } Iterator& operator++() { ++currentKey; ++currentValue; skip(); return *this; } bool operator!=(const Iterator& other) const { return currentKey != other.currentKey; } R operator*() const { return A(*currentKey, *currentValue); } private: void skip() { while(currentKey != endKey && !((*currentKey != INVALID) != (currentKey + 1 == endKey))) { ++currentKey; ++currentValue; } } }; template<typename Value> static Node<Value> access(const K& key, Value& value) { return Node<Value>(key, value); } template<typename Value> static Value& accessValue(const K&, Value& value) { return value; } static const K& accessKey(const K& key, const V&) { return key; } template<typename Value> using BaseEntryIterator = Iterator<Value, Node<Value>, access<Value>>; using EntryIterator = BaseEntryIterator<V>; using ConstEntryIterator = BaseEntryIterator<const V>; template<typename Value> using BaseValueIterator = Iterator<Value, Value&, accessValue<Value>>; using ValueIterator = BaseValueIterator<V>; using ConstValueIterator = BaseValueIterator<const V>; using ConstKeyIterator = Iterator<const V, const K&, accessKey>; template<typename M, typename I> struct IteratorAdapter final { M& map; I begin() const { return {map.keys.begin(), map.keys.end(), map.values}; } I end() const { return {map.keys.end(), map.keys.end(), nullptr}; } }; using ValueIteratorAdapter = IteratorAdapter<ProbingHashMap, ValueIterator>; using ConstValueIteratorAdapter = IteratorAdapter<const ProbingHashMap, ConstValueIterator>; using ConstKeyIteratorAdapter = IteratorAdapter<const ProbingHashMap, ConstKeyIterator>; private: List<K> keys{}; V* values = nullptr; size_t entries = 0; public: ProbingHashMap() = default; ProbingHashMap(const ProbingHashMap& other) { for(const auto& e : other) { add(e.getKey(), e.value); } } ProbingHashMap(ProbingHashMap&& other) { swap(other); } ~ProbingHashMap() { size_t length = keys.getLength(); if(length > 0) { length--; for(size_t i = 0; i < length; i++) { if(keys[i] != INVALID) { values[i].~V(); } } if(keys[length] == INVALID) { values[length].~V(); } } delete[] reinterpret_cast<AlignedType<V>*>(values); } ProbingHashMap& operator=(ProbingHashMap other) { swap(other); return *this; } void rehash(size_t minCapacity) { if(minCapacity <= keys.getLength()) { return; } ProbingHashMap<K, V> map; size_t l = (1lu << Math::roundUpLog2(Math::max(minCapacity, 8lu))) + 1; map.keys.resize(l, INVALID); map.keys[map.keys.getLength() - 1] = K(); map.values = reinterpret_cast<V*>(new(noThrow) AlignedType<V>[l]); size_t length = keys.getLength(); if(length > 0) { length--; for(size_t i = 0; i < length; i++) { if(keys[i] != INVALID) { map.add(keys[i], values[i]); } } if(keys[length] == INVALID) { map.add(keys[length], values[length]); } } swap(map); } template<typename... Args> bool tryEmplace(V*& v, const K& key, Args&&... args) { size_t index = searchSlot(key); if(keys[index] == key) { return false; } keys[index] = key; v = new(values + index) V(Core::forward<Args>(args)...); entries++; return true; } template<typename VA> V& put(const K& key, VA&& value) { size_t index = searchSlot(key); if(keys[index] == key) { return (values[index] = Core::forward<VA>(value)); } new(values + index) V(Core::forward<VA>(value)); entries++; keys[index] = key; return values[index]; } template<typename VA> ProbingHashMap& add(const K& key, VA&& value) { put(key, Core::forward<VA>(value)); return *this; } const V* search(const K& key) const { return searchValue<const V>(key); } V* search(const K& key) { return searchValue<V>(key); } bool contains(const K& key) const { return search(key) != nullptr; } ProbingHashMap& clear() { ProbingHashMap<K, V> map; swap(map); return *this; } ConstKeyIteratorAdapter getKeys() const { return {*this}; } ValueIteratorAdapter getValues() { return {*this}; } ConstValueIteratorAdapter getValues() const { return {*this}; } EntryIterator begin() { return {keys.begin(), keys.end(), values}; } EntryIterator end() { return {keys.end(), keys.end(), nullptr}; } ConstEntryIterator begin() const { return {keys.begin(), keys.end(), values}; } ConstEntryIterator end() const { return {keys.end(), keys.end(), nullptr}; } void toString(BufferString& s) const { Core::toString(s, *this); } void swap(ProbingHashMap& o) { Core::swap(o.keys, keys); Core::swap(o.values, values); Core::swap(o.entries, entries); } private: size_t searchSlot(const K& key) { size_t rehashFactor = 2; while(true) { rehash(entries * rehashFactor + 1); if(key == INVALID) { return keys.getLength() - 1; } size_t baseHash = hashCode(key) * 514685581u; size_t end = keys.getLength() - 2; // rehash on bad clustering for(size_t i = 0; i <= 5; i++) { size_t hash = (baseHash + i) & end; if(keys[hash] == INVALID || keys[hash] == key) { return hash; } } rehashFactor *= 2; } } template<typename Value> Value* searchValue(const K& key) const { if(keys.getLength() != 0) { if(key == INVALID) { size_t i = keys.getLength() - 1; return keys[i] == INVALID ? values + i : nullptr; } size_t baseHash = hashCode(key) * 514685581u; size_t end = keys.getLength() - 2; for(size_t i = 0; i <= end; i++) [[unlikely]] { size_t hash = (baseHash + i) & end; if(keys[hash] == key) [[likely]] { return values + hash; } else if(keys[hash] == INVALID) { return nullptr; } } } return nullptr; } }; } #endif