#!/usr/bin/env python3
"""
Build an abstaining fast-and-frugal tree for regional sambar / amti / dal-curry recipes.

Place this file in the same directory as `sambar_fftree.py` and run, for example:

    uv run sambar_fftree_abstension.py --max-recipes 2000

What it does:
1. Reuses any previously downloaded dataset, preferably `sambar_out/sambar_recipe_dataset.csv`,
   then `./sambar_recipe_dataset.csv`.
2. Extends the dataset, when possible, using the same collectors from `sambar_fftree.py`
   plus extra RecipeNLG / Cookpad queries and optional public web search + JSON-LD extraction.
3. Deduplicates recipes.
4. Generates:
   - an abstaining fast-and-frugal tree with branch activation counts,
   - confidence-ranked dialect detection,
   - CSVs for the dataset, tree predictions, dialect predictions, and rule diagnostics.

Design note:
This is deliberately NOT a forced 5-state classifier. It tries to say “I know when I know,”
and leaves the overlapping middle as `mixed / abstain`.
"""
from __future__ import annotations

import argparse
import ast
import html
import json
import math
import re
import sys
import time
from collections import Counter, defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple
from urllib.parse import quote_plus, unquote, urlparse, parse_qs, urljoin

import pandas as pd

try:
    import requests
except Exception:  # pragma: no cover
    requests = None

try:
    from bs4 import BeautifulSoup
except Exception:  # pragma: no cover
    BeautifulSoup = None

# Import the earlier script from the same directory. This keeps this file small and lets it
# reuse the exact feature vocabulary, downloaders, and deduplication logic already present.
try:
    import sambar_fftree as base
except Exception as exc:  # pragma: no cover
    raise SystemExit(
        "Could not import sambar_fftree.py. Put this script in the same directory as "
        "sambar_fftree.py and run again. Original error: " + repr(exc)
    ) from exc

STATES = list(base.STATES)
FEATURES = list(base.INGREDIENT_PATTERNS.keys())

# More deliberately regional queries than the first script. These are intentionally broad,
# because the abstaining tree can tolerate overlap better than a forced classifier.
EXPANDED_STATE_QUERIES: Dict[str, Sequence[str]] = {
    "Tamil Nadu": [
        "Tamil Nadu sambar recipe", "arachuvitta sambar recipe", "murungakkai sambar recipe",
        "vendakkai sambar recipe", "kalyana sambar Tamil recipe", "hotel sambar Tamil Nadu recipe",
        "tiffin sambar idli dosa Tamil recipe", "Arachuvitta kuzhambu sambar", "kathirikai sambar Tamil",
        "Kongunad sambar", "Chettinad sambar", "Thakkali sambar Tamil",
    ],
    "Kerala": [
        "Kerala sambar recipe", "Onam sadya sambar recipe", "varutharacha sambar recipe",
        "Kerala mixed vegetable sambar recipe", "Kerala coconut sambar recipe", "Mathanga sambar Kerala",
        "Kumbalanga sambar Kerala", "Vendakka sambar Kerala", "Malabar sambar recipe",
        "Cherupayar sambar Kerala", "Kerala temple sambar",
    ],
    "Karnataka": [
        "Udupi sambar recipe", "Karnataka sambar recipe", "Mysore sambar recipe", "huli recipe Karnataka",
        "Mangalore sambar recipe", "koddel recipe", "byadagi sambar recipe", "hotel style sambar Karnataka",
        "idli sambar Udupi recipe", "Karnataka sweet sambar jaggery", "Southekayi huli recipe",
    ],
    "Andhra": [
        "Andhra sambar recipe", "pappu charu recipe", "Andhra pappu charu recipe", "Telugu sambar recipe",
        "Guntur sambar recipe", "pulusu pappu recipe", "mudda pappu charu recipe", "tomato pappu charu",
        "dosakaya pappu charu", "mukkala pulusu recipe", "Andhra dal sambar",
    ],
    "Maharashtra": [
        "Maharashtrian amti recipe", "Maharashtrian aamti recipe", "amti dal recipe", "aamti dal recipe",
        "katachi amti recipe", "goda masala dal amti", "varan amti recipe", "kokum amti recipe",
        "Maharashtrian toor dal amti", "Maharashtrian sambar recipe", "ambat varan recipe",
    ],
}

# Optional site search. These sites often expose recipeIngredient in JSON-LD, but if any site
# blocks scraping the script simply skips it. Use responsibly and keep the default delay.
RECIPE_SITE_DOMAINS = [
    "hebbarskitchen.com",
    "subbuskitchen.com",
    "vegrecipesofindia.com",
    "rakskitchen.net",
    "padhuskitchen.com",
    "sharmispassions.com",
    "kannammacooks.com",
    "yummytummyaarthi.com",
    "archanaskitchen.com",
    "sailusfood.com",
    "udupi-recipes.com",
    "smithakalluraya.com",
    "madhurasrecipe.com",
    "maharashtrianrecipes.com",
]

