#include <utility>

#include "network/Packet.h"

InPacket::InPacket(ENetPacket* packet) : packet(packet), index(0) {
}

bool InPacket::read(void* buffer, unsigned int length) {
    if(index + length > packet->dataLength) {
        return true;
    }
    memcpy(buffer, packet->data + index, length);
    index += length;
    return false;
}

bool InPacket::readU8(uint8& u) {
    return read(&u, sizeof(u));
}

bool InPacket::readU16(uint16& u) {
    if(read(&u, sizeof(u))) {
        return true;
    }
    u = ntohs(u);
    return false;
}

bool InPacket::readU32(uint32& u) {
    if(read(&u, sizeof(u))) {
        return true;
    }
    u = ntohl(u);
    return false;
}

bool InPacket::readS8(int8& s) {
    uint8 u;
    if(readU8(u)) {
        return true;
    }
    if(u < 128) {
        s = static_cast<int8>(u) - 128;
    } else {
        s = u - 128;
    }
    return false;
}

bool InPacket::readS16(int16& s) {
    uint16 u;
    if(readU16(u)) {
        return true;
    }
    if(u < 32768) {
        s = static_cast<int16>(u) - 32768;
    } else {
        s = u - 32768;
    }
    return false;
}

bool InPacket::readS32(int32& s) {
    uint32 u;
    if(readU32(u)) {
        return true;
    }
    if(u < 2147483648) {
        s = static_cast<int32>(u) - 2147483648;
    } else {
        s = u - 2147483648;
    }
    return false;
}

OutPacket::OutPacket(unsigned int size, int flags, int channel)
    : packet(enet_packet_create(nullptr, size, flags)), index(0),
      channel(channel) {
}

OutPacket OutPacket::reliable(unsigned int size) {
    return OutPacket(size, ENET_PACKET_FLAG_RELIABLE, 0);
}

OutPacket OutPacket::sequenced(unsigned int size) {
    return OutPacket(size, 0, 1);
}

OutPacket OutPacket::unsequenced(unsigned int size) {
    return OutPacket(size, ENET_PACKET_FLAG_UNSEQUENCED, 2);
}

OutPacket::~OutPacket() {
    enet_packet_destroy(packet);
}

OutPacket::OutPacket(const OutPacket& other)
    : packet(enet_packet_copy(other.packet)), index(other.index) {
}

OutPacket::OutPacket(OutPacket&& other) : packet(nullptr), index(0) {
    std::swap(packet, other.packet);
    std::swap(index, other.index);
}

OutPacket& OutPacket::operator=(OutPacket other) {
    std::swap(packet, other.packet);
    std::swap(index, other.index);
    return *this;
}

void OutPacket::write(const void* buffer, unsigned int length) {
    if(packet == nullptr || index + length > packet->dataLength) {
        return;
    }
    memcpy(packet->data + index, buffer, length);
    index += length;
}

void OutPacket::writeU8(uint8 u) {
    write(&u, sizeof(u));
}

void OutPacket::writeU16(uint16 u) {
    u = htons(u);
    write(&u, sizeof(u));
}

void OutPacket::writeU32(uint32 u) {
    u = htonl(u);
    write(&u, sizeof(u));
}

void OutPacket::writeS8(int8 s) {
    if(s < 0) {
        writeU8(s + 128);
    } else {
        writeU8(static_cast<uint8>(s) + 128);
    }
}

void OutPacket::writeS16(int16 s) {
    if(s < 0) {
        writeU16(s + 32768);
    } else {
        writeU16(static_cast<uint16>(s) + 32768);
    }
}

void OutPacket::writeS32(int32 s) {
    if(s < 0) {
        writeU32(s + 2147483648);
    } else {
        writeU32(static_cast<uint32>(s) + 2147483648);
    }
}