Browse Source

correct iterators for hashmaps, package from server to client

Kajetan Johannes Hammerle 3 years ago
parent
commit
48800a9c38
7 changed files with 142 additions and 20 deletions
  1. 8 5
      network/Client.h
  2. 2 0
      network/Packet.h
  3. 22 0
      network/Server.cpp
  4. 10 4
      network/Server.h
  5. 25 0
      tests/HashMapTests.cpp
  6. 28 3
      tests/NetworkTests.cpp
  7. 47 8
      utils/HashMap.h

+ 8 - 5
network/Client.h

@@ -30,15 +30,18 @@ public:
 
     template<typename T>
     void consumeEvents(T& consumer) {
-        (void)consumer;
         ENetEvent e;
         while(enet_host_service(client, &e, 0) > 0) {
             switch(e.type) {
-                case ENET_EVENT_TYPE_CONNECT: std::cout << "1\n"; break;
-                case ENET_EVENT_TYPE_RECEIVE: std::cout << "2\n"; break;
+                case ENET_EVENT_TYPE_CONNECT: break;
                 case ENET_EVENT_TYPE_DISCONNECT_TIMEOUT:
-                case ENET_EVENT_TYPE_DISCONNECT: std::cout << "3\n"; break;
-                case ENET_EVENT_TYPE_NONE: std::cout << "4\n"; return;
+                case ENET_EVENT_TYPE_DISCONNECT: consumer.onDisconnect(); break;
+                case ENET_EVENT_TYPE_NONE: return;
+                case ENET_EVENT_TYPE_RECEIVE:
+                    InPacket in(e.packet);
+                    consumer.onPacket(in);
+                    enet_packet_destroy(e.packet);
+                    break;
             }
         }
     }

+ 2 - 0
network/Packet.h