FAMILY_RE = re.compile(
    r"samb[ha]?r|sambaar|sam?bar|amti|aamti|pappu\s+charu|pappu\s+chaaru|pulusu|\bhuli\b|koddel|kuzhambu|saaru|varan|sadya",
    re.I,
)

OBVIOUS_NOISE_RE = re.compile(
    r"\b(chicken|mutton|beef|pork|fish|prawn|shrimp|cake|cookies?|muffin|french toast|pasta|pizza|salad|sandwich|burger|smoothie|ice cream|hotel butter)\b",
    re.I,
)


def norm(x: object) -> str:
    if x is None or (isinstance(x, float) and math.isnan(x)):
        return ""
    if isinstance(x, (list, tuple)):
        return " | ".join(map(str, x))
    s = str(x)
    if s.strip().startswith("["):
        try:
            val = ast.literal_eval(s)
            if isinstance(val, list):
                s = " | ".join(map(str, val))
        except Exception:
            pass
    return html.unescape(s).lower()


def yes(row: pd.Series, feature: str) -> bool:
    return int(row.get(feature, 0) or 0) == 1


def text_of(row: pd.Series) -> str:
    return norm(" | ".join(str(row.get(c, "")) for c in ["state", "title", "query", "source", "ingredients_text", "url"]))


def state_distribution(sub: pd.DataFrame) -> Dict[str, int]:
    counts = sub["state"].astype(str).value_counts().sort_index().to_dict()
    return {k: int(v) for k, v in counts.items()}


def dist_str(sub: pd.DataFrame) -> str:
    d = state_distribution(sub)
    return "; ".join(f"{k}:{v}" for k, v in d.items()) if d else ""


def safe_feature_columns(df: pd.DataFrame) -> pd.DataFrame:
    """Ensure every ingredient feature in the base vocabulary exists."""
    if df.empty:
        return df
    df = base.add_features(df.copy())
    for col in FEATURES:
        if col not in df.columns:
            df[col] = 0
        df[col] = df[col].fillna(0).astype(int).clip(0, 1)
    return df


def find_existing_dataset(output_dir: Path, explicit_input: Optional[str]) -> Optional[Path]:
    candidates: List[Path] = []
    if explicit_input:
        candidates.append(Path(explicit_input))
    candidates.extend([
        output_dir / "sambar_recipe_dataset.csv",
        Path("sambar_out") / "sambar_recipe_dataset.csv",
        Path("sambar_recipe_dataset.csv"),
        Path("sambar_public_validation_dataset.csv"),
    ])
    for p in candidates:
        if p.exists() and p.is_file():
            return p
    return None


def load_existing_dataset(path: Optional[Path]) -> pd.DataFrame:
    if path is None:
        return pd.DataFrame()
    df = pd.read_csv(path)
    if "state" not in df.columns:
        raise SystemExit(f"{path} does not have a 'state' column")
    if "ingredients_text" not in df.columns:
        df["ingredients_text"] = ""
    # Drop rows with unknown labels.
    df = df[df["state"].astype(str).isin(STATES)].copy()
    return safe_feature_columns(df)


def canonical_key_for_row(row: pd.Series) -> str:
    return base.canonical_key(str(row.get("title", "")), str(row.get("ingredients_text", "")))


def dedupe(df: pd.DataFrame) -> pd.DataFrame:
    if df.empty:
        return df
    df = df.copy()
    # Canonical text dedupe first.
    keys = df.apply(canonical_key_for_row, axis=1)
    df = df.loc[~keys.duplicated()].copy()
    # URL dedupe when URL is present.
    if "url" in df.columns:
        url_key = df["url"].fillna("").astype(str).str.lower().str.replace(r"[?#].*$", "", regex=True)
        keep = (url_key == "") | (~url_key.duplicated())
        df = df.loc[keep].copy()
    return df.reset_index(drop=True)


def is_recipe_family(row: pd.Series) -> bool:
    t = text_of(row)
    return bool(FAMILY_RE.search(t))


def is_obvious_noise(row: pd.Series) -> bool:
    t = text_of(row)
    return bool(OBVIOUS_NOISE_RE.search(t))


def annotate_quality(df: pd.DataFrame) -> pd.DataFrame:
    if df.empty:
        return df
    df = df.copy()
    df["recipe_family_match"] = df.apply(is_recipe_family, axis=1).astype(int)
    df["obvious_noise_title"] = df.apply(is_obvious_noise, axis=1).astype(int)
    return df


