from datetime import date
from decimal import Decimal
from sqlalchemy import Select, select

from be_kit.repositories import BaseRepository
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import date
from . import models, enums


class ProductCategoryRepository(BaseRepository):
    model = models.ProductCategory
    filter_by_organization = True

    def add_organization_filter(self, query: Select):
        query = query.filter_by(organization_id=self.organization_id)
        return query


class ProductRepository(BaseRepository):
    model = models.Product
    filter_by_organization = True

    def add_organization_filter(self, query: Select):
        query = query.filter_by(organization_id=self.organization_id)
        return query

    async def acreate(self, _commit: bool = True, **kwargs):
        product_category_repo = ProductCategoryRepository(
            self.session, self.organization_id
        )
        kwargs["category"] = await product_category_repo.aget_or_404(
            kwargs.pop("category_id")
        )
        today = date.today()
        counter = await self.aget_counter(today, kwargs["organization_id"])
        kwargs["identifier"] = self.get_entry_by_id(today, counter)
        return await super().acreate(_commit, **kwargs)

    async def aget_counter(self, dt: date, organization_id: int):
        query = select(models.ProductCounter).filter_by(
            organization_id=organization_id, counter_date=dt
        )
        obj = await self.session.execute(query)
        obj = obj.scalar()
        if obj is None:
            obj = models.ProductCounter(
                organization_id=organization_id, counter_date=dt, counter=1  # Start at 1, not 0
            )
            self.session.add(obj)
        else:
            obj.counter += 1
            self.session.add(obj)
        await self.session.flush()
        return obj.counter

    def get_entry_by_id(self, dt: date, counter: int):
        return f"P{dt.strftime('%y%m%d')}{counter:03d}"


class ProductStockRepository(BaseRepository):
    model = models.ProductStock
    filter_by_organization = True

    def add_organization_filter(self, query: Select):
        query = query.filter_by(organization_id=self.organization_id)
        return query

    async def add_reservation(self, product_id: int, warehouse_id: int, quantity: Decimal):
        query = select(self.model).filter_by(
            product_id=product_id,
            warehouse_id=warehouse_id,
            organization_id=self.organization_id,
        )
        result = await self.session.execute(query)
        stock: models.ProductStock = result.scalar_one_or_none()
        if stock is None:
            raise ValueError("ProductStock not found for reservation")
        stock.reserved += quantity
        self.session.add(stock)
        await self.session.flush()
        return stock

    async def move_reservation_to_locked(self, product_id: int, warehouse_id: int, quantity: Decimal):
        query = select(self.model).filter_by(
            product_id=product_id,
            warehouse_id=warehouse_id,
            organization_id=self.organization_id,
        )
        result = await self.session.execute(query)
        stock: models.ProductStock = result.scalar_one_or_none()
        if stock is None:
            raise ValueError("ProductStock not found for reservation")
        if stock.reserved < quantity:
            raise ValueError("Not enough reserved stock to move")
        stock.reserved -= quantity
        stock.locked += quantity
        self.session.add(stock)
        await self.session.flush()
        return stock

    async def move_available_to_reserved(self, product_id: int, warehouse_id: int, quantity: Decimal):
        query = select(self.model).filter_by(
            product_id=product_id,
            warehouse_id=warehouse_id,
            organization_id=self.organization_id,
        )
        result = await self.session.execute(query)
        stock: models.ProductStock = result.scalar_one_or_none()
        if stock is None:
            raise ValueError("ProductStock not found for reservation")
        if stock.available < quantity:
            raise ValueError("Not enough available stock to reserve")
        stock.available -= quantity
        stock.reserved += quantity
        self.session.add(stock)
        await self.session.flush()
        return stock

    async def add_to_available(self, product_id: int, warehouse_id: int, quantity: Decimal):
        query = select(self.model).filter_by(
            product_id=product_id,
            warehouse_id=warehouse_id,
            organization_id=self.organization_id,
        )
        result = await self.session.execute(query)
        stock: models.ProductStock = result.scalar_one_or_none()
        if stock is None:
            raise ValueError("ProductStock not found for available stock update")
        stock.available += quantity
        self.session.add(stock)
        await self.session.flush()
        return stock

    async def remove_from_available(self, product_id: int, warehouse_id: int, quantity: Decimal):
        query = select(self.model).filter_by(
            product_id=product_id,
            warehouse_id=warehouse_id,
            organization_id=self.organization_id,
        )
        result = await self.session.execute(query)
        stock: models.ProductStock = result.scalar_one_or_none()
        if stock is None:
            raise ValueError("ProductStock not found for available stock update")
        if stock.available < quantity:
            raise ValueError("Not enough available stock to remove")
        stock.available -= quantity
        self.session.add(stock)
        await self.session.flush()
        return stock

    async def set_stock(self, product_id: int, warehouse_id: int, quantity: Decimal):
        query = select(self.model).filter_by(
            product_id=product_id,
            warehouse_id=warehouse_id,
            organization_id=self.organization_id,
        )
        result = await self.session.execute(query)
        stock: models.ProductStock = result.scalar_one_or_none()
        if stock is None:
            raise ValueError("ProductStock not found for available stock update")
        stock.available = quantity
        self.session.add(stock)
        await self.session.flush()
        return stock


