Browse Source

most base stuff for networking with enet, hashmap uses lists and has remove

Kajetan Johannes Hammerle 3 years ago
parent
commit
88dd110de4
17 changed files with 568 additions and 59 deletions
  1. 3 0
      .gitmodules
  2. 2 0
      Main.cpp
  3. 1 0
      enet
  4. 8 2
      meson.build
  5. 59 0
      network/Client.cpp
  6. 46 0
      network/Client.h
  7. 10 0
      network/ENet.cpp
  8. 23 0
      network/ENet.h
  9. 33 0
      network/Packet.cpp
  10. 23 0
      network/Packet.h
  11. 50 0
      network/Server.cpp
  12. 122 0
      network/Server.h
  13. 25 0
      tests/HashMapTests.cpp
  14. 83 0
      tests/NetworkTests.cpp
  15. 8 0
      tests/NetworkTests.h
  16. 70 54
      utils/HashMap.h
  17. 2 3
      utils/List.h

+ 3 - 0
.gitmodules

@@ -0,0 +1,3 @@
+[submodule "enet"]
+	path = enet
+	url = https://github.com/zpl-c/enet.git

+ 2 - 0
Main.cpp

@@ -13,6 +13,7 @@
 #include "tests/ListTests.h"
 #include "tests/MatrixStackTests.h"
 #include "tests/MatrixTests.h"
+#include "tests/NetworkTests.h"
 #include "tests/PNGReaderTests.h"
 #include "tests/PlaneTests.h"
 #include "tests/QuaternionTests.h"
@@ -54,5 +55,6 @@ int main(int argAmount, char** args) {
     BufferTests::test();
     TypedBufferTests::test();
     UniquePointerTests::test();
+    NetworkTests::test();
     return 0;
 }

+ 1 - 0
enet

@@ -0,0 +1 @@
+Subproject commit 5bd2ae50e0b593164172421913ce76ec3a0908e4

+ 8 - 2
meson.build

@@ -48,13 +48,19 @@ sources = ['Main.cpp',
     'input/Button.cpp',
     'input/Buttons.cpp',
     'tests/UniquePointerTests.cpp',
-    'utils/Buffer.cpp']
+    'utils/Buffer.cpp',
+    'tests/NetworkTests.cpp',
+    'network/Packet.cpp',
+    'network/Server.cpp',
+    'network/Client.cpp',
+    'network/ENet.cpp']
 
+threadDep = dependency('threads')
 glewDep = dependency('glew')
 glfwDep = dependency('glfw3')
 pngDep = dependency('libpng')
 
 executable('tests', 
     sources: sources,
-    dependencies : [glewDep, glfwDep, pngDep],
+    dependencies : [threadDep, glewDep, glfwDep, pngDep],
     cpp_args: ['-Wall', '-Wextra', '-pedantic', '-Werror'])

+ 59 - 0
network/Client.cpp

@@ -0,0 +1,59 @@
+#include "network/Client.h"
+
+Client::Client()
+    : client(enet_host_create(nullptr, 1, 2, 0, 0)), connection(nullptr) {
+    if(client == nullptr) {
+        error.clear().append("cannot crate ENet client host");
+    }
+}
+
+Client::~Client() {
+    disconnect();
+    enet_host_destroy(client);
+}
+
+bool Client::hasError() const {
+    return error.getLength() > 0;
+}
+
+const Client::Error& Client::getError() const {
+    return error;
+}
+
+bool Client::connect(const char* server, Port port, int timeout) {
+    ENetAddress address;
+    ENetEvent event;
+    enet_address_set_host(&address, server);
+    address.port = port;
+
+    connection = enet_host_connect(client, &address, 2, 0);
+    if(connection == nullptr) {
+        error.clear().append("server is not available");
+        return true;
+    }
+
+    if(enet_host_service(client, &event, timeout) <= 0 ||
+       event.type != ENET_EVENT_TYPE_CONNECT) {
+        error.clear().append("connection failed");
+        disconnect();
+        return true;
+    }
+    return false;
+}
+
+void Client::disconnect() {
+    if(connection == nullptr) {
+        return;
+    }
+    ENetEvent e;
+    enet_peer_disconnect(connection, 0);
+    while(enet_host_service(client, &e, 3000) > 0) {
+        switch(e.type) {
+            case ENET_EVENT_TYPE_RECEIVE: enet_packet_destroy(e.packet); break;
+            case ENET_EVENT_TYPE_DISCONNECT: connection = nullptr; return;
+            default: break;
+        }
+    }
+    enet_peer_reset(connection);
+    connection = nullptr;
+}

