#include <iostream>
#include <cstring>
#include <vector>
#include <atomic>
#include <thread>
#include <mutex>

#include <sys/socket.h>
#include <netinet/in.h>
#include <unistd.h>
#include <poll.h>

#include "server/network/Server.h"

static void printError(const char* message)
{
    std::cout << message << ": " << strerror(errno) << "\n";
}

struct ConnectedClient
{
    ~ConnectedClient()
    {
        if(socket != -1)
        {
            if(shutdown(socket, SHUT_RDWR))
            {
                printError("cannot shutdown client socket");
            }
            if(close(socket) == -1)
            {
                printError("cannot close client socket");
            }
        }
        if(th.joinable())
        {
            th.join();
        }
        else
        {
            std::cout << "cannot join client connection thread\n";
        }
    }
    
    std::thread th = std::thread([]() {});
    int socket = -1;
};

struct InternServer final
{
    InternServer() : listenerSocket(-1), listenerThread([](){}), shouldRun(true), clients(nullptr)
    {
    }
    
    ~InternServer()
    {
        shouldRun = false;
        listenerThread.join();

        if(listenerSocket != -1)
        {
            if(close(listenerSocket) == -1)
            {
                printError("cannot close listener socket");
            }
        }
        if(clients != nullptr)
        {
            delete[] clients;
        }
    }
    
    int listenerSocket;
    std::thread listenerThread;
    std::atomic_bool shouldRun;
    ConnectedClient* clients;
};

static InternServer server;

static void defaultFullServerClientConnect(int)
{
    std::cout << "default onFullServerClientConnectFunction\n";
}

static void defaultClientConnect(int)
{
    std::cout << "default onClientConnectFunction\n";
}

static void defaultClientPackage(int, Stream&)
{
    std::cout << "default onClientPackageFunction\n";
}

static void defaultClientDisconnect(int)
{
    std::cout << "default onClientDisconnectFunction\n";
}

static std::mutex clientMutex;
static u16 clientAmount = 0;
static u16 maxClients = 0;
static Server::FullServerClientConnectFunction onFullServerClientConnect = defaultFullServerClientConnect;
static Server::ClientConnectFunction onClientConnect = defaultClientConnect;
static Server::ClientPackageFunction onClientPackage = defaultClientPackage;
static Server::ClientDisconnectFunction onClientDisconnect = defaultClientDisconnect;

void Server::setFullServerClientConnectFunction(Server::FullServerClientConnectFunction f)
{
    onFullServerClientConnect = f;
}

void Server::setClientConnectFunction(Server::ClientConnectFunction f)
{
    onClientConnect = f;
}

void Server::setClientPackageFunction(Server::ClientPackageFunction f)
{
    onClientPackage = f;
}

void Server::setClientDisconnectFunction(Server::ClientDisconnectFunction f)
{
    onClientDisconnect = f;
}

static void listenOnClient(ConnectedClient& cc)
{
    // poll data
    struct pollfd fds;
    fds.fd = cc.socket; // file descriptor for polling
    fds.events = POLLIN; // wait until data is ready to read
    fds.revents = 0; // return events - none
    
    onClientConnect(cc.socket);
    
    Stream st;
    while(server.shouldRun)
    {
        // nfds_t - 1 - amount of passed in structs
        // timeout - 100 - milliseconds to wait until an event occurs
        // returns 0 on timeout, -1 on error, and >0 on success
        int pollData = poll(&fds, 1, 100);
        if(pollData > 0)
        {
            st.readSocket(cc.socket);
            if(st.hasData())
            {
                onClientPackage(cc.socket, st);
            }
            else
            {
                // client closed connection
                break;
            }
        }
        else if(pollData == -1)
        {
            printError("cannot poll from client");
            break;
        }
    }
    
    onClientDisconnect(cc.socket);
    
    // reset slot for another client
    if(server.shouldRun)
    {
        std::lock_guard<std::mutex> lg(clientMutex);
        if(close(cc.socket) == -1)
        {
            printError("cannot close socket of client");
        }
        cc.socket = -1;
        clientAmount--;
    }
}

