import uuid
from datetime import datetime
from decimal import Decimal

from sqlalchemy import (
    BigInteger,
    DateTime,
    Engine,
    Enum,
    Float,
    MetaData,
    NullPool,
    String,
    create_engine,
)
from sqlalchemy.ext.asyncio import AsyncAttrs, AsyncSession
from sqlalchemy.ext.asyncio import async_sessionmaker as sa_async_sessionmaker
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
from sqlalchemy.orm import sessionmaker as sa_sync_sessionmaker
from sqlalchemy.types import TypeDecorator, CHAR, DECIMAL

from .settings import CoreSettings
from . import enums

# pylint: disable=import-outside-toplevel,unused-import,wrong-import-order
import aiomysql

# pylint: enable=import-outside-toplevel,unused-import,wrong-import-order


def get_sessionmaker(
    conn_string: str, is_async: bool = False, is_testing: bool = False
):
    if is_async:
        engine = (
            create_async_engine(conn_string, pool_recycle=1800, pool_pre_ping=True)
            if not is_testing
            else create_async_engine(conn_string, poolclass=NullPool)
        )
        sessionmaker = sa_async_sessionmaker(
            engine, expire_on_commit=False, autoflush=False
        )
    else:
        engine = (
            create_engine(conn_string, pool_recycle=1800, pool_pre_ping=True)
            if not is_testing
            else create_engine(conn_string, poolclass=NullPool)
        )
        sessionmaker = sa_sync_sessionmaker(
            engine, expire_on_commit=False, autoflush=False
        )
    return engine, sessionmaker


class DBConfig:
    def __init__(self, settings: CoreSettings):
        base_conn_string = (
            f"{settings.db_username}:{settings.db_password}@"
            f"{settings.db_host}:{settings.db_port}/{settings.db_database}"
        )
        self._conn_string = f"mysql+pymysql://{base_conn_string}"
        self._async_conn_string = f"mysql+aiomysql://{base_conn_string}"

        _engine, _sessionmaker = get_sessionmaker(
            self._conn_string, is_async=False, is_testing=settings.test_mode
        )
        self._engine = _engine
        self._sessionmaker = _sessionmaker

        _async_engine, _async_sessionmaker = get_sessionmaker(
            self._async_conn_string, is_async=True, is_testing=settings.test_mode
        )
        self._async_engine = _async_engine
        self._async_sessionmaker = _async_sessionmaker

    @property
    def conn_string(self) -> str:
        return self._conn_string

    @property
    def async_conn_string(self) -> str:
        return self._async_conn_string

    @property
    def engine(self) -> Engine:
        return self._engine

    # pylint: disable=unsubscriptable-object
    @property
    def sessionmaker(self) -> sa_sync_sessionmaker[Session]:
        return self._sessionmaker

    # pylint: enable=unsubscriptable-object

    @property
    def async_engine(self) -> Engine:
        return self._async_engine

    @property
    def async_sessionmaker(self) -> sa_async_sessionmaker[AsyncSession]:
        return self._async_sessionmaker


class BaseModelMixin:
    pk: Mapped[int] = mapped_column(BigInteger, primary_key=True)


class BaseModel(BaseModelMixin, AsyncAttrs, DeclarativeBase):
    metadata = MetaData(
        naming_convention={
            "ix": "ix_%(column_0_label)s",
            "uq": "uq_%(table_name)s_%(column_0_name)s",
            "ck": "ck_%(table_name)s_%(constraint_name)s",
            "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
            "pk": "pk_%(table_name)s",
        }
    )


class GUID(TypeDecorator):
    """Platform-independent GUID type.

    Uses CHAR(36) for MySQL, which is the most efficient way to store UUIDs in MySQL.
    """

    impl = CHAR(36)

    def process_bind_param(self, value, dialect):
        if value is None:
            return value
        if dialect.name == "mysql":
            return value.hex
        return str(value)

    def process_result_value(self, value, dialect):
        if value is None:
            return value
        if dialect.name == "mysql":
            return uuid.UUID(hex=value)
        return uuid.UUID(value)

    def process_literal_param(self, value, dialect):
        if value is None:
            return None
        return str(value)

    @property
    def python_type(self):
        return uuid.UUID


class RatioField(TypeDecorator):
    impl = DECIMAL(16, 15)
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        value = Decimal(value)
        if not (Decimal("-1") <= value <= Decimal("1")):
            raise ValueError("RatioField value must be between -1 and 1.")
        return value

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return Decimal(value)

    def process_literal_param(self, value, dialect):
        if value is None:
            return None
        return str(Decimal(value))

    @property
    def python_type(self):
        return Decimal


