#include "core/Network.h" #define ENET_IMPLEMENTATION #include #include #include #include #include #include #include "ErrorSimulator.h" // HashMap clients; // Client -> ENetPeer* HASHMAP(Client, ENetPeer*, Client) #define equalClient equalInt #define isInvalidKeyClient isInvalidKeyInt #define hashClient hashInt HASHMAP_SOURCE(Client, ENetPeer*, Client) void initInPacket(InPacket* in, const void* data, size_t n) { in->data = data; in->size = n; in->index = 0; } bool readInPacketU8(InPacket* in, u8* u) { return readInPacket(in, u, sizeof(*u)); } bool readInPacketU16(InPacket* in, u16* u) { if(readInPacket(in, u, sizeof(*u))) { return true; } *u = ntohs(*u); return false; } bool readInPacketU32(InPacket* in, u32* u) { if(readInPacket(in, u, sizeof(*u))) { return true; } *u = ntohl(*u); return false; } bool readInPacketI8(InPacket* in, i8* i) { u8 u; if(readInPacketU8(in, &u)) { return true; } *i = (i8)((i32)u - (i32)128); return false; } bool readInPacketI16(InPacket* in, i16* i) { u16 u; if(readInPacketU16(in, &u)) { return true; } *i = (i16)((i32)u - (i32)32'768); return false; } bool readInPacketI32(InPacket* in, i32* i) { u32 u; if(readInPacketU32(in, &u)) { return true; } if(u < 2'147'483'648) { *i = (i32)((i32)u - (i32)2'147'483'648); } else { *i = (i32)(u - (u32)2'147'483'648); } return false; } bool readInPacketFloat(InPacket* in, float* f) { u32 u; static_assert(sizeof(u) == sizeof(*f), "float and u32 size do not match"); if(readInPacketU32(in, &u)) { return true; } memcpy(f, &u, sizeof(float)); return false; } size_t readInPacketString(InPacket* in, char* buffer, size_t n) { if(n == 0) { return 0; } u16 size; if(readInPacketU16(in, &size)) { return 0; } size_t end = size; char* bufferStart = buffer; n--; while(n-- > 0 && end > 0) { end--; u8 u; if(readInPacketU8(in, &u)) { *bufferStart = '\0'; break; } *(buffer++) = (char)u; } while(end-- > 0 && !readInPacketU8(in, &(u8){0})) {} *buffer = '\0'; return size; } bool readInPacket(InPacket* in, void* buffer, size_t n) { if(in->index + n > in->size) { return true; } memcpy(buffer, in->data + in->index, n); in->index += n; return false; } void initOutPacket(OutPacket* out) { initBuffer(&out->data); } void destroyOutPacket(OutPacket* out) { destroyBuffer(&out->data); } void writeOutPacketU8(OutPacket* out, u8 u) { addSizedBufferData(&out->data, &u, sizeof(u)); } void writeOutPacketU16(OutPacket* out, u16 u) { u = htons(u); addSizedBufferData(&out->data, &u, sizeof(u)); } void writeOutPacketU32(OutPacket* out, u32 u) { u = htonl(u); addSizedBufferData(&out->data, &u, sizeof(u)); } void writeOutPacketI8(OutPacket* out, i8 i) { writeOutPacketU8(out, (u8)((i32)i + (i32)128)); } void writeOutPacketI16(OutPacket* out, i16 i) { writeOutPacketU16(out, (u16)((i32)i + (i32)32'768)); } void writeOutPacketI32(OutPacket* out, i32 i) { if(i < 0) { writeOutPacketU32(out, (u32)((i + (i32)2'147'483'647) + (i32)1)); } else { writeOutPacketU32(out, (u32)((u32)i + (u32)2'147'483'648)); } } void writeOutPacketFloat(OutPacket* out, float f) { u32 u; static_assert(sizeof(u) == sizeof(f), "float and u32 size do not match"); memcpy(&u, &f, sizeof(float)); writeOutPacketU32(out, u); } void writeOutPacketString(OutPacket* out, const char* s) { size_t marker = out->data.size; writeOutPacketU16(out, 0); size_t end = 0; while(end < 65'534 && *s != '\0') { writeOutPacketU8(out, (u8)(*(s++))); end++; } writeOutPacketU8(out, 0); end++; size_t endMarker = out->data.size; out->data.size = marker; writeOutPacketU16(out, (u16)end); out->data.size = endMarker; } void writeOutPacket(OutPacket* out, const void* buffer, size_t n) { addSizedBufferData(&out->data, buffer, n); } static int enetCounter = 0; static bool addENet(void) { if(enetCounter == 0 && FAIL(enet_initialize() != 0, true)) { return true; } enetCounter++; return false; } static void removeENet(void) { if(enetCounter > 0 && --enetCounter == 0) { enet_deinitialize(); } } static_assert(sizeof(enet_uint16) == sizeof(Port), "port has wrong type"); static void voidVoidDummy(void) { } static void voidInPacketDummy(InPacket*) { } typedef struct { ENetHost* client; ENetPeer* connection; OnServerConnect onConnect; OnServerDisconnect onDisconnect; OnServerPacket onPacket; int connectTicks; int connectTimeoutTicks; int disconnectTicks; int disconnectTimeoutTicks; } ClientData; static ClientData client = { nullptr, nullptr, voidVoidDummy, voidVoidDummy, voidInPacketDummy, 0, 0, 0, 0}; bool startClient(void) { if(client.client != nullptr) { LOG_WARNING("Client already started"); return true; } else if(addENet()) { LOG_ERROR("Client cannot initialize enet"); return true; } client.client = FAIL(enet_host_create(nullptr, 1, 2, 0, 0), nullptr); if(client.client == nullptr) { stopClient(); LOG_ERROR("Cannot create enet client host"); return true; } return false; } void stopClient(void) { if(client.connection != nullptr) { client.onDisconnect(); FAIL( enet_peer_disconnect_now(client.connection, 0), enet_peer_reset(client.connection)); client.connection = nullptr; } if(client.client != nullptr) { enet_host_destroy(client.client); client.client = nullptr; } removeENet(); client.connectTicks = 0; client.disconnectTicks = 0; } bool connectClient(const char* server, Port port, int timeoutTicks) { if(client.client == nullptr) { LOG_WARNING("Client not started"); return true; } else if(client.connection != nullptr) { LOG_WARNING("Connection already exists"); return true; } ENetAddress address = {0}; enet_address_set_host(&address, server); address.port = port; client.connection = FAIL(enet_host_connect(client.client, &address, 3, 0), nullptr); if(client.connection == nullptr) { LOG_ERROR("Cannot create connection"); return true; } client.connectTicks = 1; client.connectTimeoutTicks = timeoutTicks; return false; } void setClientTimeout(u32 timeout, u32 timeoutMin, u32 timeoutMax) { if(client.connection != nullptr) { enet_peer_timeout(client.connection, timeout, timeoutMin, timeoutMax); } } void disconnectClient(int timeoutTicks) { if(client.connection == nullptr) { return; } client.connectTicks = 0; enet_peer_disconnect(client.connection, 0); client.disconnectTicks = 1; client.disconnectTimeoutTicks = timeoutTicks; } void sendClientPacket(const OutPacket* p, PacketSendMode mode) { if(client.client == nullptr || client.connection == nullptr || client.connectTicks >= 0) { return; } static const enet_uint32 flags[] = { ENET_PACKET_FLAG_RELIABLE, 0, ENET_PACKET_FLAG_UNSEQUENCED}; enet_uint8 i = (enet_uint8)mode; enet_peer_send( client.connection, i, enet_packet_create(p->data.buffer, p->data.size, flags[i])); } static void tickClientEvents(void) { ENetEvent e; while(enet_host_service(client.client, &e, 0) >= 0) { switch(e.type) { case ENET_EVENT_TYPE_CONNECT: client.connectTicks = -1; client.onConnect(); break; case ENET_EVENT_TYPE_DISCONNECT_TIMEOUT: case ENET_EVENT_TYPE_DISCONNECT: client.disconnectTicks = 0; client.connectTicks = 0; client.onDisconnect(); client.connection = nullptr; break; case ENET_EVENT_TYPE_RECEIVE: { InPacket in; initInPacket(&in, e.packet->data, e.packet->dataLength); client.onPacket(&in); enet_packet_destroy(e.packet); break; } case ENET_EVENT_TYPE_NONE: return; } } } void tickClient(void) { if(client.client == nullptr) { return; } tickClientEvents(); if(client.connectTicks >= 1 && ++client.connectTicks > client.connectTimeoutTicks) { client.connectTicks = 0; disconnectClient(client.connectTimeoutTicks); } if(client.disconnectTicks >= 1 && ++client.disconnectTicks > client.disconnectTimeoutTicks) { client.disconnectTicks = 0; client.onDisconnect(); if(client.connection != nullptr) { enet_peer_reset(client.connection); client.connection = nullptr; } } } void setClientConnectHandler(OnServerConnect oc) { client.onConnect = oc == nullptr ? voidVoidDummy : oc; } void setClientDisconnectHandler(OnServerDisconnect od) { client.onDisconnect = od == nullptr ? voidVoidDummy : od; } void setClientPacketHandler(OnServerPacket op) { client.onPacket = op == nullptr ? voidInPacketDummy : op; } void resetClientHandler(void) { client.onConnect = voidVoidDummy; client.onDisconnect = voidVoidDummy; client.onPacket = voidInPacketDummy; } bool isClientConnecting(void) { return client.connectTicks >= 1; } bool isClientConnected(void) { return client.connectTicks < 0; } static void voidClientDummy(Client) { } static void voidClientInPacketDummy(Client, InPacket*) { } typedef struct { ENetHost* server; HashMapClient clients; Client idCounter; OnClientConnect onConnect; OnClientDisconnect onDisconnect; OnClientPacket onPacket; } ServerData; static ServerData server = { nullptr, {0}, 1, voidClientDummy, voidClientDummy, voidClientInPacketDummy}; bool startServer(Port port, size_t maxClients) { if(maxClients <= 0) { LOG_ERROR("Invalid max client amount"); return true; } else if(server.server != nullptr) { LOG_WARNING("Server already started"); return true; } else if(addENet()) { LOG_ERROR("Server cannot initialize enet"); return true; } ENetAddress address = {.host = ENET_HOST_ANY, .port = port}; server.server = FAIL(enet_host_create(&address, maxClients, 3, 0, 0), nullptr); if(server.server == nullptr) { stopServer(); LOG_ERROR("Cannot create enet server host"); return true; } initHashMapClient(&server.clients); return false; } void stopServer(void) { if(server.server != nullptr) { HashMapIteratorClient i; initHashMapIteratorClient(&i, &server.clients); while(hasNextHashMapNodeClient(&i)) { HashMapNodeClient* n = nextHashMapNodeClient(&i); enet_peer_reset(*n->value); } enet_host_destroy(server.server); server.server = nullptr; destroyHashMapClient(&server.clients); } removeENet(); } static void writeId(ENetPeer* peer, Client id) { static_assert( sizeof(peer->data) >= sizeof(id), "private data not big enough for id"); memcpy(&(peer->data), &id, sizeof(id)); } static Client getId(ENetPeer* peer) { assert(peer->data != nullptr); Client id = -1; memcpy(&id, &(peer->data), sizeof(id)); return id; } static void handleConnect(ENetEvent* e) { Client id = server.idCounter++; assert(searchHashMapKeyClient(&server.clients, id) == nullptr); *putHashMapKeyClient(&server.clients, id) = e->peer; writeId(e->peer, id); server.onConnect(id); } static void handlePacket(ENetEvent* e) { Client id = getId(e->peer); InPacket in; initInPacket(&in, e->packet->data, e->packet->dataLength); server.onPacket(id, &in); } static void handleDisconnect(ENetEvent* e) { Client id = getId(e->peer); server.onDisconnect(id); removeHashMapKeyClient(&server.clients, id); } void tickServer(void) { if(server.server == nullptr) { return; } ENetEvent e; while(enet_host_service(server.server, &e, 0) >= 0) { switch(e.type) { case ENET_EVENT_TYPE_CONNECT: handleConnect(&e); break; case ENET_EVENT_TYPE_RECEIVE: handlePacket(&e); enet_packet_destroy(e.packet); break; case ENET_EVENT_TYPE_DISCONNECT_TIMEOUT: case ENET_EVENT_TYPE_DISCONNECT: handleDisconnect(&e); break; case ENET_EVENT_TYPE_NONE: return; } } } static ENetPacket* fromBuffer(const Buffer* buffer, enet_uint8 index) { static const enet_uint32 flags[] = { ENET_PACKET_FLAG_RELIABLE, 0, ENET_PACKET_FLAG_UNSEQUENCED}; return enet_packet_create(buffer->buffer, buffer->size, flags[index]); } void sendServerPacketBroadcast(const OutPacket* p, PacketSendMode mode) { if(server.server != nullptr) { enet_uint8 index = (enet_uint8)mode; enet_host_broadcast(server.server, index, fromBuffer(&p->data, index)); } } void sendServerPacket( Client clientId, const OutPacket* p, PacketSendMode mode) { if(server.server == nullptr) { return; } ENetPeer** peer = searchHashMapKeyClient(&server.clients, clientId); if(peer != nullptr) { enet_uint8 index = (enet_uint8)mode; enet_peer_send(*peer, index, fromBuffer(&p->data, index)); } } void setServerTimeout( Client clientId, u32 timeout, u32 timeoutMin, u32 timeoutMax) { if(server.server == nullptr) { return; } ENetPeer** peer = searchHashMapKeyClient(&server.clients, clientId); if(peer != nullptr) { enet_peer_timeout(*peer, timeout, timeoutMin, timeoutMax); } } void disconnectServerClient(Client clientId) { if(server.server == nullptr) { return; } ENetPeer** peer = searchHashMapKeyClient(&server.clients, clientId); if(peer != nullptr) { enet_peer_disconnect(*peer, 0); } } void setServerConnectHandler(OnClientConnect oc) { server.onConnect = oc == nullptr ? voidClientDummy : oc; } void setServerDisconnectHandler(OnClientDisconnect od) { server.onDisconnect = od == nullptr ? voidClientDummy : od; } void setServerPacketHandler(OnClientPacket op) { server.onPacket = op == nullptr ? voidClientInPacketDummy : op; } void resetServerHandler(void) { server.onConnect = voidClientDummy; server.onDisconnect = voidClientDummy; server.onPacket = voidClientInPacketDummy; }