def cap_rows(df: pd.DataFrame, max_recipes: int) -> pd.DataFrame:
    """Cap rows gently. Preserve the existing order, but avoid one state swamping the sample."""
    if len(df) <= max_recipes:
        return df.reset_index(drop=True)
    # Round-robin by state. This is less order-biased than groupby.head.
    groups = {s: g.copy().reset_index(drop=True) for s, g in df.groupby("state", sort=False)}
    out = []
    i = 0
    while len(out) < max_recipes:
        added = False
        for s in STATES:
            g = groups.get(s)
            if g is not None and i < len(g):
                out.append(g.iloc[i])
                added = True
                if len(out) >= max_recipes:
                    break
        if not added:
            break
        i += 1
    return pd.DataFrame(out).reset_index(drop=True)


def http_get(url: str, timeout: int = 25) -> str:
    if requests is None:
        raise RuntimeError("requests is not installed")
    headers = {
        "User-Agent": "Mozilla/5.0 (compatible; sambar-recipe-research/2.0; +https://openai.com/)"
    }
    r = requests.get(url, headers=headers, timeout=timeout)
    r.raise_for_status()
    return r.text


def duckduckgo_search_urls(query: str, max_urls: int = 8) -> List[str]:
    """Return result URLs from DuckDuckGo's static HTML endpoint. Skips gracefully if blocked."""
    if requests is None or BeautifulSoup is None:
        return []
    url = "https://duckduckgo.com/html/?q=" + quote_plus(query)
    try:
        page = http_get(url)
    except Exception:
        return []
    soup = BeautifulSoup(page, "html.parser")
    urls: List[str] = []
    for a in soup.find_all("a", href=True):
        href = a["href"]
        label = a.get_text(" ", strip=True)
        if not label:
            continue
        real = href
        # DDG often uses /l/?uddg=<encoded-url>
        if "uddg=" in href:
            try:
                parsed = urlparse(href)
                qs = parse_qs(parsed.query)
                real = unquote(qs.get("uddg", [href])[0])
            except Exception:
                real = href
        if real.startswith("http") and real not in urls and not re.search(r"duckduckgo|google|bing", real):
            urls.append(real)
        if len(urls) >= max_urls:
            break
    return urls


def fetch_url_recipe(url: str, fallback_state: str, query: str) -> Optional[Dict[str, str]]:
    try:
        title, ingred = base.fetch_recipe_page(url, fallback_title="")
    except Exception:
        title, ingred = "", ""
    if not ingred or len(norm(ingred)) < 20:
        return None
    state = base.infer_state(f"{title} {ingred} {query} {url}", fallback=fallback_state)
    if state not in STATES:
        return None
    return {
        "state": state,
        "title": title or query,
        "url": url,
        "source": "site-search-jsonld",
        "query": query,
        "ingredients_text": ingred,
    }


