from datetime import date
from decimal import Decimal
from sqlalchemy import Select, select

from be_kit.repositories import BaseRepository
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import date
from . import models, enums


class ExpenseCategoryRepository(BaseRepository):
    model = models.ExpenseCategory
    filter_by_organization = True

    def add_organization_filter(self, query: Select):
        query = query.filter_by(organization_id=self.organization_id)
        return query


class ExpenseRepository(BaseRepository):
    model = models.Expense
    filter_by_organization = True

    def add_organization_filter(self, query: Select):
        query = query.filter_by(organization_id=self.organization_id)
        return query

    async def acreate(self, _commit: bool = True, **kwargs):
        expense_category_repo = ExpenseCategoryRepository(
            self.session, self.organization_id
        )
        kwargs["category"] = await expense_category_repo.aget_or_404(
            kwargs.pop("category_id")
        )
        today = date.today()
        counter = await self.aget_counter(today, kwargs["organization_id"])
        kwargs["identifier"] = self.get_entry_by_id(today, counter)
        return await super().acreate(_commit, **kwargs)

    async def aget_counter(self, dt: date, organization_id: int):
        query = select(models.ExpenseCounter).filter_by(
            organization_id=organization_id, counter_date=dt
        )
        obj = await self.session.execute(query)
        obj = obj.scalar()
        if obj is None:
            obj = models.ExpenseCounter(
                organization_id=organization_id, counter_date=dt, counter=1  # Start at 1, not 0
            )
            self.session.add(obj)
        else:
            obj.counter += 1
            self.session.add(obj)
        await self.session.flush()
        return obj.counter

    def get_entry_by_id(self, dt: date, counter: int):
        return f"P{dt.strftime('%y%m%d')}{counter:03d}"