@@ -9,6 +9,7 @@ class InPacket {
     ENetPacket* packet;
     unsigned int index;
 
+    friend class Client;
     friend class Server;
 
     InPacket(ENetPacket* packet);
@@ -47,6 +48,7 @@ class OutPacket {
     int channel;
 
     friend class Client;
+    friend class Server;
 
     OutPacket(unsigned int size, int flags, int channel);
 

+ 22 - 0
network/Server.cpp

@@ -26,6 +26,13 @@ int Server::Client::getId() const {
     return id;
 }
 
+void Server::Client::send(OutPacket& p) {
+    if(p.packet != nullptr) {
+        enet_peer_send(peer, p.channel, p.packet);
+        p.packet = nullptr;
+    }
+}
+
 Server::Server(Port port, int maxClients) : server(nullptr), idCounter(1) {
     ENetAddress address;
     address.host = ENET_HOST_ANY;
@@ -38,6 +45,9 @@ Server::Server(Port port, int maxClients) : server(nullptr), idCounter(1) {
 }
 
 Server::~Server() {
+    for(auto& client : clients) {
+        enet_peer_reset(client.value.peer);
+    }
     enet_host_destroy(server);
 }
 
@@ -47,4 +57,16 @@ bool Server::hasError() const {
 
 const Server::Error& Server::getError() const {
     return error;
+}
+
+void Server::send(OutPacket& p) {
+    if(p.packet != nullptr) {
+        enet_host_broadcast(server, p.channel, p.packet);
+        p.packet = nullptr;
+    }
+}
+
+void Server::disconnect(Client& client) {
+    enet_peer_reset(client.peer);
+    clients.remove(client.getId());
 }

+ 10 - 4
network/Server.h

@@ -12,6 +12,8 @@ struct Server final {
     typedef StringBuffer<256> Error;
 
     class Client final {
+        friend class Server;
+
         ENetPeer* peer;
         int id;
 
@@ -27,6 +29,7 @@ struct Server final {
         ~Client();
 
         int getId() const;
+        void send(OutPacket& p);
     };
 
 private:
@@ -53,7 +56,7 @@ public:
             switch(e.type) {
                 case ENET_EVENT_TYPE_CONNECT: onConnect(e, consumer); break;
                 case ENET_EVENT_TYPE_RECEIVE:
-                    onPackage(e, consumer);
+                    onPacket(e, consumer);
                     enet_packet_destroy(e.packet);
                     break;
                 case ENET_EVENT_TYPE_DISCONNECT_TIMEOUT:
@@ -65,6 +68,9 @@ public:
         }
     }
 
+    void send(OutPacket& p);
+    void disconnect(Client& client);
+
 private:
     template<typename T>
     void onConnect(ENetEvent& e, T& consumer) {
@@ -78,14 +84,14 @@ private:
         memcpy(&(e.peer->data), &id, sizeof(id));
         Client* client = clients.search(id);
         if(client != nullptr) {
-            consumer.onConnection(*client);
+            consumer.onConnect(*client);
         } else {
             error.clear().append("cannot find added client");
         }
     }
 
     template<typename T>
-    void onPackage(ENetEvent& e, T& consumer) {
+    void onPacket(ENetEvent& e, T& consumer) {
         if(e.peer->data == nullptr) {
             error.clear().append("client without data sent package");
             return;
@@ -95,7 +101,7 @@ private:
         Client* client = clients.search(id);
         if(client != nullptr) {
             InPacket in(e.packet);
-            consumer.onPackage(*client, in);
+            consumer.onPacket(*client, in);
         } else {
             error.clear().append("client with invalid id sent package");
         }

+ 25 - 0
tests/HashMapTests.cpp

@@ -238,6 +238,30 @@ static void testRemove(Test& test) {
     }
 }
 
+static void testForEach(Test& test) {
+    IntMap map;
+    map.add(5, 4).add(10, 3).add(15, 2);
+
+    auto iter = map.begin();
+    test.checkEqual(true, iter != map.end(), "not at end 1");
+    ++iter;
+    test.checkEqual(true, iter != map.end(), "not at end 2");
+    ++iter;
+    test.checkEqual(true, iter != map.end(), "not at end 3");
+    ++iter;
+    test.checkEqual(false, iter != map.end(), "at end");
+
+    const IntMap& cmap = map;
+    auto citer = cmap.begin();
+    test.checkEqual(true, citer != cmap.end(), "not at end 1");
+    ++citer;
+    test.checkEqual(true, citer != cmap.end(), "not at end 2");
+    ++citer;
+    test.checkEqual(true, citer != cmap.end(), "not at end 3");
+    ++citer;
+    test.checkEqual(false, citer != cmap.end(), "at end");
+}
+
 void HashMapTests::test() {
     Test test("HashMap");
     testAdd(test);
@@ -255,5 +279,6 @@ void HashMapTests::test() {
     testMove(test);
     testMoveAssignment(test);
     testRemove(test);
+    testForEach(test);
     test.finalize();
 }

+ 28 - 3
tests/NetworkTests.cpp

@@ -10,6 +10,8 @@ static void sleep(int millis) {
     std::this_thread::sleep_for(std::chrono::milliseconds(millis));
 }
 
+static int packageCounter = 0;
+
 struct ServerConsumer {
     bool connected = false;
     bool disconnect = false;
@@ -25,7 +27,7 @@ struct ServerConsumer {
     int32 data9 = 0;
     StringBuffer<20> data10;
 
-    void onConnection(Server::Client& client) {
+    void onConnect(Server::Client& client) {
         (void)client;
         connected = true;
     }
@@ -35,7 +37,7 @@ struct ServerConsumer {
         disconnect = true;
     }
 
-    void onPackage(Server::Client& client, InPacket& in) {
+    void onPacket(Server::Client& client, InPacket& in) {
         (void)client;
         in.readU8(data1);
         in.readU16(data2);
@@ -47,10 +49,32 @@ struct ServerConsumer {
         in.readS16(data8);
         in.readS32(data9);
         in.readString(data10);
+
+        if(packageCounter == 0) {
+            OutPacket out = OutPacket::reliable(0);
+            client.send(out);
+        } else if(packageCounter == 1) {
+            OutPacket out = OutPacket::sequenced(0);
+            client.send(out);
+        } else if(packageCounter == 2) {
+            OutPacket out = OutPacket::unsequenced(0);
+            client.send(out);
+        }
+        packageCounter++;
     }
 };
 
-struct ClientConsumer {};
+struct ClientConsumer {
+    bool package = false;
+
+    void onDisconnect() {
+    }
+
+    void onPacket(InPacket& in) {
+        (void)in;
+        package = true;
+    }
+};
 
 static void testConnect(Test& test, OutPacket out) {
     Server server(54321, 5);
@@ -97,6 +121,7 @@ static void testConnect(Test& test, OutPacket out) {
         client.consumeEvents(clientConsumer);
     }
 
+    test.checkEqual(true, clientConsumer.package, "client has received data");
     test.checkEqual(true, serverConsumer.connected, "server has connection");
 
     test.checkEqual(static_cast<uint8>(0xF1), serverConsumer.data1,

+ 47 - 8
utils/HashMap.h

@@ -39,6 +39,45 @@ struct HashMap final {
         }
     };
 
+    template<typename N, typename R>
+    class BaseIterator {
+        N& nodes;
+        int indexA;
+        int indexB;
+
+    public:
+        BaseIterator(N& nodes, int indexA, int indexB)
+            : nodes(nodes), indexA(indexA), indexB(indexB) {
+            skip();
+        }
+
+        BaseIterator& operator++() {
+            indexB++;
+            skip();
+            return *this;
+        }
+
+        bool operator!=(const BaseIterator& other) const {
+            return indexA != other.indexA || indexB != other.indexB;
+        }
+
+        R& operator*() {
+            return nodes[indexA][indexB];
+        }
+
+    private:
+        void skip() {
+            while(indexA < nodes.getLength() &&
+                  indexB >= nodes[indexA].getLength()) {
+                indexA++;
+                indexB = 0;
+            }
+        }
+    };
+
+    typedef BaseIterator<List<List<Node>>, Node> Iterator;
+    typedef BaseIterator<const List<List<Node>>, const Node> ConstIterator;
+
 private:
     List<List<Node>> nodes;
     int elements;
@@ -118,20 +157,20 @@ public:
         return *this;
     }
 
-    Node* begin() {
-        return nodes.begin();
+    Iterator begin() {
+        return Iterator(nodes, 0, 0);
     }
 
-    Node* end() {
-        return nodes.end();
+    Iterator end() {
+        return Iterator(nodes, nodes.getLength(), 0);
     }
 
-    const Node* begin() const {
-        return nodes.begin();
+    ConstIterator begin() const {
+        return ConstIterator(nodes, 0, 0);
     }
 
-    const Node* end() const {
-        return nodes.end();
+    ConstIterator end() const {
+        return ConstIterator(nodes, nodes.getLength(), 0);
     }
 
     template<int L>