from typing import Type, Union
from datetime import date
from dateutil.rrule import rrule, MONTHLY

import pandas as pd
from pydantic import BaseModel


def get_rows_with_duplicated_data(df: pd.DataFrame):
    # Find all duplicated rows (both first and subsequent occurrences)
    duplicated_mask = df.duplicated(keep=False)
    return df[duplicated_mask].index.tolist()


def validate_column_names(
    df: pd.DataFrame, schema: Union[Type[BaseModel], list[str]]
) -> dict:
    if isinstance(schema, type) and issubclass(schema, BaseModel):
        expected_columns = set(schema.model_fields.keys())
    elif isinstance(schema, list):
        expected_columns = set(schema)
    else:
        raise TypeError("schema must be a Pydantic BaseModel or a list of strings")

    df_columns = set(df.columns)
    missing_columns = [(col, "missing") for col in expected_columns - df_columns]
    extra_columns = [(col, "extra") for col in df_columns - expected_columns]
    return missing_columns + extra_columns


def validate_period_completeness(
    df: pd.DataFrame,
    period_col: str,
    period_start: date,
    period_end: date,
) -> list:
    expected_periods = list(rrule(MONTHLY, dtstart=period_start, until=period_end))
    expected_periods = [d.date() for d in expected_periods]
    actual_periods = set(df[period_col].unique().tolist())
    missing_periods = [
        (None, "missing period", p) for p in expected_periods if p not in actual_periods
    ]
    extra_periods = [
        (idx, "extra period", p)
        for p in actual_periods
        if p not in expected_periods
        for idx in df[df[period_col] == p].index.tolist()
    ]
    return missing_periods + extra_periods
