import tempfile
import os
from operator import ne, lt, le, gt, ge
from typing import Any
from uuid import UUID

import pandas as pd
from fastapi import HTTPException, UploadFile
from sqlalchemy import select, func
from sqlalchemy.orm import Session, aliased, selectinload
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql.expression import Select

from .databases import BaseModel, SFTPFile
from .dateutils import now
from .exceptions import HTTPNotFoundException
from .paginations import PaginationQuery
from .schemas import BaseSchema
from .sftp import AsyncSFTPClient
from . import enums


FILTER_OPERATOR_MAPPING = {
    "ne": ne,
    "lt": lt,
    "le": le,
    "gt": gt,
    "ge": ge,
}


FILTER_SPECIAL_OPERATOR_MAPPING = {
    "icontains": "icontains",
}


FILTER_CUSTOM_OPERATOR_MAPPING = {
    "isnull": {
        True: lambda x: x.is_(None),
        False: lambda x: x.is_not(None),
    },
}


# pylint: disable=too-many-public-methods
class BaseRepository:
    model: BaseModel
    class_default_filters: dict[str, str] = {}
    class_default_ordering: list[str] | None = None
    filter_by_organization: bool = False

    def __init__(
        self,
        session: AsyncSession | Session | None = None,
        organization_id: int | None = None,
    ):
        self._session = session
        self._organization_id = organization_id
        self._default_filters = self.class_default_filters.copy()

    @property
    def session(self) -> AsyncSession | Session | None:
        return self._session

    @session.setter
    def session(self, value: AsyncSession | Session):
        self._session = value

    @property
    def organization_id(self) -> int | None:
        if self.filter_by_organization and self._organization_id is None:
            raise ValueError(
                "Organization ID must be provided when `filter_by_organization` is True"
            )
        return self._organization_id

    @property
    def default_filters(self) -> dict[str, str]:
        return self._default_filters

    @default_filters.setter
    def default_filters(self, value: dict[str, str]):
        self._default_filters.update(value)

    def set_default_filters(self, filters: dict[str, str]):
        self.default_filters = filters

    def create(self, _commit: bool = True, **kwargs):
        obj = self.model(**kwargs)
        try:
            self.session.add(obj)
            if _commit:
                self.session.commit()
            else:
                self.session.flush()
        except:
            self.session.rollback()
            raise
        return self.get(obj.pk, disable_organization_filter=True)

    def bulk_create(self, objs: list[BaseModel], _commit: bool = True):
        try:
            self.session.add_all(objs)
            if _commit:
                self.session.commit()
            else:
                self.session.flush()
        except:
            self.session.rollback()
            raise
        return objs

    def get(
        self,
        pk,
        disable_organization_filter: bool = False,
        eager_loads: list[str] | None = None,
    ):
        query: Select = select(self.model)
        if self.filter_by_organization and not disable_organization_filter:
            # pylint: disable=assignment-from-no-return
            query = self.add_organization_filter(query)
            # pylint: enable=assignment-from-no-return
        query = query.filter_by(pk=pk, **self.default_filters)
        query = self.add_eager_loads(query, eager_loads) if eager_loads else query
        obj = self.session.execute(query)
        obj = obj.scalar()
        return obj

    def get_or_404(self, pk, eager_loads: list[str] | None = None):
        obj = self.get(pk, eager_loads=eager_loads)
        self.raise_if_notfound(obj)
        return obj

    def records(
        self,
        filters: dict[str, Any] | BaseSchema | None = None,
        pagination: PaginationQuery | None = None,
        ordering: list[str] | None = None,
        eager_loads: list[str] | None = None,
    ):
        query = select(self.model)
        query: Select
        if self.filter_by_organization:
            # pylint: disable=assignment-from-no-return
            query = self.add_organization_filter(query)
            # pylint: enable=assignment-from-no-return
        if filters:
            query = self.add_filters(query, filters)
        elif len(self.default_filters) > 0:
            query = query.filter_by(**self.default_filters)
        query = self.add_ordering(query, ordering)
        query = query.distinct()
        query = self.add_eager_loads(query, eager_loads) if eager_loads else query
        if pagination:
            return self.handle_pagination(query, pagination)
        objs = self.session.execute(query)
        objs = objs.scalars()
        return objs

    def update(self, pk, _commit: bool = True, **kwargs):
        obj = self.get_or_404(pk)
        obj = self.update_values(obj, kwargs)
        try:
            self.session.add(obj)
            if _commit:
                self.session.commit()
            else:
                self.session.flush()
            self.session.refresh(obj)
        except:
            self.session.rollback()
            raise
        return self.get(obj.pk, disable_organization_filter=True)

    def delete(self, pk):
        obj = self.get_or_404(pk)
        try:
            self.session.delete(obj)
            self.session.commit()
        except:
            self.session.rollback()
            raise

    def soft_delete(self, pk, **kwargs):
        obj = self.get_or_404(pk)
        obj.deleted_at = now()
        obj = self.update_values(obj, kwargs)
        try:
            self.session.add(obj)
            self.session.commit()
            self.session.refresh(obj)
        except:
            self.session.rollback()
            raise
        return obj

    async def acreate(self, _commit: bool = True, **kwargs):
        obj = self.model(**kwargs)
        try:
            self.session.add(obj)
            if _commit:
                await self.session.commit()
            else:
                await self.session.flush()
        except:
            await self.session.rollback()
            raise
        return await self.aget(obj.pk, disable_organization_filter=True)

    async def abulk_create(self, objs: list[BaseModel], _commit: bool = True):
        try:
            self.session.add_all(objs)
            if _commit:
                await self.session.commit()
            else:
                await self.session.flush()
        except:
            await self.session.rollback()
            raise
        return objs

    async def aget(
        self,
        pk,
        disable_organization_filter: bool = False,
        eager_loads: list[str] | None = None,
    ):
        query: Select = select(self.model)
        if self.filter_by_organization and not disable_organization_filter:
            # pylint: disable=assignment-from-no-return
            query = self.add_organization_filter(query)
            # pylint: enable=assignment-from-no-return
        query = query.filter_by(pk=pk, **self.default_filters)
        query = self.add_eager_loads(query, eager_loads) if eager_loads else query
        obj = await self.session.execute(query)
        obj = obj.scalar()
        return obj

    async def aget_or_404(self, pk, eager_loads: list[str] | None = None):
        obj = await self.aget(pk, eager_loads=eager_loads)
        self.raise_if_notfound(obj)
        return obj

    async def aget_by_ids(self, ids: list[Any]):
        query: Select = select(self.model)
        query = query.filter(self.model.pk.in_(ids))
        objs = await self.session.execute(query)
        objs = objs.scalars()
        return objs

    async def arecords(
        self,
        filters: dict[str, Any] | BaseSchema | None = None,
        pagination: PaginationQuery | None = None,
        ordering: dict[str, Any] | BaseSchema | None = None,
        eager_loads: list[str] | None = None,
    ):
        query = select(self.model)
        query: Select
        if self.filter_by_organization:
            # pylint: disable=assignment-from-no-return
            query = self.add_organization_filter(query)
            # pylint: enable=assignment-from-no-return
        if filters:
            query = self.add_filters(query, filters)
        elif len(self.default_filters) > 0:
            query = query.filter_by(**self.default_filters)
        query = self.add_ordering(query, ordering)
        query = query.distinct()
        query = self.add_eager_loads(query, eager_loads) if eager_loads else query
        if pagination:
            return await self.ahandle_pagination(query, pagination)
        objs = await self.session.execute(query)
        objs = objs.scalars()
        return objs

    async def aupdate(self, pk, _commit: bool = True, **kwargs):
        obj = await self.aget_or_404(pk)
        obj = self.update_values(obj, kwargs)
        try:
            self.session.add(obj)
            if _commit:
                await self.session.commit()
            else:
                await self.session.flush()
            await self.session.refresh(obj)
        except:
            await self.session.rollback()
            raise
        return await self.aget(obj.pk, disable_organization_filter=True)

    async def adelete(self, pk):
        obj = await self.aget_or_404(pk)
        try:
            await self.session.delete(obj)
            await self.session.commit()
        except:
            await self.session.rollback()
            raise

    async def asoft_delete(self, pk, **kwargs):
        obj = await self.aget_or_404(pk)
        obj.deleted_at = now()
        obj = self.update_values(obj, kwargs)
        try:
            self.session.add(obj)
            await self.session.commit()
            await self.session.refresh(obj)
        except:
            await self.session.rollback()
            raise
        return obj

    def raise_if_notfound(self, obj):
        if obj is None:
            raise HTTPNotFoundException

    def add_filters(self, query: Select, filters: dict[str, Any] | BaseSchema):
        if not isinstance(filters, dict):
            filters = filters.model_dump(exclude_unset=True, exclude_defaults=True)
        if len(self.default_filters) > 0:
            _default_filter = self.default_filters.copy()
            _default_filter.update(filters)
            filters = _default_filter
        if len(filters) > 0:
            for key, value in filters.items():
                if "___" in key:
                    nested_fields = key.split("___")
                    if "__" in nested_fields[-1]:
                        key, op = nested_fields[-1].split("__")
                        nested_fields[-1] = key
                    else:
                        op = None
                    current_model = self.model
                    field = getattr(current_model, nested_fields[0])
                    joined_models = []
                    for nested_field in nested_fields[1:]:
                        rel_model = field.property.mapper.class_
                        if rel_model not in joined_models:
                            query = query.join(field)
                            joined_models.append(rel_model)
                        field = getattr(rel_model, nested_field)
                    if op is not None:
                        query = self.apply_operator(query, field, op, value)
                    else:
                        query = query.filter(field == value)
                elif "__" in key:
                    key, op = key.split("__")
                    query = self.apply_operator(
                        query, getattr(self.model, key), op, value
                    )
                else:
                    model_attr = getattr(self.model, key)
                    if hasattr(model_attr, "property") and hasattr(
                        model_attr.property, "direction"
                    ):
                        query = query.filter(model_attr.has(pk=value))
                    else:
                        query = query.filter(model_attr == value)
        return query

    def apply_operator(self, query: Select, field, op, value):
        if op in FILTER_OPERATOR_MAPPING:
            query = query.filter(FILTER_OPERATOR_MAPPING[op](field, value))
        elif op in FILTER_SPECIAL_OPERATOR_MAPPING:
            query = query.filter(getattr(field, op)(value))
        elif op in FILTER_CUSTOM_OPERATOR_MAPPING:
            query = query.filter(FILTER_CUSTOM_OPERATOR_MAPPING[op][value](field))
        else:
            raise ValueError(f"Invalid operator: {op}")
        return query

    def get_ordering_fields(self, fields: list[str]):
        ordering = []
        for field in fields:
            if field.startswith("-"):
                field = field[1:]
                ordering.append(getattr(self.model, field).desc())
            else:
                ordering.append(getattr(self.model, field))
        return ordering

    def add_ordering(self, query: Select, ordering: list[str] = None):
        if ordering:
            query = query.order_by(*self.get_ordering_fields(ordering))
        elif self.class_default_ordering:
            query = query.order_by(
                *self.get_ordering_fields(self.class_default_ordering)
            )
        else:
            pass
        return query

    def update_values(self, obj, values: dict):
        for k, v in values.items():
            setattr(obj, k, v)
        return obj

    def handle_pagination(self, query: Select, pagination: PaginationQuery):
        aliased_model = aliased(self.model, query.subquery())
        # pylint: disable=not-callable
        total_items = self.session.execute(select(func.count(aliased_model.pk)))
        # pylint: enable=not-callable
        total_items = total_items.scalar()
        total_page = (total_items + pagination.page_size - 1) // pagination.page_size
        query = query.limit(pagination.page_size).offset(
            (pagination.page - 1) * pagination.page_size
        )
        objs = self.session.execute(query)
        objs = objs.scalars()
        return {
            "total_items": total_items,
            "total_page": total_page,
            "page": pagination.page,
            "items": objs,
        }

    async def ahandle_pagination(self, query: Select, pagination: PaginationQuery):
        aliased_model = aliased(self.model, query.subquery())
        # pylint: disable=not-callable
        total_items = await self.session.execute(select(func.count(aliased_model.pk)))
        # pylint: enable=not-callable
        total_items = total_items.scalar()
        total_page = (total_items + pagination.page_size - 1) // pagination.page_size
        query = query.limit(pagination.page_size).offset(
            (pagination.page - 1) * pagination.page_size
        )
        objs = await self.session.execute(query)
        objs = objs.scalars()
        return {
            "total_items": total_items,
            "total_page": total_page,
            "page": pagination.page,
            "items": objs,
        }

    # pylint: disable=unused-argument
    def add_organization_filter(self, query: Select):
        if self.filter_by_organization:
            raise NotImplementedError(
                "`add_organization_filter` method must be implemented"
            )

    # pylint: enable=unused-argument

    def add_eager_loads(self, query: Select, eager_loads: list[str] | dict[str, str]):
        LOAD_FUN_MAPPING = {
            "selectinload": selectinload,
        }

        eager_load_objs = []
        if isinstance(eager_loads, dict):
            items = eager_loads.items()
        else:
            items = ((field, "selectinload") for field in eager_loads)

        for field, load_fun in items:
            fields = field.split(".")
            if len(fields) > 1:
                field_model = self.model
                field_loader = None
                for i, fld in enumerate(fields):
                    if i == 0:
                        field_loader = LOAD_FUN_MAPPING[load_fun](
                            getattr(field_model, fld)
                        )
                    else:
                        field_loader = getattr(field_loader, load_fun)(
                            getattr(field_model, fld)
                        )
                    if i != len(fields) - 1:
                        field_model = getattr(field_model, fld).property.mapper.class_
                eager_load_objs.append(field_loader)
            else:
                eager_load_objs.append(
                    LOAD_FUN_MAPPING[load_fun](getattr(self.model, field))
                )
        query = query.options(*eager_load_objs)
        return query