static bool addClient(int clientSocket)
{
    std::lock_guard<std::mutex> lg(clientMutex);
    if(clientAmount >= maxClients)
    {
        onFullServerClientConnect(clientSocket);
        return true;
    }
    else
    {
        // search for free slot
        uint16_t index = 0;
        while(index < maxClients)
        {
            if(server.clients[index].socket == -1)
            {
                break;
            }
            index++;
        }
        
        if(index >= maxClients)
        {
            std::cout << "cannot find free slot - even if there should be one\n";
            return true;
        }

        //ensure old thread has ended
        if(!server.clients[index].th.joinable())
        {
            std::cout << "cannot join thread of non used client connection\n";
            return true;
        }
        server.clients[index].th.join();

        server.clients[index].socket = clientSocket;
        server.clients[index].th = std::thread(listenOnClient, std::ref(server.clients[index]));

        clientAmount++;

        return false;
    }
}

static void listenForClients()
{
    while(server.shouldRun)
    {
        // wait until a connection arrives with timeout, this prevents being 
        // stuck in accept
        struct pollfd fds;
        fds.fd = server.listenerSocket; // file descriptor for polling
        fds.events = POLLIN; // wait until data is ready to read
        fds.revents = 0; // return events - none
        // nfds_t - 1 - amount of passed in structs
        // timeout - 100 - milliseconds to wait until an event occurs
        // returns 0 on timeout, -1 on error, and >0 on success
        int pollData = poll(&fds, 1, 100);
        if(pollData > 0)
        {
            struct sockaddr_in clientSocketData;
            // accepts an incoming client connection and stores the data in the
            // given struct, returns a nonnegative file descriptor on success
            socklen_t addrlen = sizeof(struct sockaddr_in);
            int clientSocket = accept(server.listenerSocket, (struct sockaddr*) &clientSocketData, &addrlen);
            if(clientSocket >= 0)
            {
                if(addClient(clientSocket))
                {
                    if(close(clientSocket) == -1)
                    {
                        printError("cannot close client socket");
                    }
                }
            }
            else
            {
                printError("accept error");
                break;
            }
        }
        else if(pollData == -1)
        {
            printError("poll error");
            break;
        }
    }
}

bool Server::start(u16 port, u16 inMaxClients)
{
    // create socket for incoming connections
    // domain - AF_INET - IPv4 Internet protocols
    // type - SOCK_STREAM - two-way, connection-based byte streams
    // protocol - 0 - use standard protocol for the given socket type
    server.listenerSocket = socket(AF_INET, SOCK_STREAM, 0);
    if(server.listenerSocket == -1)
    {
        printError("cannot create listener socket");
        return false;
    }
    
    // prevents clients from blocking the port if the server exits
    // this is useful if server and client run on the same system
    struct linger sl;
    sl.l_onoff = 1; // nonzero to linger on close
    sl.l_linger = 0; // time to linger
    // sockfd - listenerSocket - modified socket
    // level - SOL_SOCKET - manipulate options at the sockets API level
    // optname - SO_LINGER - identifier of the option
    if(setsockopt(server.listenerSocket, SOL_SOCKET, SO_LINGER, &sl, sizeof(struct linger)) == -1)
    {
        printError("cannot set non lingering");
        return false;
    }
    
    // specify binding data
    struct sockaddr_in connectSocketData;
    // clear padding
    memset(&connectSocketData, 0, sizeof(struct sockaddr_in));
    // IPv4 Internet protocols
    connectSocketData.sin_family = AF_INET; 
    // port in network byte order
    connectSocketData.sin_port = htons(port); 
    // address in network byte order, accept any incoming messages
    connectSocketData.sin_addr.s_addr = htons(INADDR_ANY); 
    // bind the socket
    if(bind(server.listenerSocket, (struct sockaddr*) &connectSocketData, sizeof(struct sockaddr_in)) == -1)
    {
        printError("cannot bind listener socket");
        return false;
    }
    
    // mark the socket as handler for connection requests
    // backlog - 5 - max queue length of pending connections
    if(listen(server.listenerSocket, 5) == -1)
    {
        printError("cannot start listening on socket");
        return false;
    }

    server.shouldRun = true;    
    
    maxClients = inMaxClients;
    server.clients = new ConnectedClient[inMaxClients];
    // join empty spawn thread
    server.listenerThread.join();
    server.listenerThread = std::thread(listenForClients);
    return true;
}