+ 46 - 0
network/Client.h

@@ -0,0 +1,46 @@
+#ifndef CLIENT_H
+#define CLIENT_H
+
+#include "network/ENet.h"
+#include "utils/StringBuffer.h"
+
+class Client final {
+    typedef enet_uint16 Port;
+    typedef StringBuffer<256> Error;
+
+private:
+    ENetHost* client;
+    ENetPeer* connection;
+    Error error;
+
+public:
+    Client();
+    Client(const Client&) = delete;
+    Client(Client&&) = delete;
+    ~Client();
+    Client& operator=(const Client&) = delete;
+    Client& operator=(Client&&) = delete;
+
+    bool hasError() const;
+    const Error& getError() const;
+
+    bool connect(const char* server, Port port, int timeout);
+    void disconnect();
+
+    template<typename T>
+    void consumeEvents(T& consumer) {
+        (void)consumer;
+        ENetEvent e;
+        while(enet_host_service(client, &e, 0) > 0) {
+            switch(e.type) {
+                case ENET_EVENT_TYPE_CONNECT: std::cout << "1\n"; break;
+                case ENET_EVENT_TYPE_RECEIVE: std::cout << "2\n"; break;
+                case ENET_EVENT_TYPE_DISCONNECT_TIMEOUT:
+                case ENET_EVENT_TYPE_DISCONNECT: std::cout << "3\n"; break;
+                case ENET_EVENT_TYPE_NONE: std::cout << "4\n"; return;
+            }
+        }
+    }
+};
+
+#endif

+ 10 - 0
network/ENet.cpp

@@ -0,0 +1,10 @@
+#define ENET_IMPLEMENTATION
+#include "network/ENet.h"
+
+ENet::~ENet() {
+    enet_deinitialize();
+}
+
+bool ENet::init() {
+    return enet_initialize() != 0;
+}

+ 23 - 0
network/ENet.h

@@ -0,0 +1,23 @@
+#ifndef ENET_H
+#define ENET_H
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+#pragma GCC diagnostic ignored "-Wpragmas"
+#pragma GCC diagnostic ignored "-Wextra"
+#pragma GCC diagnostic ignored "-Wpedantic"
+#include "enet/include/enet.h"
+#pragma GCC diagnostic pop
+
+struct ENet final {
+    ENet() = default;
+    ENet(const ENet&) = delete;
+    ENet(ENet&&) = delete;
+    ~ENet();
+    ENet& operator=(const ENet&) = delete;
+    ENet& operator=(ENet&&) = delete;
+
+    bool init();
+};
+
+#endif

+ 33 - 0
network/Packet.cpp

@@ -0,0 +1,33 @@
+#include "network/Packet.h"
+
+InPacket::InPacket(ENetPacket* packet) : packet(packet), readIndex(0) {
+}
+
+bool InPacket::read(void* buffer, unsigned int length) {
+    if(readIndex + length > packet->dataLength) {
+        return true;
+    }
+    memcpy(buffer, packet->data + readIndex, length);
+    readIndex += length;
+    return false;
+}
+
+bool InPacket::read(uint8& u) {
+    return read(&u, sizeof(u));
+}
+
+bool InPacket::read(uint16& u) {
+    if(read(&u, sizeof(u))) {
+        return true;
+    }
+    u = ntohs(u);
+    return false;
+}
+
+bool InPacket::read(uint32& u) {
+    if(read(&u, sizeof(u))) {
+        return true;
+    }
+    u = ntohl(u);
+    return false;
+}

