#include <iostream>
#include <cstring>
#include <vector>
#include <atomic>
#include <thread>
#include <mutex>

#include <sys/socket.h>
#include <netinet/in.h>
#include <unistd.h>
#include <poll.h>

#include "server/network/Server.h"

static void printError(const char* message) {
    std::cout << message << ": " << strerror(errno) << "\n";
}

struct ConnectedClient {

    ~ConnectedClient() {
        if(socket != -1) {
            if(shutdown(socket, SHUT_RDWR)) {
                printError("cannot shutdown client socket");
            }
            if(close(socket) == -1) {
                printError("cannot close client socket");
            }
        }
        if(th.joinable()) {
            th.join();
        } else {
            std::cout << "cannot join client connection thread\n";
        }
    }

    std::thread th = std::thread([]() {
    });
    int socket = -1;
};

struct InternServer final {
    InternServer() : listenerSocket(-1), listenerThread([]() {
    })
, shouldRun(true), clients(nullptr) {
    }

    ~InternServer() {
        shouldRun = false;
        listenerThread.join();

        if(listenerSocket != -1) {
            if(close(listenerSocket) == -1) {
                printError("cannot close listener socket");
            }
        }
        if(clients != nullptr) {
            delete[] clients;
        }
    }

    int listenerSocket;
    std::thread listenerThread;
    std::atomic_bool shouldRun;
    ConnectedClient* clients;
};

static InternServer server;

static void defaultFullServerClientConnect(int) {
    std::cout << "default onFullServerClientConnectFunction\n";
}

static void defaultClientConnect(int) {
    std::cout << "default onClientConnectFunction\n";
}

static void defaultClientPackage(int, Stream&) {
    std::cout << "default onClientPackageFunction\n";
}

static void defaultClientDisconnect(int) {
    std::cout << "default onClientDisconnectFunction\n";
}

static std::mutex clientMutex;
static u16 clientAmount = 0;
static u16 maxClients = 0;
static Server::FullServerClientConnectFunction onFullServerClientConnect = defaultFullServerClientConnect;
static Server::ClientConnectFunction onClientConnect = defaultClientConnect;
static Server::ClientPackageFunction onClientPackage = defaultClientPackage;
static Server::ClientDisconnectFunction onClientDisconnect = defaultClientDisconnect;

void Server::setFullServerClientConnectFunction(Server::FullServerClientConnectFunction f) {
    onFullServerClientConnect = f;
}

void Server::setClientConnectFunction(Server::ClientConnectFunction f) {
    onClientConnect = f;
}

void Server::setClientPackageFunction(Server::ClientPackageFunction f) {
    onClientPackage = f;
}

void Server::setClientDisconnectFunction(Server::ClientDisconnectFunction f) {
    onClientDisconnect = f;
}

static void listenOnClient(ConnectedClient& cc) {
    // poll data
    struct pollfd fds;
    fds.fd = cc.socket; // file descriptor for polling
    fds.events = POLLIN; // wait until data is ready to read
    fds.revents = 0; // return events - none

    onClientConnect(cc.socket);

    Stream st;
    while(server.shouldRun) {
        // nfds_t - 1 - amount of passed in structs
        // timeout - 100 - milliseconds to wait until an event occurs
        // returns 0 on timeout, -1 on error, and >0 on success
        int pollData = poll(&fds, 1, 100);
        if(pollData > 0) {
            st.readSocket(cc.socket);
            if(st.hasData()) {
                onClientPackage(cc.socket, st);
            } else {
                // client closed connection
                break;
            }
        } else if(pollData == -1) {
            printError("cannot poll from client");
            break;
        }
    }

    onClientDisconnect(cc.socket);

    // reset slot for another client
    if(server.shouldRun) {
        std::lock_guard<std::mutex> lg(clientMutex);
        if(close(cc.socket) == -1) {
            printError("cannot close socket of client");
        }
        cc.socket = -1;
        clientAmount--;
    }
}

