from sqlalchemy import Select

from be_kit.repositories import BaseRepository
from . import models


class TagRepository(BaseRepository):
    model = models.Tag
    filter_by_organization = True

    def add_organization_filter(self, query: Select):
        query = query.filter_by(organization_id=self.organization_id)
        return query


class TagRepositoryMixin:
    async def acreate(self, _commit=True, **kwargs):
        if "tag_ids" in kwargs:
            tag_ids = kwargs.pop("tag_ids")
            obj = await super().acreate(_commit=False, **kwargs)
            if tag_ids:
                tags = await self.session.execute(
                    Select(models.Tag).filter(models.Tag.pk.in_(tag_ids))
            )
                obj.tags = tags.scalars().all()
            if _commit:
                await self.session.commit()
            return obj
        return await super().acreate(_commit, **kwargs)

    async def aupdate(self, pk: int, _commit=True, **kwargs):
        if "tag_ids" in kwargs:
            tag_ids = kwargs.pop("tag_ids")
            obj = await super().aupdate(pk, _commit=False, **kwargs)
            if tag_ids is not None:
                tags = await self.session.execute(
                    Select(models.Tag).filter(models.Tag.pk.in_(tag_ids))
                )
                obj.tags = tags.scalars().all()
            if _commit:
                await self.session.commit()
            return obj
        return await super().aupdate(pk, _commit, **kwargs)
