#include "libs/enet/include/enet.h"

#include "network/Packet.h"

static_assert(sizeof(float) == sizeof(uint32),
              "sizeof(float) != sizeof(uint32)");

InPacket::InPacket(const void* data, int size)
    : data(static_cast<const char*>(data)), size(size), index(0) {
}

bool InPacket::read(void* buffer, int length) {
    if(index + length > size) {
        return true;
    }
    memcpy(buffer, data + index, length);
    index += length;
    return false;
}

bool InPacket::read(uint8& u) {
    return read(&u, sizeof(u));
}

bool InPacket::read(uint16& u) {
    if(read(&u, sizeof(u))) {
        return true;
    }
    u = ntohs(u);
    return false;
}

bool InPacket::read(uint32& u) {
    if(read(&u, sizeof(u))) {
        return true;
    }
    u = ntohl(u);
    return false;
}

bool InPacket::read(int8& s) {
    uint8 u;
    if(read(u)) {
        return true;
    }
    if(u < 128) {
        s = static_cast<int8>(u) - 128;
    } else {
        s = u - 128;
    }
    return false;
}

bool InPacket::read(int16& s) {
    uint16 u;
    if(read(u)) {
        return true;
    }
    if(u < 32768) {
        s = static_cast<int16>(u) - 32768;
    } else {
        s = u - 32768;
    }
    return false;
}

bool InPacket::read(int32& s) {
    uint32 u;
    if(read(u)) {
        return true;
    }
    if(u < 2147483648) {
        s = static_cast<int32>(u) - 2147483648;
    } else {
        s = u - 2147483648;
    }
    return false;
}

bool InPacket::read(float& f) {
    uint32 u;
    if(read(u)) {
        return true;
    }
    memcpy(&f, &u, sizeof(float));
    return false;
}

OutPacket::OutPacket(int initialSize) : buffer(initialSize) {
}

OutPacket& OutPacket::writeU8(uint8 u) {
    buffer.add(u);
    return *this;
}

OutPacket& OutPacket::writeU16(uint16 u) {
    u = htons(u);
    buffer.add(u);
    return *this;
}

OutPacket& OutPacket::writeU32(uint32 u) {
    u = htonl(u);
    buffer.add(u);
    return *this;
}

OutPacket& OutPacket::writeS8(int8 s) {
    if(s < 0) {
        return writeU8(s + 128);
    }
    return writeU8(static_cast<uint8>(s) + 128u);
}

OutPacket& OutPacket::writeS16(int16 s) {
    if(s < 0) {
        return writeU16(s + 32768);
    }
    return writeU16(static_cast<uint16>(s) + 32768u);
}

OutPacket& OutPacket::writeS32(int32 s) {
    if(s < 0) {
        return writeU32(s + 2147483648);
    }
    return writeU32(static_cast<uint32>(s) + 2147483648u);
}

OutPacket& OutPacket::writeFloat(float f) {
    uint32 u;
    memcpy(&u, &f, sizeof(float));
    return writeU32(u);
}