class PercentageField(TypeDecorator):
    impl = DECIMAL(21, 15)
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        value = Decimal(value)
        if not (
            Decimal("-999999.999999999999999")
            <= value
            <= Decimal("999999.999999999999999")
        ):
            raise ValueError(
                "PercentageField value must be between -999,999.999999999999999 and 999,999.999999999999999."
            )
        return value

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return Decimal(value)

    def process_literal_param(self, value, dialect):
        if value is None:
            return None
        return str(Decimal(value))

    @property
    def python_type(self):
        return Decimal


class NominalField(TypeDecorator):
    impl = DECIMAL(30, 15)
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        value = Decimal(value)
        if not (
            Decimal("-999999999999999.999999999999999")
            <= value
            <= Decimal("999999999999999.999999999999999")
        ):
            raise ValueError(
                "NominalField value must be between -999,999,999,999,999.999999999999999 and 999,999,999,999,999.999999999999999."
            )
        return value

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return Decimal(value)

    def process_literal_param(self, value, dialect):
        if value is None:
            return None
        return str(Decimal(value))

    @property
    def python_type(self):
        return Decimal


class SFTPFile(str):
    def __new__(
        cls,
        value: str,
        path_prefix: str,
        accepted_formats: list[str],
        max_file_size: int,
    ):
        obj = str.__new__(cls, value)
        obj._path_prefix = path_prefix
        obj._accepted_formats = accepted_formats
        obj._max_file_size = max_file_size
        return obj

    @property
    def path_prefix(self) -> str:
        return self._path_prefix

    @property
    def accepted_formats(self) -> list[str]:
        return self._accepted_formats

    @property
    def max_file_size(self) -> int:
        return self._max_file_size

    def __repr__(self) -> str:
        return f"<SFTPFile value={str(self)!r}>"


class SFTPFileField(TypeDecorator):
    impl = String(512)
    cache_ok = True

    def __init__(
        self,
        *args,
        path_prefix: str | None = None,
        accepted_formats: list[str] = [],
        max_file_size: int = 100 * 1024 * 1024,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self._path_prefix = path_prefix
        self._accepted_formats = accepted_formats
        self._max_file_size = max_file_size

    def process_bind_param(self, value, dialect):
        return str(value) if value is not None else None

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return SFTPFile(
            value, self._path_prefix, self._accepted_formats, self._max_file_size
        )

    def process_literal_param(self, value, dialect):
        return str(value) if value is not None else None

    @property
    def python_type(self):
        return SFTPFile

    @property
    def path_prefix(self):
        return self._path_prefix

    @property
    def accepted_formats(self):
        return self._accepted_formats

    @property
    def max_file_size(self):
        return self._max_file_size


class SFTPFileMixin:
    def _wrap_sftpfile(self, key, value):
        """Ensure SFTPFileField columns are always SFTPFile objects."""
        col_type = self.__table__.c[key].type
        if value is None or isinstance(value, SFTPFile):
            return value
        return SFTPFile(
            value,
            col_type.path_prefix,
            col_type.accepted_formats,
            col_type.max_file_size,
        )


class SoftDeleteMixin:
    @declared_attr
    @classmethod
    def deleted_at(cls) -> Mapped[datetime | None]:
        return mapped_column(DateTime(timezone=True), default=None, nullable=True)


class TaskMixin:
    @declared_attr
    @classmethod
    def task_id(cls) -> Mapped[uuid.UUID | None]:
        return mapped_column(GUID, nullable=True, default=uuid.uuid4)

    @declared_attr
    @classmethod
    def status(cls) -> Mapped[enums.TaskStatus]:
        return mapped_column(Enum(enums.TaskStatus), default=enums.TaskStatus.QUEUED)

    @declared_attr
    @classmethod
    def message(cls) -> Mapped[str | None]:
        return mapped_column(String(255), nullable=True, default=None)

    @declared_attr
    @classmethod
    def progress(cls) -> Mapped[float | None]:
        return mapped_column(Float, nullable=True, default=None)

    @declared_attr
    @classmethod
    def started_at(cls) -> Mapped[datetime | None]:
        return mapped_column(DateTime(timezone=True), nullable=True)

    @declared_attr
    @classmethod
    def ended_at(cls) -> Mapped[datetime | None]:
        return mapped_column(DateTime(timezone=True), nullable=True)