# pylint: enable=too-many-public-methods


class FileMixin:
    @property
    def sftp(self) -> AsyncSFTPClient:
        return self._sftp

    @sftp.setter
    def sftp(self, value: AsyncSFTPClient):
        self._sftp = value

    def enable_sftp(self, sftp: AsyncSFTPClient):
        self._sftp = sftp

    async def aupload_and_create(
        self,
        file_field: str,
        file: UploadFile,
        _commit=True,
        **kwargs,
    ):
        kwargs[file_field] = file.filename
        obj = await super().acreate(_commit=False, **kwargs)
        file_obj: SFTPFile = getattr(obj, file_field)
        filename = file.filename
        file_ext = filename.split(".")[-1].lower()
        if (
            len(file_obj.accepted_formats) > 0
            and file_ext not in file_obj.accepted_formats
        ):
            raise HTTPException(
                status_code=400,
                detail=(
                    f"File format '{file_ext}' is not accepted. "
                    f"Accepted formats: {file_obj.accepted_formats}"
                ),
            )

        remote_path = (
            f"{file_obj.path_prefix}/{obj.pk}/{filename}"
            if file_obj.path_prefix
            else f"{obj.pk}/{filename}"
        )
        await self.sftp.aupload(file, remote_path, file_obj.max_file_size)
        kwargs[file_field] = filename
        return await super().aupdate(obj.pk, _commit=_commit, **kwargs)

    async def adownload_file(self, pk, file_field: str, local_path: str):
        obj = await self.aget_or_404(pk)
        file_obj: SFTPFile = getattr(obj, file_field)
        remote_path = (
            f"{file_obj.path_prefix}/{obj.pk}/{file_obj}"
            if file_obj.path_prefix
            else f"{obj.pk}/{file_obj}"
        )
        await self.sftp.adownload(remote_path, local_path)

    async def astream_file(self, pk, file_field: str):
        obj = await self.aget_or_404(pk)
        file_obj: SFTPFile = getattr(obj, file_field)
        remote_path = (
            f"{file_obj.path_prefix}/{obj.pk}/{file_obj}"
            if file_obj.path_prefix
            else f"{obj.pk}/{file_obj}"
        )
        async for chunk in self.sftp.astream(remote_path):
            yield chunk

    async def aload_df(self, pk, file_field: str):
        obj = await self.aget_or_404(pk)
        file_obj: SFTPFile = getattr(obj, file_field)
        filename = str(file_obj)
        file_ext = filename.split(".")[-1].lower()
        with tempfile.TemporaryDirectory() as tmpdir:
            local_path = os.path.join(tmpdir, filename)
            await self.adownload_file(pk, file_field, local_path)
            if file_ext == "csv":
                df = pd.read_csv(local_path)
            elif file_ext in ("xlsx", "xls"):
                df = pd.read_excel(local_path)
            else:
                raise ValueError(f"Unsupported file extension: {file_ext}")
            os.remove(local_path)
        return df

    async def adelete(self, pk):
        obj = await self.aget_or_404(pk)
        for attr_name in dir(self.model):
            attr = getattr(self.model, attr_name)
            if isinstance(attr, SFTPFile):
                file_obj = getattr(obj, attr_name, None)
                if file_obj:
                    remote_path = (
                        f"{attr.path_prefix}/{obj.pk}/{file_obj}"
                        if attr.path_prefix
                        else f"{obj.pk}/{file_obj}"
                    )
                    await self.sftp.adelete(remote_path)
        return await super().adelete(pk)