class ProductMovementRepository(BaseRepository):
    model = models.ProductMovement
    filter_by_organization = True

    def add_organization_filter(self, query: Select):
        query = query.filter_by(organization_id=self.organization_id)
        return query

    async def acreate(self, _commit: bool = True, **kwargs):
        today = date.today()
        counter = await self.aget_counter(today, kwargs["organization_id"])
        kwargs["identifier"] = self.get_entry_by_id(today, counter, kwargs["organization_id"])
        # Convert items dicts to ProductMovementItem instances
        items_data = kwargs.pop("items", [])
        items = [models.ProductMovementItem(**item) if isinstance(item, dict) else item for item in items_data]
        kwargs["items"] = items
        return await super().acreate(_commit, **kwargs)

    async def aget_counter(self, dt: date, organization_id: int):
        query = select(models.ProductCounter).filter_by(
            organization_id=organization_id, counter_date=dt
        )
        obj = await self.session.execute(query)
        obj = obj.scalar()
        if obj is None:
            obj = models.ProductCounter(
                organization_id=organization_id, counter_date=dt, counter=1  # Start at 1, not 0
            )
            self.session.add(obj)
        else:
            obj.counter += 1
            self.session.add(obj)
        await self.session.flush()
        return obj.counter

    async def move_stock(
        self,
        movement_type: str,
        status: str,
        product_id: int,
        source_warehouse_id: int | None,
        destination_warehouse_id: int | None,
        quantity: Decimal,
        product_stock_repo: "ProductStockRepository",
    ):
        if source_warehouse_id == destination_warehouse_id:
            # No movement needed
            return

        if movement_type == enums.ProductMovementType.OUT:
            # Move stock from reservation to locked
            await product_stock_repo.move_reservation_to_locked(
                product_id=product_id,
                warehouse_id=source_warehouse_id,
                quantity=quantity,
            )

        if movement_type == enums.ProductMovementType.IN:
            # Decrease stock from the destination warehouse
            await product_stock_repo.add_to_available(
                product_id=product_id,
                warehouse_id=destination_warehouse_id,
                quantity=quantity,
            )

        elif movement_type == enums.ProductMovementType.RESERVED:
            # Decrease stock from the destination warehouse
            await product_stock_repo.move_available_to_reserved(
                product_id=product_id,
                warehouse_id=source_warehouse_id,
                quantity=quantity,
            )

        elif movement_type == enums.ProductMovementType.SCRAP:
            # Decrease stock from the destination warehouse
            await product_stock_repo.remove_from_available(
                product_id=product_id,
                warehouse_id=destination_warehouse_id,
                quantity=quantity,
            )

        elif movement_type == enums.ProductMovementType.ADJUSTMENT:
            await product_stock_repo.set_stock(
                product_id=product_id,
                warehouse_id=destination_warehouse_id,
                quantity=quantity,
            )

        elif movement_type == enums.ProductMovementType.INTERNAL:
            # Decrease stock from the source warehouse
            await product_stock_repo.remove_from_available(
                product_id=product_id,
                warehouse_id=source_warehouse_id,
                quantity=quantity,
            )
            # Increase stock in the destination warehouse
            await product_stock_repo.add_to_available(
                product_id=product_id,
                warehouse_id=destination_warehouse_id,
                quantity=quantity,
            )

    def get_entry_by_id(self, dt: date, counter: int, organization_id: int):
        return f"PM{dt.strftime('%y%m%d')}{organization_id}{counter}"
