from sqlalchemy import select
from sqlalchemy.sql.expression import Select

from be_kit.repositories import BaseRepository

from ..auth.enums import AuthorizationMode
from ..permission.exceptions import InvalidPermissionActionError
from . import models


class PermissionRepository(BaseRepository):
    model = models.Permission

    async def aget_permission_by_codes(self, codes: list[str]):
        query: Select = select(self.model)
        query = query.filter(self.model.code.in_(codes))
        objs = await self.session.execute(query)
        objs = objs.scalars()
        objs = list(objs)
        if len(objs) != len(codes):
            raise InvalidPermissionActionError(codes)
        return objs

    def populate_initial(self, permissions: list[dict]):
        created_permissions = []
        try:
            for permission in permissions:
                permission_copy = permission.copy()
                existing_permission = self.session.execute(
                    select(self.model).filter_by(code=permission_copy["code"])
                ).scalar_one_or_none()
                if existing_permission:
                    continue
                permission_copy.pop("code")
                created_permission = self.create(**permission_copy, _commit=False)
                created_permissions.append(created_permission)
        except:
            self.session.rollback()
            raise
        self.session.commit()
        return created_permissions


class GroupPermissionRepository(BaseRepository):
    model = models.GroupPermission

    async def aget_group_permissions_by_ids(
        self, group_id: int, permission_ids: list[int]
    ):
        query: Select = select(self.model)
        query = query.filter(
            self.model.group_id == group_id,
            self.model.permission_id.in_(permission_ids),
        )
        objs = await self.session.execute(query)
        objs = objs.scalars()
        return objs

    async def acheck_group_permission(
        self,
        group_id: int,
        codes: list[str],
        mode: AuthorizationMode = AuthorizationMode.ANY,
    ):
        perm_repo = PermissionRepository(self.session)
        perms = await perm_repo.aget_permission_by_codes(codes)
        perms = [perm.pk for perm in perms]
        group_perms = await self.aget_group_permissions_by_ids(group_id, perms)
        group_perms = list(group_perms)
        if mode == AuthorizationMode.ANY and len(group_perms) == 0:
            return False
        if mode == AuthorizationMode.ALL and len(group_perms) != len(codes):
            return False
        return True