+ 23 - 0
network/Packet.h

@@ -0,0 +1,23 @@
+#ifndef PACKET_H
+#define PACKET_H
+
+#include "network/ENet.h"
+#include "utils/Types.h"
+
+class InPacket {
+    ENetPacket* packet;
+    unsigned int readIndex;
+
+    friend class Server;
+
+    InPacket(ENetPacket* packet);
+
+    bool read(void* buffer, unsigned int length);
+
+public:
+    bool read(uint8& u);
+    bool read(uint16& u);
+    bool read(uint32& u);
+};
+
+#endif

+ 50 - 0
network/Server.cpp

@@ -0,0 +1,50 @@
+#include <utility>
+
+#include "network/Server.h"
+
+Server::Client::Client(ENetPeer* peer, int id) : peer(peer), id(id) {
+}
+
+Server::Client::~Client() {
+    if(peer != nullptr) {
+        enet_peer_disconnect(peer, 0);
+    }
+}
+
+Server::Client::Client(Client&& other) : Client(nullptr, -1) {
+    std::swap(peer, other.peer);
+    std::swap(id, other.id);
+}
+
+Server::Client& Server::Client::operator=(Client&& other) {
+    std::swap(peer, other.peer);
+    std::swap(id, other.id);
+    return *this;
+}
+
+int Server::Client::getId() const {
+    return id;
+}
+
+Server::Server(Port port, int maxClients) : server(nullptr), idCounter(1) {
+    ENetAddress address;
+    address.host = ENET_HOST_ANY;
+    address.port = port;
+
+    server = enet_host_create(&address, maxClients, 2, 0, 0);
+    if(server == nullptr) {
+        error.clear().append("cannot crate ENet server host");
+    }
+}
+
+Server::~Server() {
+    enet_host_destroy(server);
+}
+
+bool Server::hasError() const {
+    return error.getLength() > 0;
+}
+
+const Server::Error& Server::getError() const {
+    return error;
+}

+ 122 - 0
network/Server.h