def collect_expanded_recipes(max_new_rows: int, per_state_target: Optional[int], delay: float, no_hf: bool, no_cookpad: bool, no_site_search: bool) -> pd.DataFrame:
    rows: List[Dict[str, str]] = []
    if max_new_rows <= 0:
        return pd.DataFrame()

    per_state = per_state_target or math.ceil(max_new_rows / max(1, len(STATES)))
    counts = Counter()

    # 1) Reuse the base collector, but ask it for a bit extra. It already knows HF + Cookpad.
    try:
        base_df = base.collect_recipes(
            max_recipes=max_new_rows,
            per_state=per_state,
            use_cookpad=not no_cookpad,
            use_hf=not no_hf,
        )
        if not base_df.empty:
            rows.extend(base_df[[c for c in base_df.columns if c in ["state", "title", "url", "source", "query", "ingredients_text"]]].to_dict("records"))
            counts.update(r["state"] for r in rows if r.get("state") in STATES)
    except Exception as exc:
        print(f"[warn] base.collect_recipes failed: {exc}", file=sys.stderr)

    # 2) Extra RecipeNLG searches with richer regional terms.
    if not no_hf:
        for state, queries in EXPANDED_STATE_QUERIES.items():
            for q in queries:
                if len(rows) >= max_new_rows and counts[state] >= per_state:
                    break
                try:
                    new_rows = base.fetch_recipenlg_search(q, state, max_rows=100)
                    rows.extend(new_rows)
                    counts.update(r["state"] for r in new_rows if r.get("state") in STATES)
                except Exception as exc:
                    print(f"[warn] RecipeNLG search failed for {q!r}: {exc}", file=sys.stderr)

    # 3) Extra Cookpad queries.
    if not no_cookpad:
        for state, queries in EXPANDED_STATE_QUERIES.items():
            for q in queries:
                if len(rows) >= max_new_rows and counts[state] >= per_state:
                    break
                try:
                    new_rows = base.fetch_cookpad_search(q, state, max_links=12, delay=delay)
                    rows.extend(new_rows)
                    counts.update(r["state"] for r in new_rows if r.get("state") in STATES)
                except Exception as exc:
                    print(f"[warn] Cookpad search failed for {q!r}: {exc}", file=sys.stderr)

    # 4) Optional static web search over known recipe domains, extracting JSON-LD where possible.
    if not no_site_search:
        seen_urls = set()
        for state, queries in EXPANDED_STATE_QUERIES.items():
            for q in queries[:6]:  # keep it polite
                if len(rows) >= max_new_rows and counts[state] >= per_state:
                    break
                domain = RECIPE_SITE_DOMAINS[len(seen_urls) % len(RECIPE_SITE_DOMAINS)] if RECIPE_SITE_DOMAINS else ""
                # Search broad first, then targeted-site as a fallback.
                search_queries = [f"{q} recipe ingredients", f"site:{domain} {q} recipe"] if domain else [f"{q} recipe ingredients"]
                for sq in search_queries:
                    urls = duckduckgo_search_urls(sq, max_urls=6)
                    for u in urls:
                        clean = re.sub(r"[?#].*$", "", u)
                        if clean in seen_urls:
                            continue
                        seen_urls.add(clean)
                        time.sleep(delay)
                        row = fetch_url_recipe(u, state, q)
                        if row is not None:
                            rows.append(row)
                            counts[row["state"]] += 1
                            if len(rows) >= max_new_rows:
                                break
                    if len(rows) >= max_new_rows:
                        break

    df = pd.DataFrame(rows)
    if df.empty:
        return df
    return safe_feature_columns(df)


# --------------------------
# Abstaining tree definitions
# --------------------------

@dataclass(frozen=True)
class Rule:
    name: str
    predicted_state: str
    test: Callable[[pd.Series], bool]
    explanation: str


EXPERT_RULES: List[Rule] = [
    Rule(
        name="kokum or goda masala",
        predicted_state="Maharashtra",
        test=lambda r: yes(r, "kokum") or yes(r, "goda masala"),
        explanation="amti ingredient vocabulary: kokum / goda masala",
    ),
    Rule(
        name="sesame/gingelly oil",
        predicted_state="Tamil Nadu",
        test=lambda r: yes(r, "sesame/gingelly oil"),
        explanation="gingelly oil is a strong Tamil-style clue in this feature space",
    ),
    Rule(
        name="coconut oil + shallots",
        predicted_state="Kerala",
        test=lambda r: yes(r, "coconut oil") and yes(r, "shallots"),
        explanation="Kerala-style tempering smell: coconut oil plus shallots",
    ),
    Rule(
        name="moong dal + no mustard seeds",
        predicted_state="Andhra",
        test=lambda r: yes(r, "moong dal") and not yes(r, "mustard seeds"),
        explanation="pappu/charu-like signal: moong dal without the usual mustard tempering",
    ),
    Rule(
        name="sambar powder + tamarind",
        predicted_state="Tamil Nadu",
        test=lambda r: yes(r, "sambar powder") and yes(r, "tamarind"),
        explanation="mainstream tiffin/hotel sambar signal; often Tamil in scraped recipes",
    ),
    Rule(
        name="byadagi chillies",
        predicted_state="Karnataka",
        test=lambda r: yes(r, "byadagi chillies"),
        explanation="Byadagi is a Karnataka-ish masala clue",
    ),
    Rule(
        name="grated coconut + jaggery",
        predicted_state="Karnataka",
        test=lambda r: yes(r, "grated coconut") and yes(r, "jaggery"),
        explanation="sweet-coconut sambar/huli axis; useful but not perfectly clean",
    ),
]

