import io
from be_kit.exim.utils import get_version_upload_template
import pandas as pd
from openpyxl import load_workbook
from openpyxl.worksheet.datavalidation import DataValidation
from openpyxl.utils import get_column_letter
from sqlalchemy.ext.asyncio import AsyncSession

from .enums import ContactLegalType, ContactIndustryType
from be_kit.paginations import PaginationQuery
from be_uam.user.models import User
from . import enums, schemas, repositories


async def create_contact(
    db: AsyncSession,
    contact: schemas.ContactCreate,
    request_user: User,
    _commit: bool = True,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ContactRepository(db)
    industry_repo = repositories.IndustryRepository(db)
    address_repo = repositories.AddressRepository(db)
    contact_data = contact.model_dump()
    industries = contact_data.pop("industries")
    addresses = contact_data.pop("addresses")
    obj = await repo.acreate(
        **contact_data,
        organization_id=group.organization_id,
        created_by_id=request_user.pk,
        _commit=False,
    )
    await db.flush()
    for industry in industries:
        await industry_repo.acreate(**industry, contact_id=obj.pk)
    for address in addresses:
        await address_repo.acreate(**address, contact_id=obj.pk)

    if _commit:
        await db.commit()
        await db.refresh(obj)
    return obj


async def retrieve_contact(
    db: AsyncSession,
    pk: int,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ContactRepository(db, organization_id=group.organization_id)
    obj = await repo.aget_or_404(pk)
    return obj


async def list_contacts(
    db: AsyncSession,
    pagination: PaginationQuery,
    filters: schemas.ContactFilter,
    ordering: list[enums.ContactOrdering] | None,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ContactRepository(db, organization_id=group.organization_id)
    objs = await repo.arecords(filters, pagination, ordering)
    return objs


async def update_contact(
    db: AsyncSession,
    pk: int,
    contact: schemas.ContactUpdate,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ContactRepository(db, organization_id=group.organization_id)
    industry_repo = repositories.IndustryRepository(db)
    address_repo = repositories.AddressRepository(db)
    contact_data = contact.model_dump()
    industries = contact_data.pop("industries")
    addresses = contact_data.pop("addresses")
    obj = await repo.aupdate(pk, **contact_data, last_modified_by_id=request_user.pk)

    existing_industries = await industry_repo.arecords({"contact_id": obj.pk})
    industry_pks_in_schema = {
        getattr(ind, "pk", None)
        for ind in industries
        if getattr(ind, "pk", None) is not None
    }
    for existing in existing_industries:
        if existing.pk not in industry_pks_in_schema:
            await industry_repo.adelete(existing.pk)
    for industry in industries:
        if getattr(industry, "pk", None) is not None:
            await industry_repo.aupdate(industry.pk, **industry.model_dump())
        else:
            await industry_repo.acreate(**industry, contact_id=obj.pk)

    existing_addresses = await address_repo.arecords({"contact_id": obj.pk})
    address_pks_in_schema = {
        getattr(addr, "pk", None)
        for addr in addresses
        if getattr(addr, "pk", None) is not None
    }
    for existing in existing_addresses:
        if existing.pk not in address_pks_in_schema:
            await address_repo.adelete(existing.pk)
    for address in addresses:
        if getattr(address, "pk", None) is not None:
            await address_repo.aupdate(address.pk, **address.model_dump())
        else:
            await address_repo.acreate(**address, contact_id=obj.pk)

    await db.commit()
    await db.refresh(obj)
    return obj


async def delete_contact(db: AsyncSession, pk: int, request_user: User):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ContactRepository(db, organization_id=group.organization_id)
    await repo.adelete(pk)


async def handle_contact_upload(
    db: AsyncSession, contact_upload: schemas.ContactUpload, request_user: User
):
    buffer = await contact_upload.file.read()
    contact_data = pd.read_excel(buffer, sheet_name="Contact")
    industry_data = pd.read_excel(buffer, sheet_name="Industry")
    address_data = pd.read_excel(buffer, sheet_name="Address")

    if any(x is None for x in [contact_data, industry_data, address_data]):
        raise ValueError("Invalid template format. Missing required sheets.")

    contacts = []
    for _, row in contact_data.iterrows():
        contact_dict = row.to_dict()
        industries = []
        addresses = []

        contact_industries = industry_data.loc[
            industry_data["contact_name"] == contact_dict.get("name")
        ]
        for _, ind_row in contact_industries.iterrows():
            ind_dict = ind_row.to_dict()
            industries.append(
                schemas.IndustryCreate(industry_type=ind_dict["industry_type"])
            )

        contact_addresses = address_data.loc[
            address_data["contact_name"] == contact_dict.get("name")
        ]
        for _, addr_row in contact_addresses.iterrows():
            addr_dict = addr_row.to_dict()
            addresses.append(
                schemas.AddressCreate(
                    address_type=addr_dict["address_type"],
                    name=addr_dict["name"],
                    email=addr_dict.get("email"),
                    address=addr_dict["address"],
                    phone=addr_dict.get("phone"),
                )
            )

        contact = schemas.ContactCreate(
            name=contact_dict["name"],
            legal_type=contact_dict.get("legal_type"),
            tax_id=str(contact_dict.get("tax_id")) if contact_dict.get("tax_id") is not None else None,
            industries=industries,
            addresses=addresses,
            tag_ids=[],
        )
        contacts.append(contact)

    for contact in contacts:
        await create_contact(db, contact, request_user, _commit=False)

    await db.commit()


async def handle_contact_download(
    db: AsyncSession,
    request_user: User,
):
    group = await request_user.awaitable_attrs.group
    repo = repositories.ContactRepository(db, organization_id=group.organization_id)
    contacts = await repo.arecords()

    industry_repo = repositories.IndustryRepository(db)
    address_repo = repositories.AddressRepository(db)

    contact_rows = []
    industry_rows = []
    address_rows = []

    for contact in contacts:
        contact_dict = {
            "name": contact.name,
            "legal_type": getattr(contact, "legal_type", None),
            "tax_id": getattr(contact, "tax_id", None),
        }
        contact_rows.append(contact_dict)

        industries = await industry_repo.arecords({"contact_id": contact.pk})
        for industry in industries:
            industry_rows.append(
                {
                    "contact_name": contact.name,
                    "industry_type": getattr(industry, "industry_type", None),
                }
            )

        addresses = await address_repo.arecords({"contact_id": contact.pk})
        for address in addresses:
            address_rows.append(
                {
                    "contact_name": contact.name,
                    "address_type": getattr(address, "address_type", None),
                    "name": getattr(address, "name", None),
                    "email": getattr(address, "email", None),
                    "address": getattr(address, "address", None),
                    "phone": getattr(address, "phone", None),
                }
            )

    contact_df = pd.DataFrame(contact_rows)
    industry_df = pd.DataFrame(industry_rows)
    address_df = pd.DataFrame(address_rows)

    output = io.BytesIO()
    with pd.ExcelWriter(output, engine="xlsxwriter") as writer:
        contact_df.to_excel(writer, sheet_name="Contact", index=False)
        industry_df.to_excel(writer, sheet_name="Industry", index=False)
        address_df.to_excel(writer, sheet_name="Address", index=False)
    output.seek(0)
    return output


async def generate_vendor_customer_upload_template(
    schema_validator,
    sheet_name: str,
    filename: str = None,
):
    # Generate template dataframe
    template_df = await get_version_upload_template(schema_validator)

    buffer = io.BytesIO()
    with pd.ExcelWriter(buffer, engine="openpyxl") as writer:
        template_df.to_excel(writer, sheet_name=sheet_name, index=False)
    buffer.seek(0)

    wb = load_workbook(buffer)
    ws = wb[sheet_name]

    legal_type_values = [e.value for e in ContactLegalType]
    industry_type_values = [e.value for e in ContactIndustryType]

    header = [cell.value for cell in ws[1]]
    legal_type_col = get_column_letter(header.index("legal_type") + 1)
    industry_type_col = get_column_letter(header.index("industry_type") + 1)

    # Create hidden sheet for dropdown lists
    if "dropdown_lists" not in wb.sheetnames:
        ws_hidden = wb.create_sheet("dropdown_lists")
    else:
        ws_hidden = wb["dropdown_lists"]

    # Fill hidden sheet with values
    for i, v in enumerate(legal_type_values, start=1):
        ws_hidden[f"A{i}"] = v

    for i, v in enumerate(industry_type_values, start=1):
        ws_hidden[f"B{i}"] = v

    # Hide the sheet
    ws_hidden.sheet_state = "hidden"

    # Create data validation referencing the range
    dv_legal = DataValidation(type="list", formula1="dropdown_lists!$A$1:$A${}".format(len(legal_type_values)))
    ws.add_data_validation(dv_legal)
    dv_legal.add(f"{legal_type_col}2:{legal_type_col}1048576")

    dv_industry = DataValidation(type="list",
        formula1="dropdown_lists!$B$1:$B${}".format(len(industry_type_values)))
    ws.add_data_validation(dv_industry)
    dv_industry.add(f"{industry_type_col}2:{industry_type_col}1048576")

    output = io.BytesIO()
    wb.save(output)
    output.seek(0)

    return output, filename or f"{sheet_name} Template.xlsx"

