from decimal import Decimal
from typing import Any

from be_uam.user.schemas import UserMin
from pydantic import PositiveInt, computed_field, model_validator
from be_kit.schemas import BaseSchema

from be_inventory.product.schemas import ProductMin
from be_inventory.warehouse.schemas import WarehouseMin

from .enums import TransactionDeliveryTerm, TransactionPaymentTerm, TransactionStatus
from be_core_services.country.schemas import CurrencyMin, VATMin


class CalculationMixin:
    untaxed_subtotal: Decimal | None
    global_discount: Decimal | None
    tax_value: Decimal | None
    discount_per_item: Decimal | None
    total_after_tax_and_discount: Decimal | None
    global_dp: Decimal | None
    grand_total: Decimal | None



class BaseTransactionMixin(CalculationMixin):
    payment_term: TransactionPaymentTerm
    payment_n: PositiveInt | None
    payment_dp_rate: Decimal | None
    payment_dp: Decimal | None
    delivery_term: TransactionDeliveryTerm | None
    discount_rate: Decimal | None
    discount_amount: Decimal | None
    notes: str | None


class BaseTransactionItemMixin:
    quantity: Decimal
    price: Decimal
    discount_rate: Decimal | None
    discount_amount: Decimal | None
    vat_rate: Decimal
    # subtotal: Decimal | None


class TransactionSummaryMixin:
    total: Decimal
    # total_discount: Decimal
    # total_global_discount: Decimal | None
    # total_before_vat: Decimal | None
    # total_vat: Decimal
    # total_net: Decimal


class TransactionWriteMixin(BaseTransactionMixin, BaseSchema):
    currency_id: PositiveInt
    items: list["TransactionItemWriteMixin"]

    @model_validator(mode="before")
    @classmethod
    def validate_data(cls, data: dict[str, Any]) -> dict[str, Any]:
        if "N days" in data["payment_term"] and data["payment_n"] is None:
            raise ValueError(
                "payment_n is required for payment terms using number of days"
            )
        if data["payment_term"] == TransactionPaymentTerm.DP and (
            data["payment_dp_rate"] is None or data["payment_dp"] is None
        ):
            raise ValueError(
                "payment_dp_rate is required for payment term Down Payment"
            )
        total_items = len(data["items"])
        if data["discount_rate"] is not None:
            for item in data["items"]:
                item["_global_discount_rate"] = Decimal("%.4f" % data["discount_rate"])
        if data["discount_amount"] is not None:
            for item in data["items"]:
                item["_global_discount_amount"] = Decimal(
                    "%.4f" % (data["discount_amount"] / total_items)
                )
        return data

    @model_validator(mode="after")
    def data_postprocess(self):
        if self.payment_dp_rate is not None:
            self.payment_dp = self.total_net * self.payment_dp_rate
        return self

    @computed_field
    @property
    def total(self) -> Decimal:
        return Decimal("%.4f" % sum(item.total for item in self.items))

    @computed_field
    @property
    def total_discount(self) -> Decimal:
        return Decimal("%.4f" % sum(item.total_discount for item in self.items))

    @computed_field
    @property
    def total_global_discount(self) -> Decimal:
        return Decimal("%.4f" % sum(item.total_global_discount for item in self.items))

    @computed_field
    @property
    def total_vat(self) -> Decimal:
        return Decimal("%.4f" % sum(item.total_vat for item in self.items))

    @computed_field
    @property
    def total_net(self) -> Decimal:
        return Decimal("%.4f" % sum(item.total_net for item in self.items))

    @computed_field
    @property
    def total_net_after_dp(self) -> Decimal:
        return self.total_net - (self.payment_dp or Decimal(0))


class TransactionItemWriteMixin(BaseTransactionItemMixin, BaseSchema):
    vat_id: "PositiveInt | None"
    product_id: PositiveInt
    warehouse_id: PositiveInt

    @model_validator(mode="before")
    @classmethod
    def validate_data(cls, data: dict[str, Any]) -> dict[str, Any]:
        if data["discount_rate"] is not None and data["discount_amount"] is not None:
            raise ValueError("Cannot have both discount_rate and discount_amount")
        cls._global_discount_rate = data.pop("_global_discount_rate", None)
        cls._global_discount_amount = data.pop("_global_discount_amount", None)
        return data

    @model_validator(mode="after")
    def data_postprocess(self):
        if self.discount_rate is not None:
            self.discount_amount = self.total * self.discount_rate

        if self._global_discount_rate is not None:
            if self.discount_amount is None:
                self._global_discount_amount = self.total * self._global_discount_rate
            else:
                self._global_discount_amount = (
                    self.total - self.discount_amount
                ) * self._global_discount_rate
        return self

    @computed_field
    @property
    def total(self) -> Decimal:
        return Decimal("%.4f" % (self.quantity * self.price))

    @computed_field
    @property
    def total_discount(self) -> Decimal:
        return (
            Decimal("%.4f" % self.discount_amount)
            if self.discount_amount
            else Decimal(0)
        )

    @computed_field
    @property
    def total_global_discount(self) -> Decimal:
        return (
            Decimal("%.4f" % self._global_discount_amount)
            if self._global_discount_amount
            else Decimal(0)
        )

    @computed_field
    @property
    def total_vat(self) -> Decimal:
        if self.vat_rate is not None:
            return Decimal(
                "%.4f"
                % (
                    (self.total - self.total_discount - self.total_global_discount)
                    * self.vat_rate
                )
            )
        return Decimal(0)

    @computed_field
    @property
    def total_net(self) -> Decimal:
        return Decimal(
            "%.4f"
            % (
                self.total
                - self.total_discount
                - self.total_global_discount
                + self.total_vat
            )
        )


class TransactionCreateMixin(TransactionWriteMixin, BaseSchema):
    pass


class TransactionUpdateMixin(TransactionWriteMixin, BaseSchema):
    pass


class TransactionItemMixin(
    TransactionSummaryMixin, BaseTransactionItemMixin, BaseSchema
):
    vat: VATMin | None
    product: ProductMin
    warehouse: WarehouseMin


class TransactionMixin(TransactionSummaryMixin, BaseTransactionMixin, BaseSchema):
    status: TransactionStatus
    currency: CurrencyMin
    items: list[TransactionItemMixin]
    approved_by: UserMin | None
    total_net_after_dp: Decimal | None


class TransactionFilterMixin:
    status: TransactionStatus | None = None
    payment_term: TransactionPaymentTerm | None = None
    delivery_term: TransactionDeliveryTerm | None = None
    currency_id: PositiveInt | None = None


class TransactionMinMixin(TransactionSummaryMixin, BaseTransactionMixin):
    status: TransactionStatus
    payment_term: TransactionPaymentTerm | None
    delivery_term: TransactionDeliveryTerm | None
    currency: CurrencyMin
    total_net_after_dp: Decimal | None