DIALECT_RULES: List[Rule] = [
    # Maharashtra
    Rule("kokum", "Maharashtra", lambda r: yes(r, "kokum"), "kokum"),
    Rule("goda masala", "Maharashtra", lambda r: yes(r, "goda masala"), "goda masala"),
    # Tamil Nadu
    Rule("sesame/gingelly oil", "Tamil Nadu", lambda r: yes(r, "sesame/gingelly oil"), "sesame/gingelly oil"),
    Rule("sambar powder + tamarind", "Tamil Nadu", lambda r: yes(r, "sambar powder") and yes(r, "tamarind"), "sambar powder + tamarind"),
    # Kerala
    Rule("coconut oil + shallots", "Kerala", lambda r: yes(r, "coconut oil") and yes(r, "shallots"), "coconut oil + shallots"),
    Rule("grated coconut + no jaggery", "Kerala", lambda r: yes(r, "grated coconut") and not yes(r, "jaggery"), "grated coconut without jaggery"),
    # Andhra
    Rule("moong dal + no mustard seeds", "Andhra", lambda r: yes(r, "moong dal") and not yes(r, "mustard seeds"), "moong dal without mustard seeds"),
    Rule("garlic + green chilli", "Andhra", lambda r: yes(r, "garlic") and yes(r, "green chilli"), "garlic + green chilli"),
    # Karnataka
    Rule("byadagi chillies", "Karnataka", lambda r: yes(r, "byadagi chillies"), "Byadagi chillies"),
    Rule("grated coconut + jaggery", "Karnataka", lambda r: yes(r, "grated coconut") and yes(r, "jaggery"), "grated coconut + jaggery"),
]


def apply_abstention_tree(df: pd.DataFrame, rules: Sequence[Rule] = EXPERT_RULES) -> Tuple[pd.DataFrame, pd.DataFrame, str]:
    """Sequentially apply high-precision rules. Anything left is abstained."""
    pred_rows: List[Dict[str, object]] = []
    branch_rows: List[Dict[str, object]] = []
    lines: List[str] = []

    remaining = df.copy()
    remaining_idx = list(remaining.index)
    indent = ""
    path_parts: List[str] = []

    for rule in rules:
        tested_n = len(remaining)
        if tested_n == 0:
            lines.append(f"{indent}{rule.name}? [tested=0, yes=0, no=0]")
            lines.append(f"{indent}  yes:")
            lines.append(f"{indent}    → {rule.predicted_state} [n=0, correct=0, actual=]")
            lines.append(f"{indent}  no:")
            indent += "    "
            path_parts.append(f"{rule.name}=no")
            continue

        mask = remaining.apply(rule.test, axis=1).astype(bool)
        yes_sub = remaining.loc[mask].copy()
        no_sub = remaining.loc[~mask].copy()
        correct = int((yes_sub["state"].astype(str) == rule.predicted_state).sum()) if len(yes_sub) else 0
        accuracy = correct / len(yes_sub) if len(yes_sub) else None
        actual = dist_str(yes_sub)
        current_path = " > ".join(path_parts + [f"{rule.name}=yes"])

        lines.append(f"{indent}{rule.name}? [tested={tested_n}, yes={len(yes_sub)}, no={len(no_sub)}]")
        lines.append(f"{indent}  yes:")
        lines.append(f"{indent}    → {rule.predicted_state} [n={len(yes_sub)}, correct={correct}, actual={actual}]")
        lines.append(f"{indent}  no:")

        branch_rows.append({
            "path": current_path,
            "test": rule.name,
            "predicted_state": rule.predicted_state,
            "n": len(yes_sub),
            "correct": correct,
            "accuracy": round(accuracy, 4) if accuracy is not None else "",
            "actual_distribution": actual,
            "explanation": rule.explanation,
        })

        for idx, row in yes_sub.iterrows():
            pred_rows.append({
                "row_index": idx,
                "abstention_pred": rule.predicted_state,
                "abstention_correct": int(str(row["state"]) == rule.predicted_state),
                "abstention_path": current_path,
                "abstention_rule": rule.name,
                "abstained": 0,
            })

        remaining = no_sub
        indent += "    "
        path_parts.append(f"{rule.name}=no")

    # Final abstention leaf.
    abstain_path = " > ".join(path_parts + ["mixed/abstain"])
    best_state = None
    best_count = 0
    best_share = None
    if len(remaining):
        counts = remaining["state"].astype(str).value_counts()
        best_state = str(counts.index[0])
        best_count = int(counts.iloc[0])
        best_share = best_count / len(remaining)
    actual = dist_str(remaining)
    lines.append(f"{indent}→ mixed / abstain [n={len(remaining)}, best_guess={best_state or ''}, best_guess_correct_if_forced={best_count}, actual={actual}]")
    branch_rows.append({
        "path": abstain_path,
        "test": "mixed / abstain",
        "predicted_state": "mixed / abstain",
        "n": len(remaining),
        "correct": "",
        "accuracy": round(best_share, 4) if best_share is not None else "",
        "actual_distribution": actual,
        "explanation": "overlapping shared sambar/dal-curry grammar; no high-confidence rule fired",
    })
    for idx, row in remaining.iterrows():
        pred_rows.append({
            "row_index": idx,
            "abstention_pred": "mixed / abstain",
            "abstention_correct": "",
            "abstention_path": abstain_path,
            "abstention_rule": "mixed / abstain",
            "abstained": 1,
        })

    pred_df = pd.DataFrame(pred_rows).set_index("row_index")
    out = df.join(pred_df, how="left")
    branches = pd.DataFrame(branch_rows)
    return out, branches, "\n".join(lines)


