import io

import pandas as pd
from sqlalchemy.ext.asyncio import AsyncSession

from be_kit.paginations import PaginationQuery
from be_uam.user.models import User

from be_accounting.coa.utils import list_account_mappings
from be_accounting.journal import utils as journal_utils
from be_accounting.journal import enums as journal_enums, schemas as journal_schemas
from . import enums, schemas, repositories
from be_core_services.country.repositories import CurrencyRepository
from be_core_services.country.schemas import CurrencyFilter
from .models import ProductMovementItem

# ProductCategory


async def create_product_category(
    db: AsyncSession,
    product_category: schemas.ProductCategoryCreate,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductCategoryRepository(db)
    obj = await repo.acreate(
        **product_category.model_dump(),
        organization_id=group.organization_id,
        created_by_id=request_user.pk,
    )
    return obj


async def retrieve_product_category(
    db: AsyncSession,
    pk: int,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductCategoryRepository(
        db, organization_id=group.organization_id
    )
    obj = await repo.aget_or_404(pk)
    return obj


async def update_product_category(
    db: AsyncSession,
    pk: int,
    product_category: schemas.ProductCategoryUpdate,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductCategoryRepository(
        db, organization_id=group.organization_id
    )
    obj = await repo.aupdate(pk, **product_category.model_dump())
    return obj


async def list_product_categories(
    db: AsyncSession,
    pagination: PaginationQuery,
    filters: schemas.ProductCategoryFilter,
    ordering: list[enums.ProductCategoryOrdering] | None,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductCategoryRepository(
        db, organization_id=group.organization_id
    )
    objs = await repo.arecords(filters, pagination, ordering)
    return objs


async def delete_product_category(db: AsyncSession, pk: int, request_user: User):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductCategoryRepository(
        db, organization_id=group.organization_id
    )
    await repo.adelete(pk)


# Product


async def create_product(
    db: AsyncSession,
    product: schemas.ProductCreate,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductRepository(db, organization_id=group.organization_id)
    obj = await repo.acreate(
        **product.model_dump(),
        organization_id=group.organization_id,
        created_by_id=request_user.pk,
    )
    return obj


async def retrieve_product(
    db: AsyncSession,
    pk: int,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductRepository(db, organization_id=group.organization_id)
    obj = await repo.aget_or_404(pk)
    return obj


async def update_product(
    db: AsyncSession,
    pk: int,
    product: schemas.ProductUpdate,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductRepository(db, organization_id=group.organization_id)
    obj = await repo.aupdate(pk, **product.model_dump())
    return obj


async def list_products(
    db: AsyncSession,
    pagination: PaginationQuery,
    filters: schemas.ProductFilter,
    ordering: list[enums.ProductOrdering] | None,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductRepository(db, organization_id=group.organization_id)
    filters_dict = filters.model_dump(exclude_unset=True)

    # Remove None values
    filters_dict = {k: v for k, v in filters_dict.items() if v is not None}

    warehouse_filter = filters_dict.pop("warehouse_id", None)
    if warehouse_filter is not None:
        filters_dict["stocks___warehouse_id"] = warehouse_filter
        filters_dict["stocks___available__gt"] = 0

    print("Filters used:", filters_dict)
    objs = await repo.arecords(filters_dict, pagination, ordering)

    # Only add available_warehouses if warehouse filter is present
    if warehouse_filter is not None:
        stock_repo = repositories.ProductStockRepository(db, organization_id=group.organization_id)
        for product in objs["items"]:
            stock_filters = {
                "product_id": product.pk,
                "available__gt": 0,
            }
            stocks = await stock_repo.arecords(stock_filters)
            stock_items = stocks["items"] if isinstance(stocks, dict) else list(stocks)
            product.available_warehouses = [
                schemas.WarehouseMin.model_validate(stock.warehouse).model_dump()
                for stock in stock_items
            ]
    return objs


async def delete_product(db: AsyncSession, pk: int, request_user: User):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductRepository(db, organization_id=group.organization_id)
    await repo.adelete(pk)


async def handle_product_upload(
    db: AsyncSession,
    product_upload: schemas.ProductUpload,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductRepository(db, organization_id=group.organization_id)
    category_repo = repositories.ProductCategoryRepository(db, organization_id=group.organization_id)
    currency_repo = CurrencyRepository(db)  # Add this repository

    content = await product_upload.file.read()
    df = pd.read_excel(io.BytesIO(content))

    required_columns = {"name", "category", "description", "product_type", "product_unit", "price", "currency_iso_code"}
    if not required_columns.issubset(df.columns):
        missing = required_columns - set(df.columns)
        raise ValueError(f"Missing required columns: {', '.join(missing)}")

    created_products = []
    for _, row in df.iterrows():
        # Validate row using ProductValidator schema
        validated = schemas.ProductValidator.model_validate({
            "name": row["name"],
            "category": row["category"],
            "description": row["description"],
            "product_type": row["product_type"],
            "product_unit": row["product_unit"],
            "price": row["price"],
            "currency_iso_code": row["currency_iso_code"],
        })

        # Get category pk from category name
        category_obj = await category_repo.arecords(
            schemas.ProductCategoryFilter(name=validated.category),
            PaginationQuery(limit=1, offset=0),
            None
        )
        items = list(category_obj["items"])
        if not items or len(items) == 0:
            raise ValueError(f"Category '{validated.category}' not found")
        category_pk = items[0].pk

        # Get currency pk from iso code
        currency_obj = await currency_repo.arecords(
            CurrencyFilter(iso=validated.currency_iso_code),
            PaginationQuery(limit=1, offset=0),
            None
        )
        currency_items = list(currency_obj["items"])
        if not currency_items or len(currency_items) == 0:
            raise ValueError(f"Currency '{validated.currency_iso_code}' not found")
        currency_pk = currency_items[0].pk

        product_data = {
            "name": validated.name,
            "description": validated.description,
            "category_id": category_pk,
            "product_type": validated.product_type,
            "product_unit": validated.product_unit,
            "price": validated.price,
            "currency_id": currency_pk,
        }
        obj = await repo.acreate(
            **product_data,
            organization_id=group.organization_id,
            created_by_id=request_user.pk,
        )
        created_products.append(schemas.ProductMin.model_validate(obj).model_dump())

    return created_products


async def handle_product_download(
    db: AsyncSession,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductRepository(db, organization_id=group.organization_id)
    products = await repo.arecords()

    product_rows = []
    for product in products:
        product_rows.append(
            {
                "name": product.name,
                "category": product.category.name,
                "description": product.description,
                "product_type": product.product_type,
                "product_unit": product.product_unit,
                "price": product.price,
                "currency_iso_code": product.currency.iso,
            }
        )

    df = pd.DataFrame(product_rows)
    output = io.BytesIO()
    with pd.ExcelWriter(output, engine="openpyxl") as writer:
        df.to_excel(writer, sheet_name="product", index=False)
    output.seek(0)
    return output


# ProductStock


async def create_product_stock(
    db: AsyncSession,
    product_stock: schemas.ProductStockCreate,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductStockRepository(db)
    obj = await repo.acreate(
        **product_stock.model_dump(),
        organization_id=group.organization_id,
        created_by_id=request_user.pk,
    )
    return obj


async def retrieve_product_stock(
    db: AsyncSession,
    pk: int,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductStockRepository(
        db, organization_id=group.organization_id
    )
    obj = await repo.aget_or_404(pk)
    return obj


async def update_product_stock(
    db: AsyncSession,
    pk: int,
    product_stock: schemas.ProductStockUpdate,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductStockRepository(
        db, organization_id=group.organization_id
    )
    obj = await repo.aupdate(pk, **product_stock.model_dump())
    return obj


async def list_product_stocks(
    db: AsyncSession,
    pagination: PaginationQuery,
    filters: schemas.ProductStockFilter,
    ordering: list[enums.ProductStockOrdering] | None,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductStockRepository(
        db, organization_id=group.organization_id
    )
    objs = await repo.arecords(filters, pagination, ordering)
    return objs


async def delete_product_stock(db: AsyncSession, pk: int, request_user: User):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductStockRepository(
        db, organization_id=group.organization_id
    )
    await repo.adelete(pk)


# ProductMovement


async def create_product_movement(
    db: AsyncSession,
    product_movement: schemas.ProductMovementCreate,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductMovementRepository(db)
    obj = await repo.acreate(
        **product_movement.model_dump(),
        created_by_id=request_user.pk,
        organization_id=group.organization_id,
    )
    return obj


async def retrieve_product_movement(
    db: AsyncSession,
    pk: int,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductMovementRepository(
        db, organization_id=group.organization_id
    )
    obj = await repo.aget_or_404(pk)
    return obj


async def update_product_movement(
    db: AsyncSession,
    pk: int,
    product_movement: schemas.ProductMovementUpdate,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductMovementRepository(
        db, organization_id=group.organization_id
    )
    update_data = product_movement.model_dump()
    if "items" in update_data:
        update_data["items"] = [
            ProductMovementItem(**item) if isinstance(item, dict) else item
            for item in update_data["items"]
        ]
    obj = await repo.aupdate(pk, **update_data)
    return obj


async def update_product_movement_status(
    db: AsyncSession,
    pk: int,
    status_update: schemas.ProductMovementStatusUpdate,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductMovementRepository(
        db, organization_id=group.organization_id
    )
    product_stock_repo = repositories.ProductStockRepository(
        db, organization_id=group.organization_id
    )
    # Get the current movement before update
    current_obj = await repo.aget_or_404(pk)
    prev_status = current_obj.status

    #P revent updating to the same status
    if status_update.status == prev_status:
        raise ValueError("The status is already set to the specified value.")

    obj = await repo.aupdate(pk, **status_update.model_dump())

    for item in obj.items:
        source_warehouse_id = None if obj.source_warehouse == None else obj.source_warehouse.pk
        destination_warehouse_id = None if obj.destination_warehouse == None else obj.destination_warehouse.pk
        await repo.move_stock(
            status=obj.status,
            movement_type=obj.movement_type,
            product_id=item.product_id,
            source_warehouse_id=source_warehouse_id,
            destination_warehouse_id=destination_warehouse_id,
            quantity=item.quantity,
            product_stock_repo=product_stock_repo,
        )

    if prev_status == enums.ProductMovementStatus.APPROVED and obj.status == enums.ProductMovementStatus.COMPLETED:
        # Fetch account mapping for the organization
        pagination = PaginationQuery(page=1, size=1)
        mappings = await list_account_mappings(db, pagination, request_user)
        items = list(mappings["items"]) if isinstance(mappings, dict) else list(mappings.items)
        mapping = items[0] if items else None
        if not mapping or not mapping.cost_of_revenue_account_id or not mapping.default_inventory_account_id:
            raise Exception("Account mapping for Cost of Revenue or Inventory is not set.")
        elif not mapping.goods_received_invoice_received_account_id:
            raise Exception("Account mapping for Goods Received Invoice Received is not set.")

        entry_items = []
        if obj.movement_type == enums.ProductMovementType.OUT:
            for item in obj.items:
                # Use product price as cost
                cost = item.quantity * (item.product.price or 0)
                if cost <= 0:
                    raise Exception(f"Invalid cost for product {item.product.name}.")
                # Debit Cost of Revenue
                entry_items.append({
                    "account_id": mapping.cost_of_revenue_account_id,
                    "amount": cost,
                    "side": journal_enums.EntrySide.DEBIT,
                    "description": f"Cost of {item.product.name} moved out"
                })
                # Credit Inventory
                entry_items.append({
                    "account_id": mapping.default_inventory_account_id,
                    "amount": cost,
                    "side": journal_enums.EntrySide.CREDIT,
                    "description": f"Inventory reduction for {item.product.name}"
                })

            entry = journal_schemas.EntryCreate(
                date=obj.last_modified_at,
                description=f"COGS for Product Movement {obj.identifier}",
                entry_type=journal_enums.EntryTypeEnum.PRODUCT,
                items=entry_items,
                currency_id=mapping.default_inventory_account.currency_id,
            )
        elif obj.movement_type == enums.ProductMovementType.IN:
            for item in obj.items:
                # Use product price as cost
                cost = item.quantity * (item.product.price or 0)
                if cost <= 0:
                    raise Exception(f"Invalid cost for product {item.product.name}.")
                # Debit Inventory
                entry_items.append({
                    "account_id": mapping.default_inventory_account_id,
                    "amount": cost,
                    "side": journal_enums.EntrySide.DEBIT,
                    "description": f"Inventory addition for {item.product.name}"
                })
                # Credit Cost of Revenue
                entry_items.append({
                    "account_id": mapping.goods_received_invoice_received_account_id,
                    "amount": cost,
                    "side": journal_enums.EntrySide.CREDIT,
                    "description": f"Reversal of Cost of {item.product.name} moved in"
                })

            entry = journal_schemas.EntryCreate(
                date=obj.last_modified_at,
                description=f"Inventory Input for Product Movement {obj.identifier}",
                entry_type=journal_enums.EntryTypeEnum.PRODUCT,
                items=entry_items,
                currency_id=mapping.default_inventory_account.currency_id,
            )
        journal_entry = await journal_utils.record_entry(db, entry, request_user)

        await repo.aupdate(
            obj.pk,
            journal_entry_id=journal_entry.pk,
            last_modified_by_id=request_user.pk,
        )

    db.commit()
    db.refresh(obj)
    return obj


async def approve_product_movement(
    db: AsyncSession,
    pk: int,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductMovementRepository(
        db, organization_id=group.organization_id
    )
    product_stock_repo = repositories.ProductStockRepository(
        db, organization_id=group.organization_id
    )
    # Get the movement object (with items)
    obj = await repo.aget_or_404(pk)
    # For each item, move stock if warehouse changes
    for item in obj.items:
        source_warehouse_id = None if obj.source_warehouse == None else obj.source_warehouse.pk
        destination_warehouse_id = obj.destination_warehouse.pk if obj.destination_warehouse else None
        await repo.move_stock(
            status=obj.status,
            movement_type=obj.movement_type,
            product_id=item.product_id,
            source_warehouse_id=source_warehouse_id,
            destination_warehouse_id=destination_warehouse_id,
            quantity=item.quantity,
            product_stock_repo=product_stock_repo,
        )
    # Approve the movement
    obj = await repo.aupdate(
        pk, status=enums.ProductMovementStatus.APPROVED, approver_id=request_user.pk
    )
    return obj


async def list_product_movements(
    db: AsyncSession,
    pagination: PaginationQuery,
    filters: schemas.ProductMovementFilter,
    ordering: list[enums.ProductMovementOrdering] | None,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductMovementRepository(
        db, organization_id=group.organization_id
    )
    objs = await repo.arecords(filters, pagination, ordering)
    return objs


async def delete_product_movement(db: AsyncSession, pk: int, request_user: User):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductMovementRepository(
        db, organization_id=group.organization_id
    )
    await repo.adelete(pk)

async def list_product_movement_items(
    db: AsyncSession,
    pagination: PaginationQuery,
    filters: schemas.ProductMovementItemFilter = None,
    ordering: list[enums.ProductMovementOrdering] | None = None,
    request_user: User = None,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ProductMovementRepository(db, organization_id=group.organization_id)
    all_movements = await repo.arecords(filters, None, ordering)
    if hasattr(all_movements, "all"):
        all_movements = all_movements.all()
    items = []
    for movement in all_movements:
        for item in movement.items:
            item_view = schemas.ProductMovementItemView(
                pk=movement.pk,
                status=movement.status,
                identifier=movement.identifier,
                movement_type=movement.movement_type,
                note=movement.note,
                source_warehouse=movement.source_warehouse,
                destination_warehouse=movement.destination_warehouse,
                approver=movement.approver,
                product=item.product,
                quantity=item.quantity,
                created_at=movement.created_at,
                last_modified_at=movement.last_modified_at,
                created_by=movement.created_by,
                last_modified_by=movement.last_modified_by,
            )
            items.append(item_view)
    # Paginate the flattened items
    start = (pagination.page - 1) * pagination.page_size
    end = start + pagination.page_size
    paginated_items = items[start:end]
    total_items = len(items)
    total_page = (total_items + pagination.page_size - 1) // pagination.page_size
    return schemas.PaginatedProductMovementItemView(
        items=paginated_items,
        total_items=total_items,
        total_page=total_page,
        page=pagination.page,
    )

# Currency Utils

async def get_all_currency_iso_codes(db, request_user: User):
    group = await request_user.awaitable_attrs.group
    repo = CurrencyRepository(db, organization_id=group.organization_id)
    currencies = await repo.arecords()
    # Assuming each currency has an 'iso' attribute
    return [cur.iso for cur in currencies]

async def get_all_product_category_names(db: AsyncSession, request_user: User):
    group = await request_user.awaitable_attrs.group
    from .repositories import ProductCategoryRepository
    repo = ProductCategoryRepository(db, organization_id=group.organization_id)
    categories = await repo.arecords()
    if hasattr(categories, "all"):
        categories = categories.all()
    return [cat.name for cat in categories]