class TaskMixin:
    def update_progress(
        self, pk, message: str | None = None, progress: float | None = None
    ):
        update_kwargs = {"last_modified_at": now()}
        if message is not None:
            update_kwargs["message"] = message
        if progress is not None:
            update_kwargs["progress"] = progress
        return super().update(pk, **update_kwargs)

    def mark_running(self, pk, task_id: UUID | str):
        if isinstance(task_id, str):
            task_id = UUID(task_id)
        return super().update(
            pk,
            status=enums.TaskStatus.RUNNING,
            message="Running",
            task_id=task_id,
            started_at=now(),
        )

    def mark_success(self, pk):
        return super().update(
            pk,
            status=enums.TaskStatus.SUCCESS,
            message="Completed",
            progress=1,
            ended_at=now(),
        )

    def mark_failed(self, pk, exc: Exception | str):
        message = str(exc) if not isinstance(exc, str) else exc
        return super().update(
            pk,
            status=enums.TaskStatus.FAILED,
            message=f"Failed: {message}",
            ended_at=now(),
        )

    def mark_cancelling(self, pk):
        return super().update(
            pk, status=enums.TaskStatus.CANCELLING, message="Cancelling"
        )

    def mark_cancelled(self, pk):
        return super().update(
            pk, status=enums.TaskStatus.CANCELLED, message="Cancelled", ended_at=now()
        )

    def mark_stopping(self, pk):
        return super().update(pk, status=enums.TaskStatus.STOPPING, message="Stopping")

    def mark_stopped(self, pk):
        return super().update(
            pk, status=enums.TaskStatus.STOPPED, message="Stopped", ended_at=now()
        )

    def mark_timeout(self, pk):
        return super().update(
            pk, status=enums.TaskStatus.TIMEOUT, message="Timeout", ended_at=now()
        )

    def mark_unknown(self, pk, exc: Exception | str):
        message = str(exc) if not isinstance(exc, str) else exc
        return super().update(
            pk,
            status=enums.TaskStatus.UNKNOWN,
            message=f"Unknown: {message}",
            ended_at=now(),
        )

    async def aupdate_progress(
        self, pk, message: str | None = None, progress: float | None = None
    ):
        update_kwargs = {"last_modified_at": now()}
        if message is not None:
            update_kwargs["message"] = message
        if progress is not None:
            update_kwargs["progress"] = progress
        return await super().aupdate(pk, **update_kwargs)

    async def amark_running(self, pk, task_id: UUID | str):
        if isinstance(task_id, str):
            task_id = UUID(task_id)
        return await super().aupdate(
            pk,
            status=enums.TaskStatus.RUNNING,
            message="Running",
            task_id=task_id,
            started_at=now(),
        )

    async def amark_success(self, pk):
        return await super().aupdate(
            pk,
            status=enums.TaskStatus.SUCCESS,
            message="Completed",
            progress=1,
            ended_at=now(),
        )

    async def amark_failed(self, pk, exc: Exception | str):
        message = str(exc) if not isinstance(exc, str) else exc
        return await super().aupdate(
            pk,
            status=enums.TaskStatus.FAILED,
            message=f"Failed: {message}",
            ended_at=now(),
        )

    async def amark_cancelling(self, pk):
        return await super().aupdate(
            pk, status=enums.TaskStatus.CANCELLING, message="Cancelling"
        )

    async def amark_cancelled(self, pk):
        return await super().aupdate(
            pk, status=enums.TaskStatus.CANCELLED, message="Cancelled", ended_at=now()
        )

    async def amark_stopping(self, pk):
        return await super().aupdate(
            pk, status=enums.TaskStatus.STOPPING, message="Stopping"
        )

    async def amark_stopped(self, pk):
        return await super().aupdate(
            pk, status=enums.TaskStatus.STOPPED, message="Stopped", ended_at=now()
        )

    async def amark_timeout(self, pk):
        return await super().aupdate(
            pk, status=enums.TaskStatus.TIMEOUT, message="Timeout", ended_at=now()
        )

    async def amark_unknown(self, pk, exc: Exception | str):
        message = str(exc) if not isinstance(exc, str) else exc
        return await super().aupdate(
            pk,
            status=enums.TaskStatus.UNKNOWN,
            message=f"Unknown: {message}",
            ended_at=now(),
        )