def rule_diagnostics(df: pd.DataFrame, rules: Sequence[Rule] = DIALECT_RULES) -> pd.DataFrame:
    rows: List[Dict[str, object]] = []
    for rule in rules:
        mask = df.apply(rule.test, axis=1).astype(bool)
        sub = df.loc[mask]
        n = len(sub)
        correct = int((sub["state"].astype(str) == rule.predicted_state).sum()) if n else 0
        precision = correct / n if n else 0.0
        # Recall relative to the target state's rows.
        target_total = int((df["state"].astype(str) == rule.predicted_state).sum())
        recall = correct / target_total if target_total else 0.0
        rows.append({
            "rule": rule.name,
            "dialect": rule.predicted_state,
            "n_matched": n,
            "correct": correct,
            "precision": round(precision, 4),
            "recall_within_label": round(recall, 4),
            "actual_distribution": dist_str(sub),
            "explanation": rule.explanation,
        })
    return pd.DataFrame(rows).sort_values(["precision", "n_matched"], ascending=[False, False]).reset_index(drop=True)


def confidence_from_rule_precision(diag: pd.DataFrame) -> Dict[str, float]:
    """Empirical precision with light shrinkage so tiny rules don't look absurdly certain."""
    out: Dict[str, float] = {}
    for _, r in diag.iterrows():
        n = int(r["n_matched"])
        correct = int(r["correct"])
        # Prior says a random 5-way guess is ~0.2. Strength 8 keeps tiny rules humble.
        conf = (correct + 8 * 0.20) / (n + 8) if n else 0.20
        out[str(r["rule"])] = float(conf)
    return out


def dialect_scores_for_row(row: pd.Series, diag_conf: Dict[str, float], rules: Sequence[Rule] = DIALECT_RULES) -> Dict[str, object]:
    scores: Dict[str, float] = defaultdict(float)
    evidence: Dict[str, List[str]] = defaultdict(list)
    # Give a weak base prior to avoid empty score dicts, but make it clear when there is no evidence.
    for state in STATES:
        scores[state] = 0.0
    for rule in rules:
        try:
            fired = bool(rule.test(row))
        except Exception:
            fired = False
        if fired:
            conf = diag_conf.get(rule.name, 0.20)
            # Accumulate in log-odds-ish space without pretending to be a calibrated model.
            scores[rule.predicted_state] += conf
            evidence[rule.predicted_state].append(f"{rule.name} ({conf:.0%})")
    total = sum(scores.values())
    if total <= 0:
        ranked = [("mixed / abstain", 1.0)]
        top = "mixed / abstain"
        confidence = 0.0
    else:
        ranked_states = sorted(scores.items(), key=lambda kv: (-kv[1], kv[0]))
        ranked = [(s, v / total) for s, v in ranked_states if v > 0]
        top = ranked[0][0]
        confidence = ranked[0][1]
    return {
        "top_dialect": top,
        "top_confidence": round(float(confidence), 4),
        "ranked_dialects": json.dumps([(s, round(float(v), 4)) for s, v in ranked], ensure_ascii=False),
        "evidence": "; ".join(f"{s}: " + ", ".join(es) for s, es in evidence.items()),
        "num_evidence_rules": sum(len(v) for v in evidence.values()),
    }


def apply_dialect_detection(df: pd.DataFrame, diag: pd.DataFrame) -> pd.DataFrame:
    conf = confidence_from_rule_precision(diag)
    rows = [dialect_scores_for_row(row, conf) for _, row in df.iterrows()]
    pred = pd.DataFrame(rows, index=df.index)
    out = pd.concat([df.reset_index(drop=True), pred.reset_index(drop=True)], axis=1)
    out["dialect_top1_correct"] = (out["state"].astype(str) == out["top_dialect"].astype(str)).astype(object)
    # If no rule fired and top is abstain, don't count it as a wrong confident prediction.
    out.loc[out["top_dialect"] == "mixed / abstain", "dialect_top1_correct"] = ""
    return out