@@ -0,0 +1,122 @@
+#ifndef SERVER_H
+#define SERVER_H
+
+#include "network/ENet.h"
+#include "network/Packet.h"
+#include "utils/HashMap.h"
+#include "utils/StringBuffer.h"
+#include "utils/Types.h"
+
+struct Server final {
+    typedef enet_uint16 Port;
+    typedef StringBuffer<256> Error;
+
+    class Client final {
+        ENetPeer* peer;
+        int id;
+
+        friend HashMap<int, Client>;
+
+        Client(ENetPeer* peer, int id);
+        Client(const Client&) = delete;
+        Client(Client&& other);
+        Client& operator=(const Client&) = delete;
+        Client& operator=(Client&& other);
+
+    public:
+        ~Client();
+
+        int getId() const;
+    };
+
+private:
+    ENetHost* server;
+    Error error;
+    HashMap<int, Client> clients;
+    int idCounter;
+
+public:
+    Server(Port port, int maxClients);
+    Server(const Server&) = delete;
+    Server(Server&&) = delete;
+    ~Server();
+    Server& operator=(const Server&) = delete;
+    Server& operator=(Server&&) = delete;
+
+    bool hasError() const;
+    const Error& getError() const;
+
+    template<typename T>
+    void consumeEvents(T& consumer) {
+        ENetEvent e;
+        while(!hasError() && enet_host_service(server, &e, 0) > 0) {
+            switch(e.type) {
+                case ENET_EVENT_TYPE_CONNECT: onConnect(e, consumer); break;
+                case ENET_EVENT_TYPE_RECEIVE:
+                    onPackage(e, consumer);
+                    enet_packet_destroy(e.packet);
+                    break;
+                case ENET_EVENT_TYPE_DISCONNECT_TIMEOUT:
+                case ENET_EVENT_TYPE_DISCONNECT:
+                    onDisconnect(e, consumer);
+                    break;
+                case ENET_EVENT_TYPE_NONE: return;
+            }
+        }
+    }
+
+private:
+    template<typename T>
+    void onConnect(ENetEvent& e, T& consumer) {
+        int id = idCounter++;
+        if(clients.tryEmplace(id, e.peer, id)) {
+            error.clear().append("id is connected twice");
+            return;
+        }
+        static_assert(sizeof(e.peer->data) >= sizeof(id),
+                      "private data not big enough for id");
+        memcpy(&(e.peer->data), &id, sizeof(id));
+        Client* client = clients.search(id);
+        if(client != nullptr) {
+            consumer.onConnection(*client);
+        } else {
+            error.clear().append("cannot find added client");
+        }
+    }
+
+    template<typename T>
+    void onPackage(ENetEvent& e, T& consumer) {
+        if(e.peer->data == nullptr) {
+            error.clear().append("client without data sent package");
+            return;
+        }
+        int id = -1;
+        memcpy(&id, &(e.peer->data), sizeof(id));
+        Client* client = clients.search(id);
+        if(client != nullptr) {
+            InPacket in(e.packet);
+            consumer.onPackage(*client, in);
+        } else {
+            error.clear().append("client with invalid id sent package");
+        }
+    }
+
+    template<typename T>
+    void onDisconnect(ENetEvent& e, T& consumer) {
+        if(e.peer->data == nullptr) {
+            error.clear().append("client without data disconnected");
+            return;
+        }
+        int id = -1;
+        memcpy(&id, &(e.peer->data), sizeof(id));
+        Client* client = clients.search(id);
+        if(client != nullptr) {
+            consumer.onDisconnect(*client);
+            clients.remove(id);
+        } else {
+            error.clear().append("client has invalid id");
+        }
+    }
+};
+
+#endif

+ 25 - 0
tests/HashMapTests.cpp

@@ -214,6 +214,30 @@ static void testMoveAssignment(Test& test) {
     }
 }
 
