"""Base for MoreThanTest test cases."""

from asyncio import create_subprocess_exec, gather, sleep
from pathlib import Path
from socket import socket
from unittest import IsolatedAsyncioTestCase
from aiohttp import ClientSession


LOCALHOST = "127.0.0.1"
SESSION_KEY = "sessionid"
HOST = "example.com"


class Server:
    """Setup and run servers."""

    def __init__(self, *args):
        """Initialize class"""
        app = Path.cwd().joinpath("target", "release", "morethantext")
        addr = "127.0.0.1"
        port = 3000
        if args:
            self.cmd = list(args)
            self.cmd.insert(0, app)
            get_port = False
            get_addr = False
            for item in args:
                if get_port:
                    port = item
                    get_port = False
                if get_addr:
                    addr = item
                    get_addr = False
                if item in ("-a", "--address"):
                    get_addr = True
                if item in ("-p", "--port"):
                    get_port = True
        else:
            self.cmd = [app]
        self.server = None
        self.host = f"http://{addr}:{port}"

    async def create(self):
        """Cerate the server"""
        self.server = await create_subprocess_exec(*self.cmd)
        await sleep(1)

    async def destroy(self):
        """destroy servers"""
        self.server.terminate()
        await self.server.wait()


class MTTClusterTC(IsolatedAsyncioTestCase):
    """Test case for MoreThanTText."""

    async def asyncSetUp(self):
        """Test setup"""
        self.servers = []
        self.cookies = {}
        self.session = ClientSession()

    async def asyncTearDown(self):
        """Test tear down."""
        await self.session.close()
        for server in self.servers:
            await server.destroy()

    @staticmethod
    async def get_port():
        """Retrieve an unused port."""
        sock = socket()
        sock.bind((LOCALHOST, 0))
        port = sock.getsockname()[1]
        sock.close()
        return port

    async def create_server_with_flags(self, *args):
        """Create a single server with flags."""
        server = Server(*args)
        await server.create()
        self.servers.append(server)

    async def create_server(self):
        """Create a server on a random port."""
        port = await self.get_port()
        await self.create_server_with_flags("-p", str(port))

    async def create_cluster(self, num=2):
        """Create a cluster of servers."""
        ports = []
        while len(ports) < num:
            port = await self.get_port()
            if port not in ports:
                ports.append(port)
        servers = []
        for port in ports:
            servers.append(self.create_server_with_flags("-p", str(port)))
        cluster = gather(*servers)
        await cluster

    async def run_tests(self, uri, func):
        """Run the tests on each server."""
        for server in self.servers:
            async with self.session.get(
                f"{server.host}{uri}", cookies=self.cookies
            ) as response:
                if SESSION_KEY in response.cookies:
                    self.cookies[SESSION_KEY] = response.cookies[SESSION_KEY].value
                func(response)