from datetime import date
from sqlalchemy import Select, select, delete, func

from be_kit.repositories import BaseRepository

from . import models


class RequisitionRepository(BaseRepository):
    model = models.Requisition
    option_fields = [
        "pk",
        "requisition_id",
        "status",
        "payment_term",
        "delivery_term",
        "currency",
        "vendor",
    ]
    filter_by_organization = True

    def add_organization_filter(self, query: Select):
        query = query.filter_by(organization_id=self.organization_id)
        return query

    async def acreate(self, _commit: bool = True, **kwargs):
        today = date.today()
        counter = await self.aget_counter(today, kwargs["organization_id"])
        kwargs["requisition_id"] = self.get_requisition_id(today, counter)
        if not _commit:
            max_pk = await self.session.execute(select(func.max(self.model.pk)))
            max_pk = max_pk.scalar() or 0
            kwargs["pk"] = max_pk + 1
        return await super().acreate(_commit, **kwargs)

    async def aget_counter(self, dt: date, organization_id: int):
        query = select(models.RequisitionCounter).filter_by(
            organization_id=organization_id, counter_date=dt
        )
        obj = await self.session.execute(query)
        obj = obj.scalar()
        if obj is None:
            obj = models.RequisitionCounter(
                organization_id=organization_id, counter_date=dt, counter=0
            )
            self.session.add(obj)
        else:
            obj.counter += 1
            self.session.add(obj)
        await self.session.flush()
        return obj.counter

    def get_requisition_id(self, dt: date, counter: int):
        return f"PR{dt.strftime('%y%m%d')}{counter:03d}"


class RequisitionItemRepository(BaseRepository):
    model = models.RequisitionItem

    async def acreate(self, _commit: bool = True, **kwargs):
        if not _commit:
            max_pk = await self.session.execute(select(func.max(self.model.pk)))
            max_pk = max_pk.scalar() or 0
            kwargs["pk"] = max_pk + 1
        return await super().acreate(_commit, **kwargs)

    async def adelete_by_requisition_id(self, requisition_id: int):
        query = delete(self.model).filter_by(requisition_id=requisition_id)
        await self.session.execute(query)
