from decimal import Decimal
from typing import Any

from pydantic import PositiveFloat, PositiveInt, computed_field, model_validator
from be_procurement.transaction.enums import (
    TransactionDeliveryTerm,
    TransactionPaymentTerm,
    TransactionStatus,
)
from be_procurement.transaction.schemas import (
    BaseTransactionItemMixin,
    TransactionCreateMixin,
    TransactionFilterMixin,
    TransactionItemMixin,
    TransactionMinMixin,
    TransactionMixin,
    TransactionSummaryMixin,
    TransactionUpdateMixin,
    CalculationMixin
)
from be_core_services.country.schemas import CurrencyMin, VATMin
from be_inventory.product.schemas import ProductMin
from be_kit.paginations import BasePaginatedResponse
from be_kit.schemas import BaseORMSchema, BaseSchema
from be_uam.kit.schemas import MetadataMixin
from be_uam.user.schemas import UserMin

from ..order.schemas import OrderItem, OrderMin
from ..vendor.schemas import VendorMin
from . import enums


class BaseInvoice(CalculationMixin):
    payment_term: "TransactionPaymentTerm | None"
    payment_n: "PositiveInt | None"
    payment_dp_rate: "Decimal | None"
    payment_dp: "Decimal | None"
    delivery_term: "TransactionDeliveryTerm | None"
    discount_rate: "Decimal | None"
    discount_amount: "Decimal | None"
    notes: "str | None"


class InvoiceIdMixin:
    invoice_id: str


class InvoiceStatusMixin:
    status: "enums.InvoiceStatus"


class InvoiceWriteMixin(BaseInvoice, BaseSchema):
    order_id: PositiveInt
    items: list["InvoiceItemWrite"]

    @model_validator(mode="before")
    @classmethod
    def validate_data(cls, data: dict[str, Any]) -> dict[str, Any]:
        total_items = len(data["items"])
        if data["discount_rate"] is not None:
            for item in data["items"]:
                item["_global_discount_rate"] = Decimal("%.4f" % data["discount_rate"])
        if data["discount_amount"] is not None:
            for item in data["items"]:
                item["_global_discount_amount"] = Decimal(
                    "%.4f" % (data["discount_amount"] / total_items)
                )
        return data

    @model_validator(mode="after")
    def data_postprocess(self):
        if self.payment_dp_rate is not None:
            self.payment_dp = self.total_net * self.payment_dp_rate
        return self

    @computed_field
    @property
    def total(self) -> Decimal:
        return Decimal("%.4f" % sum(item.total for item in self.items))

    @computed_field
    @property
    def total_discount(self) -> Decimal:
        return Decimal("%.4f" % sum(item.total_discount for item in self.items))

    @computed_field
    @property
    def total_global_discount(self) -> Decimal:
        return Decimal("%.4f" % sum(item.total_global_discount for item in self.items))

    @computed_field
    @property
    def total_vat(self) -> Decimal:
        return Decimal("%.4f" % sum(item.total_vat for item in self.items))

    @computed_field
    @property
    def total_net(self) -> Decimal:
        return Decimal("%.4f" % sum(item.total_net for item in self.items))

    @computed_field
    @property
    def total_net_after_dp(self) -> Decimal:
        return self.total_net - (self.payment_dp or Decimal(0))


class InvoiceItemWrite(BaseTransactionItemMixin, BaseSchema):
    order_item_id: PositiveInt
    quantity: Decimal

    @model_validator(mode="before")
    @classmethod
    def validate_data(cls, data: dict[str, Any]) -> dict[str, Any]:
        if data["discount_rate"] is not None and data["discount_amount"] is not None:
            raise ValueError("Cannot have both discount_rate and discount_amount")
        cls._global_discount_rate = data.pop("_global_discount_rate", None)
        cls._global_discount_amount = data.pop("_global_discount_amount", None)
        return data

    @model_validator(mode="after")
    def data_postprocess(self):
        if self.discount_rate is not None:
            self.discount_amount = self.total * self.discount_rate

        if self._global_discount_rate is not None:
            if self.discount_amount is None:
                self._global_discount_amount = self.total * self._global_discount_rate
            else:
                self._global_discount_amount = (
                    self.total - self.discount_amount
                ) * self._global_discount_rate
        return self

    @computed_field
    @property
    def total(self) -> Decimal:
        return Decimal("%.4f" % (self.quantity * self.price))

    @computed_field
    @property
    def total_discount(self) -> Decimal:
        return (
            Decimal("%.4f" % self.discount_amount)
            if self.discount_amount
            else Decimal(0)
        )

    @computed_field
    @property
    def total_global_discount(self) -> Decimal:
        return (
            Decimal("%.4f" % self._global_discount_amount)
            if self._global_discount_amount
            else Decimal(0)
        )

    @computed_field
    @property
    def total_vat(self) -> Decimal:
        if self.vat_rate is not None:
            return Decimal(
                "%.4f"
                % (
                    (self.total - self.total_discount - self.total_global_discount)
                    * self.vat_rate
                )
            )
        return Decimal(0)

    @computed_field
    @property
    def total_net(self) -> Decimal:
        return Decimal(
            "%.4f"
            % (
                self.total
                - self.total_discount
                - self.total_global_discount
                + self.total_vat
            )
        )


class InvoiceCreate(InvoiceWriteMixin, BaseSchema):
    pass


class InvoiceUpdate(InvoiceWriteMixin, BaseSchema):
    pass


class InvoiceItem(TransactionSummaryMixin, BaseTransactionItemMixin, BaseORMSchema):
    order_item: OrderItem
    total_paid: Decimal


class Invoice(
    TransactionSummaryMixin,
    InvoiceStatusMixin,
    InvoiceIdMixin,
    BaseInvoice,
    MetadataMixin,
    BaseORMSchema,
):
    order: OrderMin
    total_paid: Decimal
    items: list[InvoiceItem]
    total_net_after_dp: "Decimal | None"


class InvoiceList(
    TransactionSummaryMixin,
    InvoiceStatusMixin,
    InvoiceIdMixin,
    BaseInvoice,
    MetadataMixin,
    BaseORMSchema,
):
    order: OrderMin
    total_paid: Decimal
    total_net_after_dp: "Decimal | None"


class PaginatedInvoice(BasePaginatedResponse):
    items: list[InvoiceList]


class InvoiceFilter(BaseSchema):
    status: "enums.InvoiceStatus | None" = None
    order_id: "PositiveInt | None" = None
    invoice_id__icontains: str | None = None
    order___quotation___quotation_id__icontains: str | None = None
    order___order_id__icontains: str | None = None
    order___vendor___contact___name__icontains: str | None = None


class InvoiceMin(InvoiceStatusMixin, InvoiceIdMixin, BaseInvoice, BaseORMSchema):
    order: "OrderMin | None"
    total_net_after_dp: "Decimal | None"


class InvoiceOpt(InvoiceMin):
    pass


class PaginatedInvoiceOpt(BasePaginatedResponse):
    items: list[InvoiceOpt]


VendorMin.model_rebuild()
ProductMin.model_rebuild()
UserMin.model_rebuild()
CurrencyMin.model_rebuild()
VATMin.model_rebuild()
