from datetime import timedelta
import importlib.resources as pkg_resources
import random

from typing import Any

import jwt
from celery import Task
from fastapi import Depends, HTTPException, Request
from fastapi.security import OAuth2PasswordBearer as _OAuth2PasswordBearer
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import EmailStr
from sqlalchemy.ext.asyncio import AsyncSession

from be_kit import dateutils
from be_kit.caches import AsyncCache
from be_kit.mailing import EmailClient
from be_kit.exceptions import HTTPNotAuthenticatedException, HTTPUnauthorizedException

from be_uam import organization

from ..databases import get_async_session
from ..permission.repositories import (
    GroupPermissionRepository,
)
from ..user.models import User
from ..user.repositories import AuthUserRepository, InitialUserRepository
from ..user.utils import verify_password
from ..organization.repositories import InitialOrganizationRepository
from ..redis import cache_config, get_async_cache
from ..settings import settings
from ..tasks import task_handler
from ..audit_log.enums import AuditLogAction
from ..audit_log.utils import audit_log_action
from .enums import AuthorizationMode


ALGORITHM = "HS256"


class OAuth2SessionOrToken(_OAuth2PasswordBearer):
    async def __call__(self, request: Request) -> str | None:
        session = request.cookies.get(settings.jwt.access_cookie_name, None)
        if session is not None:
            return session

        authorization = request.headers.get("Authorization")
        scheme, token = get_authorization_scheme_param(authorization)
        if not authorization or scheme.lower() != "bearer":
            if self.auto_error:
                raise HTTPNotAuthenticatedException
            return None
        return token


bearer_scheme = OAuth2SessionOrToken(tokenUrl="/auth/login/", auto_error=True)
anonable_bearer_scheme = OAuth2SessionOrToken(tokenUrl="/auth/login/", auto_error=False)


async def create_token(data: dict, expired_at: timedelta):
    to_encode = data.copy()
    expire = dateutils.now() + expired_at
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=ALGORITHM)
    return encoded_jwt


async def create_access_token(data: dict):
    return await create_token(data, timedelta(days=settings.jwt.access_token_expire))


async def create_refresh_token(data: dict):
    return await create_token(data, timedelta(days=settings.jwt.refresh_token_expire))


async def retrieve_auth_user(db: AsyncSession, pk: int) -> User:
    repo = AuthUserRepository(db)
    obj = await repo.aget_or_404(pk)
    return obj


async def verify_token(db: AsyncSession, token: str) -> Any:
    try:
        decoded_jwt = jwt.decode(token, settings.secret_key, algorithms=[ALGORITHM])
        try:
            return await retrieve_auth_user(db, decoded_jwt.get("sub"))
        except HTTPException as exc:
            if exc.status_code == 404:
                raise HTTPUnauthorizedException from exc
            raise exc
    except (jwt.ExpiredSignatureError, jwt.DecodeError) as exc:
        raise HTTPUnauthorizedException from exc


async def check_permission(
    db: AsyncSession,
    cache: AsyncCache,
    request_user: User,
    permissions: list[str],
    mode: AuthorizationMode,
):
    otp_verified = await cache.aget("uam", "auth", "otp_verified", request_user.email)
    if not request_user.is_verified or not otp_verified:
        raise HTTPUnauthorizedException
    if request_user.is_superuser:
        return
    group_perm_repo = GroupPermissionRepository(db)
    has_group_perm = await group_perm_repo.acheck_group_permission(
        request_user.group_id, permissions, mode
    )
    if not has_group_perm:
        raise HTTPUnauthorizedException


async def get_request_user(
    token: str = Depends(bearer_scheme),
    session: AsyncSession = Depends(get_async_session),
):
    obj = await verify_token(session, token)
    if not obj.is_active or obj.is_locked:
        raise HTTPUnauthorizedException
    return obj


async def get_request_user_or_anon(
    token: str | None = Depends(anonable_bearer_scheme),
    session: AsyncSession = Depends(get_async_session),
):
    if token:
        obj = await verify_token(session, token)
        if not obj.is_active or obj.is_locked:
            raise HTTPUnauthorizedException
        return obj
    return None


@task_handler.async_task(bind=False)
async def send_otp_email(
    email: EmailStr,
    last_name: str,
):
    otp = str(random.randint(100000, 999999))

    cache = cache_config.get_async_cache()
    await cache.aset(otp, "uam", "auth", "otp", email, ttl=60)
    await cache.aclose()

    email_client = EmailClient(
        settings=settings.mailing,
        template_dir=pkg_resources.files("be_uam.templates").joinpath("email"),
        test_mode=settings.test_mode,
        debug=settings.debug,
    )
    message = await email_client.get_templated_message(
        template_name="otp.html",
        subject="Symbolix OTP Verification",
        to=email,
        body={
            "full_name": last_name,
            "otp_code": otp,
            "year": dateutils.now().year,
        },
    )
    await email_client.send_message(message)


send_otp_email: Task


async def set_otp_verified_cache(
    cache: AsyncCache,
    email: EmailStr,
    status: bool = True,
):
    await cache.aset(status, "uam", "auth", "otp_verified", email)
    return status


async def verify_otp(cache: AsyncCache, email: EmailStr, otp: str) -> bool:
    if await cache.aexists("uam", "auth", "otp", email):
        cached_otp = await cache.aget("uam", "auth", "otp", email)
        if cached_otp == otp:
            await cache.adelete("uam", "auth", "otp", email)
            await set_otp_verified_cache(cache, email)
            return True
    return False


async def authenticate(
    db: AsyncSession, cache: AsyncCache, username: str, password: str
) -> User:
    repo = AuthUserRepository(db)
    obj = await repo.aget_by_username(username)
    if obj is None:
        raise HTTPUnauthorizedException
    if not obj.is_active or obj.is_locked:
        raise HTTPUnauthorizedException
    if not verify_password(password, obj.password):
        if await cache.aexists("auth", "login", "fail", obj.pk):
            count = await cache.aget("auth", "login", "fail", obj.pk)
            count += 1
        else:
            count = 1
        if count >= settings.auth_fail_attempts:
            obj = await repo.aupdate(obj.pk, is_locked=True)
            await cache.adelete("auth", "login", "fail", obj.pk)
        await cache.aset(
            count, "auth", "login", "fail", obj.pk, ttl=settings.auth_fail_ttl
        )
        raise HTTPUnauthorizedException
    obj = await repo.aupdate(obj.pk, last_login_at=dateutils.now())
    await set_otp_verified_cache(cache, obj.email, False)
    if obj.group_id is not None:
        await audit_log_action(
            db,
            app="uam",
            module="auth",
            submodule="login",
            action=AuditLogAction.LOGIN,
            description=f"User {obj.standard_full_name} logged in.",
            created_by=obj,
            before=None,
            after=None,
            _commit=True,
        )
    return obj


def authorize(permission: list[str], mode: AuthorizationMode = AuthorizationMode.ANY):
    async def _authorize(
        request_user: User = Depends(get_request_user),
        session: AsyncSession = Depends(get_async_session),
        cache: AsyncCache = Depends(get_async_cache),
    ):
        await check_permission(session, cache, request_user, permission, mode)

    return _authorize