def write_report(
    path: Path,
    df: pd.DataFrame,
    abstain_pred: pd.DataFrame,
    branches: pd.DataFrame,
    tree_text: str,
    diag: pd.DataFrame,
    dialect_pred: pd.DataFrame,
    existing_path: Optional[Path],
):
    n = len(df)
    covered = abstain_pred[abstain_pred["abstained"].astype(int) == 0]
    abstained = abstain_pred[abstain_pred["abstained"].astype(int) == 1]
    covered_correct = int(pd.to_numeric(covered["abstention_correct"], errors="coerce").fillna(0).sum())
    covered_acc = covered_correct / len(covered) if len(covered) else 0.0
    coverage = len(covered) / n if n else 0.0
    majority_state = df["state"].value_counts().idxmax() if n else ""
    majority_count = int(df["state"].value_counts().max()) if n else 0
    majority_acc = majority_count / n if n else 0.0
    forced_abstain_correct = 0
    forced_state = ""
    forced_share = 0.0
    if len(abstained):
        c = abstained["state"].value_counts()
        forced_state = str(c.index[0])
        forced_abstain_correct = int(c.iloc[0])
        forced_share = forced_abstain_correct / len(abstained)
    forced_total_correct = covered_correct + forced_abstain_correct
    forced_acc = forced_total_correct / n if n else 0.0

    family_n = int(df.get("recipe_family_match", pd.Series(dtype=int)).sum()) if "recipe_family_match" in df.columns else 0
    noise_n = int(df.get("obvious_noise_title", pd.Series(dtype=int)).sum()) if "obvious_noise_title" in df.columns else 0

    dialect_non_abstain = dialect_pred[dialect_pred["top_dialect"] != "mixed / abstain"].copy()
    dialect_correct = pd.to_numeric(dialect_non_abstain["dialect_top1_correct"], errors="coerce").fillna(0).sum() if len(dialect_non_abstain) else 0
    dialect_acc = dialect_correct / len(dialect_non_abstain) if len(dialect_non_abstain) else 0.0

    state_summary = df.groupby("state").size().reset_index(name="recipes")
    branch_table = branches.copy()

    with path.open("w", encoding="utf-8") as f:
        f.write("# Sambar abstaining fast-and-frugal tree\n\n")
        f.write(f"Input reused: `{existing_path}`  \n" if existing_path else "Input reused: none; downloaded from scratch  \n")
        f.write(f"Recipes after dedupe/cap: **{n}**  \n")
        f.write(f"Majority baseline: **{majority_state} = {majority_count}/{n} = {majority_acc:.1%}**  \n")
        f.write(f"Recipe-family title/query matches: **{family_n}/{n} = {(family_n/n if n else 0):.1%}**  \n")
        f.write(f"Obvious noisy title matches: **{noise_n}/{n} = {(noise_n/n if n else 0):.1%}**  \n")
        f.write("\n")
        f.write("## Abstaining tree summary\n\n")
        f.write(f"Covered by high-confidence leaves: **{len(covered)}/{n} = {coverage:.1%}**  \n")
        f.write(f"Accuracy on covered leaves: **{covered_correct}/{len(covered)} = {covered_acc:.1%}**  \n")
        f.write(f"Abstained: **{len(abstained)}/{n} = {(len(abstained)/n if n else 0):.1%}**  \n")
        if len(abstained):
            f.write(f"If abstentions are forced to the abstention majority label `{forced_state}`: **{forced_abstain_correct}/{len(abstained)} = {forced_share:.1%}**  \n")
            f.write(f"Forced overall accuracy: **{forced_total_correct}/{n} = {forced_acc:.1%}**  \n")
        f.write("\n## Decision tree with branch activation counts\n\n")
        f.write("```text\n")
        f.write(tree_text)
        f.write("\n```\n\n")
        f.write("## Confidence-ranked dialect detection\n\n")
        f.write("This is not a calibrated probability model. It ranks dialect/state signals by the empirical precision of the rules that fired.\n\n")
        f.write(f"Rows with at least one dialect signal: **{len(dialect_non_abstain)}/{n} = {(len(dialect_non_abstain)/n if n else 0):.1%}**  \n")
        if len(dialect_non_abstain):
            f.write(f"Top-1 accuracy where at least one signal fired: **{int(dialect_correct)}/{len(dialect_non_abstain)} = {dialect_acc:.1%}**  \n")
        f.write("\n### Rule diagnostics\n\n")
        f.write(diag.to_markdown(index=False))
        f.write("\n\n### State counts\n\n")
        f.write(state_summary.to_markdown(index=False))
        f.write("\n\n### Branch counts\n\n")
        f.write(branch_table.to_markdown(index=False))
        f.write("\n")