static bool addClient(int clientSocket) {
    std::lock_guard<std::mutex> lg(clientMutex);
    if(clientAmount >= maxClients) {
        onFullServerClientConnect(clientSocket);
        return true;
    } else {
        // search for free slot
        uint16_t index = 0;
        while(index < maxClients) {
            if(server.clients[index].socket == -1) {
                break;
            }
            index++;
        }

        if(index >= maxClients) {
            std::cout << "cannot find free slot - even if there should be one\n";
            return true;
        }

        //ensure old thread has ended
        if(!server.clients[index].th.joinable()) {
            std::cout << "cannot join thread of non used client connection\n";
            return true;
        }
        server.clients[index].th.join();

        server.clients[index].socket = clientSocket;
        server.clients[index].th = std::thread(listenOnClient, std::ref(server.clients[index]));

        clientAmount++;

        return false;
    }
}

static void listenForClients() {
    while(server.shouldRun) {
        // wait until a connection arrives with timeout, this prevents being 
        // stuck in accept
        struct pollfd fds;
        fds.fd = server.listenerSocket; // file descriptor for polling
        fds.events = POLLIN; // wait until data is ready to read
        fds.revents = 0; // return events - none
        // nfds_t - 1 - amount of passed in structs
        // timeout - 100 - milliseconds to wait until an event occurs
        // returns 0 on timeout, -1 on error, and >0 on success
        int pollData = poll(&fds, 1, 100);
        if(pollData > 0) {
            struct sockaddr_in clientSocketData;
            // accepts an incoming client connection and stores the data in the
            // given struct, returns a nonnegative file descriptor on success
            socklen_t addrlen = sizeof (struct sockaddr_in);
            int clientSocket = accept(server.listenerSocket, (struct sockaddr*) &clientSocketData, &addrlen);
            if(clientSocket >= 0) {
                if(addClient(clientSocket)) {
                    if(close(clientSocket) == -1) {
                        printError("cannot close client socket");
                    }
                }
            } else {
                printError("accept error");
                break;
            }
        } else if(pollData == -1) {
            printError("poll error");
            break;
        }
    }
}

bool Server::start(u16 port, u16 inMaxClients) {
    // create socket for incoming connections
    // domain - AF_INET - IPv4 Internet protocols
    // type - SOCK_STREAM - two-way, connection-based byte streams
    // protocol - 0 - use standard protocol for the given socket type
    server.listenerSocket = socket(AF_INET, SOCK_STREAM, 0);
    if(server.listenerSocket == -1) {
        printError("cannot create listener socket");
        return false;
    }

    // prevents clients from blocking the port if the server exits
    // this is useful if server and client run on the same system
    struct linger sl;
    sl.l_onoff = 1; // nonzero to linger on close
    sl.l_linger = 0; // time to linger
    // sockfd - listenerSocket - modified socket
    // level - SOL_SOCKET - manipulate options at the sockets API level
    // optname - SO_LINGER - identifier of the option
    if(setsockopt(server.listenerSocket, SOL_SOCKET, SO_LINGER, &sl, sizeof (struct linger)) == -1) {
        printError("cannot set non lingering");
        return false;
    }

    // specify binding data
    struct sockaddr_in connectSocketData;
    // clear padding
    memset(&connectSocketData, 0, sizeof (struct sockaddr_in));
    // IPv4 Internet protocols
    connectSocketData.sin_family = AF_INET;
    // port in network byte order
    connectSocketData.sin_port = htons(port);
    // address in network byte order, accept any incoming messages
    connectSocketData.sin_addr.s_addr = htons(INADDR_ANY);
    // bind the socket
    if(bind(server.listenerSocket, (struct sockaddr*) &connectSocketData, sizeof (struct sockaddr_in)) == -1) {
        printError("cannot bind listener socket");
        return false;
    }

    // mark the socket as handler for connection requests
    // backlog - 5 - max queue length of pending connections
    if(listen(server.listenerSocket, 5) == -1) {
        printError("cannot start listening on socket");
        return false;
    }

    server.shouldRun = true;

    maxClients = inMaxClients;
    server.clients = new ConnectedClient[inMaxClients];
    // join empty spawn thread
    server.listenerThread.join();
    server.listenerThread = std::thread(listenForClients);
    return true;
}