+static void testRemove(Test& test) {
+    IntMap map;
+    map.add(1, 3).add(2, 4).add(3, 5);
+
+    bool remove1 = map.remove(2);
+    bool remove2 = map.remove(7);
+
+    int* a = map.search(1);
+    int* b = map.search(2);
+    int* c = map.search(3);
+
+    test.checkEqual(true, a != nullptr, "move moves values 1");
+    test.checkEqual(true, b == nullptr, "move moves values 2");
+    test.checkEqual(true, c != nullptr, "move moves values 3");
+
+    test.checkEqual(true, remove1, "remove returns true");
+    test.checkEqual(false, remove2, "remove returns false");
+
+    if(a != nullptr && c != nullptr) {
+        test.checkEqual(3, *a, "move moves values 1");
+        test.checkEqual(5, *c, "move moves values 3");
+    }
+}
+
 void HashMapTests::test() {
     Test test("HashMap");
     testAdd(test);
@@ -230,5 +254,6 @@ void HashMapTests::test() {
     testCopyAssignment(test);
     testMove(test);
     testMoveAssignment(test);
+    testRemove(test);
     test.finalize();
 }

+ 83 - 0
tests/NetworkTests.cpp

@@ -0,0 +1,83 @@
+#include <atomic>
+#include <thread>
+
+#include "network/Client.h"
+#include "network/Server.h"
+#include "tests/NetworkTests.h"
+#include "tests/Test.h"
+
+static void sleep(int millis) {
+    std::this_thread::sleep_for(std::chrono::milliseconds(millis));
+}
+
+struct ServerConsumer {
+    bool connected = false;
+    bool disconnect = false;
+
+    void onConnection(Server::Client& client) {
+        (void)client;
+        connected = true;
+    }
+
+    void onDisconnect(Server::Client& client) {
+        (void)client;
+        disconnect = true;
+    }
+
+    void onPackage(Server::Client& client, InPacket& in) {
+        (void)client;
+        (void)in;
+    }
+};
+
+struct ClientConsumer {};
+
+static void testConnect(Test& test) {
+    Server server(54321, 5);
+    if(server.hasError()) {
+        test.checkEqual(false, true, "server can initialize");
+        return;
+    }
+    Client client;
+    if(client.hasError()) {
+        test.checkEqual(false, true, "client can initialize");
+        return;
+    }
+
+    std::atomic_bool running(true);
+    ServerConsumer serverConsumer;
+    std::thread listen([&running, &server, &serverConsumer]() {
+        while(running) {
+            server.consumeEvents(serverConsumer);
+        }
+    });
+
+    test.checkEqual(false, client.connect("127.0.0.1", 54321, 5),
+                    "connection failed");
+
+    ClientConsumer clientConsumer;
+    for(int i = 0; i < 100; i++) {
+        client.consumeEvents(clientConsumer);
+    }
+
+    test.checkEqual(true, serverConsumer.connected, "server has connection");
+
+    client.disconnect();
+    sleep(100);
+
+    test.checkEqual(true, serverConsumer.disconnect, "client has disconnected");
+
+    running = false;
+    listen.join();
+}
+
+void NetworkTests::test() {
+    Test test("Network");
+    ENet enet;
+    if(enet.init()) {
+        test.checkEqual(false, true, "enet init failed");
+    } else {
+        testConnect(test);
+    }
+    test.finalize();
+}

+ 8 - 0
tests/NetworkTests.h

@@ -0,0 +1,8 @@
+#ifndef NETWORKTESTS_H
+#define NETWORKTESTS_H
+
+namespace NetworkTests {
+    void test();
+}
+
+#endif

+ 70 - 54
utils/HashMap.h

@@ -11,13 +11,18 @@
 
 template<typename K, typename V>
 struct HashMap final {
-    struct Node {
+    class Node {
         friend HashMap;
         friend List<Node>;
+        K key;
 
-        const K key;
+    public:
         V value;
 
+        const K& getKey() const {
+            return key;
+        }
+
     private:
         int next;
 
@@ -35,58 +40,70 @@ struct HashMap final {
     };
 
 private:
-    List<int> nodePointers;
-    List<Node> nodes;
+    List<List<Node>> nodes;
+    int elements;
 
 public:
-    HashMap(int minCapacity = 8) {
-        nodePointers.resize(1 << Utils::roundUpLog2(minCapacity), -1);
+    HashMap(int minCapacity = 8) : elements(0) {
+        nodes.resize(1 << Utils::roundUpLog2(minCapacity));
     }
 
     template<typename... Args>
     bool tryEmplace(const K& key, Args&&... args) {
-        int pointer = prepareAdd(key);
-        if(pointer == -1) {
-            nodes.add(key, std::forward<Args>(args)...);
+        rehash();
+        Hash h = hash(key);
+        V* v = searchList(key, h);
+        if(v == nullptr) {
+            nodes[h].add(key, std::forward<Args>(args)...);
+            elements++;
             return false;
         }
         return true;
     }
 
     HashMap& add(const K& key, const V& value) {
-        int pointer = prepareAdd(key);
-        if(pointer != -1) {
-            nodes[pointer].value = value;
+        rehash();
+        Hash h = hash(key);
+        V* v = searchList(key, h);
+        if(v == nullptr) {
+            nodes[h].add(key, value);
+            elements++;
         } else {
-            nodes.add(key, value);
+            *v = value;
         }
         return *this;
     }
 
     HashMap& add(const K& key, V&& value) {
-        int pointer = prepareAdd(key);
-        if(pointer != -1) {
-            nodes[pointer].value = std::move(value);
+        rehash();
+        Hash h = hash(key);
+        V* v = searchList(key, h);
+        if(v == nullptr) {
+            nodes[h].add(key, std::move(value));
+            elements++;
         } else {
-            nodes.add(key, std::move(value));
+            *v = std::move(value);
         }
         return *this;
     }
 
-    const V* search(const K& key) const {
-        Hash h = hash(key);
-        int pointer = nodePointers[h];
-        while(pointer != -1) {
-            if(nodes[pointer].key == key) {
-                return &(nodes[pointer].value);
+    bool remove(const K& key) {
+        List<Node>& list = nodes[hash(key)];
+        for(int i = 0; i < list.getLength(); i++) {
+            if(list[i].key == key) {
+                list.remove(i);
+                return true;
             }
-            pointer = nodes[pointer].next;
         }
-        return nullptr;
+        return false;
+    }
+
+    const V* search(const K& key) const {
+        return searchList(key, hash(key));
     }
 
     V* search(const K& key) {
-        return const_cast<V*>(static_cast<const HashMap*>(this)->search(key));
+        return searchList(key, hash(key));
     }
 
     bool contains(const K& key) const {
@@ -94,10 +111,10 @@ public:
     }
 
     HashMap& clear() {
-        for(int& pointer : nodePointers) {
-            pointer = -1;
+        for(List<Node>& n : nodes) {
+            n.clear();
         }
-        nodes.clear();
+        elements = 0;
         return *this;
     }
 
@@ -121,12 +138,14 @@ public:
     void toString(StringBuffer<L>& s) const {
         s.append("[");
         bool c = false;
-        for(const Node& n : nodes) {
-            if(c) {
-                s.append(", ");
+        for(const List<Node>& list : nodes) {
+            for(const Node& n : list) {
+                if(c) {
+                    s.append(", ");
+                }
+                s.append(n.key).append(" = ").append(n.value);
+                c = true;
             }
-            s.append(n.key).append(" = ").append(n.value);
-            c = true;
         }
         s.append("]");
     }
@@ -134,7 +153,7 @@ public:
 private:
     template<typename H>
     Hash hash(const H& key) const {
-        return fullHash(key) & (nodePointers.getLength() - 1);
+        return fullHash(key) & (nodes.getLength() - 1);
     }
 
     template<typename H>
@@ -147,33 +166,30 @@ private:
     }
 
     void rehash() {
-        if(nodes.getLength() < nodePointers.getLength()) {
+        if(elements < nodes.getLength()) {
             return;
         }
-        HashMap<K, V> map(nodePointers.getLength() * 2);
-        for(const Node& n : nodes) {
-            map.add(n.key, std::move(n.value));
+        HashMap<K, V> map(nodes.getLength() * 2);
+        for(List<Node>& list : nodes) {
+            for(Node& n : list) {
+                map.tryEmplace(n.key, std::move(n.value));
+            }
         }
         *this = std::move(map);
     }
 
-    int prepareAdd(const K& key) {
-        rehash();
-        Hash h = hash(key);
-        int pointer = nodePointers[h];
-        if(pointer == -1) {
-            nodePointers[h] = nodes.getLength();
-            return -1;
-        }
-        while(true) {
-            if(nodes[pointer].key == key) {
-                return pointer;
-            } else if(nodes[pointer].next == -1) {
-                nodes[pointer].next = nodes.getLength();
-                return -1;
+    const V* searchList(const K& key, Hash h) const {
+        for(const Node& n : nodes[h]) {
+            if(n.key == key) {
+                return &n.value;
             }
-            pointer = nodes[pointer].next;
         }
+        return nullptr;
+    }
+
+    V* searchList(const K& key, Hash h) {
+        return const_cast<V*>(
+            static_cast<const HashMap*>(this)->searchList(key, h));
     }
 };
 

+ 2 - 3
utils/List.h

@@ -80,12 +80,11 @@ public:
         }
     }
 
-    template<typename... Args>
-    void resize(int n, Args&&... args) {
+    void resize(int n) {
         if(length < n) {
             reserve(n);
             for(int i = length; i < n; i++) {
-                add(std::forward<Args>(args)...);
+                add(T());
             }
         } else if(length > n) {
             for(int i = n; i < length; i++) {