#include <cstring>
#include <iostream>
#include <thread>

#include <errno.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <poll.h>

#include "DataVector.h"

DataVector::DataVector(size_t capacity) : capacity(capacity), data(new char[capacity]) {
    memset(data, 0, capacity);
}

DataVector::DataVector(const DataVector& orig) : capacity(orig.capacity), data(new char[capacity]) {
    memcpy(data, orig.data, capacity);
}

DataVector::DataVector(DataVector&& other) : capacity(other.capacity), data(other.data) {
    other.capacity = 0;
    other.data = nullptr;
}

DataVector::~DataVector() {
    delete[] data;
}

DataVector& DataVector::operator=(const DataVector& other) {
    if(this != &other) {
        delete[] data;
        capacity = other.capacity;
        data = new char[capacity];
        memcpy(data, other.data, capacity);
    }
    return *this;
}

DataVector& DataVector::operator=(DataVector&& other) {
    if(this != &other) {
        delete[] data;
        capacity = other.capacity;
        data = other.data;
        other.capacity = 0;
        other.data = nullptr;
    }
    return *this;
}

bool DataVector::read(size_t fromIndex, void* buffer, size_t length) const {
    if(fromIndex + length >= capacity) {
        return false;
    }
    memcpy(buffer, data + fromIndex, length);
    return true;
}

bool DataVector::write(size_t toIndex, const void* writeData, size_t length) {
    if(toIndex + length >= capacity) {
        return false;
    }
    memcpy(data + toIndex, writeData, length);
    return true;
}

bool DataVector::readSocket(int socket, size_t& readBytes) {
    readBytes = 0;

    uint32_t packetSize = 0;
    ssize_t readLength = recv(socket, &packetSize, 4, 0);
    if(readLength != 4) {
        return false;
    }
    packetSize = ntohl(packetSize);
    if(packetSize > capacity) {
        return false;
    }

    size_t bytesLeft = packetSize;

    while(true) {
        ssize_t readLength = recv(socket, data + readBytes, bytesLeft, MSG_DONTWAIT);
        if(readLength < 0) // an error occurred
        {
            if(errno == EAGAIN || errno == EWOULDBLOCK) {
                struct pollfd fds;
                fds.fd = socket;
                fds.events = POLLIN;
                fds.revents = 0;
                if(poll(&fds, 1, 3000) <= 0) {
                    // client took to long to send the full packet
                    readBytes = 0;
                    return false;
                }
                continue;
            }
            // a real error occurred
            perror("cannot receive data");
            return true;
        } else if(readLength == 0) // socket closed / shutdown
        {
            return true;
        } else {
            readBytes += readLength;
            bytesLeft -= readLength;
            if(bytesLeft == 0) // packet fully read
            {
                return true;
            }
        }
    }
}

void DataVector::sendToSocket(int socket, size_t toIndex) const {
    size_t bufferOffset = 0;
    size_t sendLength = toIndex;

    while(sendLength > 0) {
        ssize_t writtenLength = send(socket, data + bufferOffset, sendLength, MSG_NOSIGNAL);
        if(writtenLength == -1) {
            perror("cannot send data");
            return;
        }
        sendLength -= writtenLength;
        bufferOffset += writtenLength;
    }
}