def main(argv: Optional[Sequence[str]] = None) -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--input", help="Existing recipe CSV. Defaults to sambar_out/sambar_recipe_dataset.csv, then ./sambar_recipe_dataset.csv")
    ap.add_argument("--max-recipes", type=int, default=2000, help="Cap after deduplication")
    ap.add_argument("--per-state", type=int, default=None, help="Target rows per state during download; default is max_recipes/5")
    ap.add_argument("--output-dir", default="sambar_out")
    ap.add_argument("--no-download", action="store_true", help="Only replay on existing downloaded data")
    ap.add_argument("--no-hf", action="store_true", help="Skip Hugging Face sources")
    ap.add_argument("--no-cookpad", action="store_true", help="Skip Cookpad")
    ap.add_argument("--no-site-search", action="store_true", help="Skip public web/site search expansion")
    ap.add_argument("--strict-family", action="store_true", help="Only train/report on rows with sambar/amti/pappu-charu/pulusu/huli/kuzhambu family matches")
    ap.add_argument("--drop-obvious-noise", action="store_true", help="Drop rows whose title looks obviously unrelated/non-veg/western")
    ap.add_argument("--delay", type=float, default=0.6, help="Delay between public web/detail-page requests")
    args = ap.parse_args(argv)

    outdir = Path(args.output_dir)
    outdir.mkdir(parents=True, exist_ok=True)

    existing_path = find_existing_dataset(outdir, args.input)
    existing_df = load_existing_dataset(existing_path)
    if not existing_df.empty:
        print(f"Reused {existing_path}: {len(existing_df)} rows")
    else:
        print("No existing dataset found; will try to download from scratch")

    frames = [existing_df] if not existing_df.empty else []

    if not args.no_download:
        # Ask for extra because dedupe and balancing can remove a lot.
        target_new = max(args.max_recipes - len(existing_df), math.ceil(args.max_recipes * 0.35))
        print(f"Trying to collect up to ~{target_new} additional candidate rows...")
        new_df = collect_expanded_recipes(
            max_new_rows=target_new,
            per_state_target=args.per_state,
            delay=args.delay,
            no_hf=args.no_hf,
            no_cookpad=args.no_cookpad,
            no_site_search=args.no_site_search,
        )
        if not new_df.empty:
            print(f"Collected {len(new_df)} candidate new rows before dedupe")
            frames.append(new_df)
        else:
            print("No new rows collected; continuing with existing data")

    if not frames:
        raise SystemExit("No recipe data available. Run the earlier sambar_fftree.py first, or enable internet access.")

    df = pd.concat(frames, ignore_index=True, sort=False)
    df = safe_feature_columns(df)
    df = dedupe(df)
    df = annotate_quality(df)

    if args.drop_obvious_noise:
        before = len(df)
        df = df[df["obvious_noise_title"].astype(int) == 0].copy()
        print(f"Dropped obvious-noise rows: {before - len(df)}")
    if args.strict_family:
        before = len(df)
        df = df[df["recipe_family_match"].astype(int) == 1].copy()
        print(f"Kept strict recipe-family rows: {len(df)}/{before}")

    df = cap_rows(df, args.max_recipes)
    df = df.reset_index(drop=True)

    abstain_pred, branches, tree = apply_abstention_tree(df)
    diag = rule_diagnostics(df)
    dialect_pred = apply_dialect_detection(df, diag)

    # Write outputs. Names intentionally parallel the earlier script.
    df.to_csv(outdir / "sambar_recipe_dataset.csv", index=False)
    abstain_pred.to_csv(outdir / "sambar_fftree_abstention_predictions.csv", index=False)
    branches.to_csv(outdir / "sambar_fftree_abstention_branch_counts.csv", index=False)
    diag.to_csv(outdir / "sambar_dialect_rule_diagnostics.csv", index=False)
    dialect_pred.to_csv(outdir / "sambar_dialect_predictions.csv", index=False)
    write_report(outdir / "sambar_fftree_abstention_report.md", df, abstain_pred, branches, tree, diag, dialect_pred, existing_path)

    n = len(df)
    covered = abstain_pred[abstain_pred["abstained"].astype(int) == 0]
    covered_correct = int(pd.to_numeric(covered["abstention_correct"], errors="coerce").fillna(0).sum())
    abstained_n = int((abstain_pred["abstained"].astype(int) == 1).sum())
    print(f"Wrote {outdir}")
    print(f"Recipes: {n}")
    if len(covered):
        print(f"Covered accuracy: {covered_correct}/{len(covered)} = {covered_correct/len(covered):.1%}")
    else:
        print("Covered accuracy: no covered rows")
    print(f"Coverage: {len(covered)}/{n} = {(len(covered)/n if n else 0):.1%}; Abstained: {abstained_n}/{n} = {(abstained_n/n if n else 0):.1%}")
    print("Tree:\n" + tree)
    print("\nConfidence-ranked dialect detection written to:")
    print(f"  {outdir / 'sambar_dialect_predictions.csv'}")
    print(f"  {outdir / 'sambar_dialect_rule_diagnostics.csv'}")
    print(f"  {outdir / 'sambar_fftree_abstention_report.md'}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
