#!/usr/bin/env python3
"""
Build a fast-and-frugal decision tree for regional sambar/amti classification.

Typical use with internet access:
    python sambar_fftree.py --max-recipes 300 --output-dir sambar_out

Replay on an existing CSV:
    python sambar_fftree.py --input sambar_public_validation_dataset.csv --output-dir sambar_out

Outputs:
    sambar_recipe_dataset.csv       recipe rows + ingredient features
    sambar_fftree_predictions.csv   predictions, paths, correctness
    sambar_fftree_branch_counts.csv leaf/path counts and actual distributions
    sambar_fftree_node_counts.csv   internal node activation counts
    sambar_fftree_report.md         readable tree + summary

Notes:
- Labels are weak labels inferred from recipe title/query/source text. For example,
  “Udupi sambar” is Karnataka, “Pappu Charu” is Andhra, and “Amti/Aamti” is Maharashtra.
- The scraper uses public pages/API endpoints. Please respect robots.txt, site terms, and rate limits.
"""
from __future__ import annotations

import argparse
import ast
import html
import json
import math
import re
import sys
import time
from collections import Counter, defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
from urllib.parse import quote_plus, urljoin

import numpy as np
import pandas as pd

try:
    import requests
except Exception:  # pragma: no cover
    requests = None

try:
    from bs4 import BeautifulSoup
except Exception:  # pragma: no cover
    BeautifulSoup = None

try:
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.model_selection import StratifiedKFold, cross_val_score
except Exception as exc:  # pragma: no cover
    raise SystemExit("Install scikit-learn first: pip install scikit-learn") from exc

STATES = ["Andhra", "Karnataka", "Kerala", "Maharashtra", "Tamil Nadu"]

# Broad-but-explicit ingredient vocabulary. Add aliases freely.
INGREDIENT_PATTERNS: Dict[str, Sequence[str]] = {
    "toor dal": [r"\btoor\b", r"\btuvar\b", r"\btur\b", r"\bthoor\b", r"\bthuvar", r"\barhar\b", r"pigeon peas?", r"split pigeon"],
    "masoor dal": [r"\bmasoor\b", r"red lentil"],
    "moong dal": [r"\bmoong\b", r"mung dal", r"green gram"],
    "chana dal": [r"\bchana dal\b", r"bengal gram", r"kadalai paruppu", r"chickpea lentil"],
    "urad dal": [r"\burad\b", r"udad\b", r"ulundhu", r"black gram"],
    "besan": [r"\bbesan\b", r"gram flour"],
    "tomato": [r"tomato", r"thakkali"],
    "drumstick": [r"drumstick", r"murung", r"moringa"],
    "shallots": [r"shallot", r"small onion", r"pearl onion", r"sambar onion", r"chinna vengayam"],
    "onion": [r"\bonion", r"vengayam", r"pyaz"],
    "brinjal": [r"brinjal", r"eggplant", r"aubergine", r"baingan", r"kathirikai", r"vankaya"],
    "carrot": [r"carrot"],
    "potato": [r"potato", r"aloo", r"batata"],
    "bottle gourd": [r"bottle gourd", r"lauki", r"ghiya", r"sorakaya", r"suraikkai"],
    "cucumber": [r"cucumber", r"vellari"],
    "pumpkin": [r"pumpkin", r"mathanga", r"parangikai", r"kaddu"],
    "beans": [r"beans", r"french beans", r"cluster beans", r"gorikayi", r"avarakkai"],
    "lady finger": [r"lady[ '\-]?finger", r"okra", r"bhindi", r"vendakkai"],
    "snake gourd": [r"snake gourd", r"padwal", r"pudalangai"],
    "radish": [r"radish", r"mooli", r"mullangi", r"moolangi"],
    "raw banana": [r"raw banana", r"plantain", r"vazhakkai"],
    "yam": [r"\byam\b", r"elephant foot", r"suran", r"chena", r"senai"],
    "ash gourd": [r"ash gourd", r"white pumpkin", r"poosanikai", r"kumbalanga"],
    "peanuts": [r"peanut", r"groundnut"],
    "green peas": [r"green peas", r"matar", r"peas"],
    "oil": [r"\boil\b", r"sunflower oil", r"vegetable oil"],
    "sesame/gingelly oil": [r"sesame oil", r"gingelly", r"nallennai"],
    "coconut oil": [r"coconut oil"],
    "ghee": [r"\bghee\b", r"clarified butter"],
    "mustard seeds": [r"mustard", r"rai\b", r"kadugu", r"sasive"],
    "turmeric": [r"turmeric", r"haldi", r"manjal"],
    "hing": [r"hing", r"asafoetida", r"asafetida", r"perungayam"],
    "garlic": [r"garlic", r"poondu", r"lasun"],
    "ginger": [r"ginger", r"adrak", r"inji"],
    "coriander seeds": [r"coriander seeds", r"dhania seeds", r"dhaniya seeds"],
    "coriander powder": [r"coriander powder", r"dhania powder", r"dhaniya powder"],
    "coriander leaves": [r"coriander leaves", r"cilantro", r"fresh coriander", r"kothimbir"],
    "fenugreek": [r"fenugreek", r"methi", r"vendhayam"],
    "cumin seeds": [r"cumin", r"jeera", r"jira", r"seeragam"],
    "grated coconut": [r"grated coconut", r"fresh coconut", r"scraped coconut", r"coconut grated", r"coconut - grated", r"\bcoconut\b"],
    "gundu chillies": [r"gundu chilli", r"round chilli"],
    "guntur chillies": [r"guntur chilli", r"guntur red"],
    "byadagi chillies": [r"byadagi", r"bedgi", r"bydagi"],
    "black pepper": [r"black pepper", r"peppercorn"],
    "chilli": [r"red chilli", r"red chili", r"chilli powder", r"chili powder", r"kashmiri chilli", r"dry red chill"],
    "green chilli": [r"green chilli", r"green chili", r"hari mirch"],
    "rice flour": [r"rice flour"],
    "cinnamon": [r"cinnamon", r"dalchini"],
    "jaggery": [r"jaggery", r"gud\b", r"bella\b", r"vellam\b"],
    "sugar": [r"\bsugar\b"],
    "tamarind": [r"tamarind", r"imli", r"puli\b", r"tamerind"],
    "curry leaves": [r"curry leaves", r"kadi patta", r"karuveppilai"],
    "goda masala": [r"goda masala", r"goda spice"],
    "kokum": [r"kokum", r"amsul"],
    "sambar powder": [r"sambar powder", r"sambhar powder", r"sambar masala", r"sambhar masala"],
    "lemon/lime": [r"lemon", r"lime"],
}

# Search queries deliberately include regional recipe names, not just state names.
STATE_QUERIES: Dict[str, Sequence[str]] = {
    "Tamil Nadu": [
        "Tamil Nadu sambar", "arachuvitta sambar", "murungakkai sambar",
        "tiffin sambar tamil", "idli sambar tamil", "Kumbakonam sambar",
        "hotel sambar tamil", "vendakkai sambar tamil"
    ],
    "Kerala": [
        "Kerala sambar", "Onam sambar", "Sadya sambar", "varutharacha sambar",
        "Kerala mixed vegetable sambar", "Kerala coconut sambar", "malabar sambar"
    ],
    "Karnataka": [
        "Udupi sambar", "Karnataka sambar", "Mysore sambar", "huli sambar",
        "byadagi sambar", "koddel sambar", "Mangalore sambar"
    ],
    "Andhra": [
        "Andhra sambar", "Pappu Charu", "Pappu chaaru", "Andhra pappu charu",
        "Telugu sambar", "Guntur sambar", "pulusu dal"
    ],
    "Maharashtra": [
        "Maharashtrian amti", "Maharashtrian aamti", "amti dal", "aamti dal",
        "katachi amti", "goda masala dal", "varan amti", "kokum amti"
    ],
}


def normalize_text(x: object) -> str:
    if x is None or (isinstance(x, float) and math.isnan(x)):
        return ""
    if isinstance(x, (list, tuple)):
        return " | ".join(map(str, x))
    s = str(x)
    # RecipeNLG stores ingredients as a stringified Python list.
    if s.strip().startswith("["):
        try:
            val = ast.literal_eval(s)
            if isinstance(val, list):
                s = " | ".join(map(str, val))
        except Exception:
            pass
    return html.unescape(s).lower()


def feature_row(ingredients_text: str) -> Dict[str, int]:
    text = normalize_text(ingredients_text)
    return {
        feature: int(any(re.search(pat, text, flags=re.I) for pat in patterns))
        for feature, patterns in INGREDIENT_PATTERNS.items()
    }


def canonical_key(title: str, ingredients: str) -> str:
    clean = re.sub(r"[^a-z0-9]+", " ", f"{title} {ingredients}".lower()).strip()
    return " ".join(clean.split())[:500]


def infer_state(text: str, fallback: Optional[str] = None) -> Optional[str]:
    t = normalize_text(text)
    rules = [
        ("Maharashtra", [r"maharash", r"\baamti\b", r"\bamti\b", r"katachi", r"goda masala", r"kokum", r"varan"]),
        ("Karnataka", [r"karnataka", r"udupi", r"mysore", r"mangalor", r"huli\b", r"koddel", r"byadagi", r"bisi bele"]),
        ("Kerala", [r"kerala", r"onam", r"sadya", r"varutharacha", r"malabar", r"kumbalanga", r"mathanga"]),
        ("Andhra", [r"andhra", r"telugu", r"pappu charu", r"pappu chaaru", r"pulusu", r"guntur", r"gongura"]),
        ("Tamil Nadu", [r"tamil", r"arachuvitta", r"murungakkai", r"vendakkai", r"kumbakonam", r"tiffin sambar", r"saravana", r"chettinad"]),
    ]
    for state, patterns in rules:
        if any(re.search(p, t, flags=re.I) for p in patterns):
            return state
    return fallback


def http_get(url: str, timeout: int = 20) -> str:
    if requests is None:
        raise RuntimeError("requests is not installed")
    headers = {
        "User-Agent": "Mozilla/5.0 (compatible; sambar-recipe-research/1.0; +https://openai.com/)"
    }
    r = requests.get(url, headers=headers, timeout=timeout)
    r.raise_for_status()
    return r.text


def fetch_anupam_dataset() -> List[Dict[str, str]]:
    """Fetch Indian recipe CSV from Hugging Face if available."""
    url = "https://huggingface.co/datasets/Anupam007/indian-recipe-dataset/resolve/main/Cleaned_Indian_Food_Dataset.csv"
    rows: List[Dict[str, str]] = []
    try:
        df = pd.read_csv(url)
    except Exception:
        return rows
    for _, r in df.iterrows():
        title = str(r.get("TranslatedRecipeName", ""))
        ingredients = str(r.get("TranslatedIngredients", "") or r.get("Cleaned-Ingredients", ""))
        cuisine = str(r.get("Cuisine", ""))
        source_url = str(r.get("URL", ""))
        state = infer_state(f"{title} {ingredients} {cuisine}")
        if not state:
            continue
        if not re.search(r"sambar|sambhar|pappu charu|aamti|amti|huli|koddel|pulusu", f"{title} {ingredients}", re.I):
            # Keep regionally-labelled close relatives only when the title suggests the family.
            continue
        rows.append({
            "state": state,
            "title": title,
            "url": source_url,
            "source": "huggingface:Anupam007/indian-recipe-dataset",
            "query": "dataset-filter",
            "ingredients_text": ingredients,
        })
    return rows


def fetch_recipenlg_search(query: str, fallback_state: str, max_rows: int = 100) -> List[Dict[str, str]]:
    """Search RecipeNLG through the HF Dataset Viewer API."""
    rows: List[Dict[str, str]] = []
    base = "https://datasets-server.huggingface.co/search"
    if requests is None:
        return rows
    params = {
        "dataset": "Mahimas/recipenlg",
        "config": "default",
        "split": "train",
        "query": query,
        "offset": 0,
        "length": min(max_rows, 100),
    }
    try:
        r = requests.get(base, params=params, timeout=30)
        r.raise_for_status()
        data = r.json()
    except Exception:
        return rows
    for item in data.get("rows", []):
        row = item.get("row", {})
        title = str(row.get("title", ""))
        ingredients = str(row.get("ingredients", "") or row.get("NER", ""))
        link = str(row.get("link", ""))
        state = infer_state(f"{title} {ingredients} {query}", fallback=fallback_state)
        if not state:
            continue
        rows.append({
            "state": state,
            "title": title,
            "url": link,
            "source": "huggingface:Mahimas/recipenlg",
            "query": query,
            "ingredients_text": ingredients,
        })
    return rows


def parse_jsonld_recipes(soup) -> List[Tuple[str, str]]:
    out = []
    for script in soup.find_all("script", type="application/ld+json"):
        try:
            data = json.loads(script.string or "")
        except Exception:
            continue
        objects = data if isinstance(data, list) else [data]
        for obj in objects:
            if isinstance(obj, dict) and obj.get("@type") in ("Recipe", ["Recipe"]):
                title = obj.get("name", "")
                ingred = obj.get("recipeIngredient", "")
                out.append((str(title), normalize_text(ingred)))
            elif isinstance(obj, dict) and "@graph" in obj:
                for g in obj.get("@graph", []):
                    if isinstance(g, dict) and g.get("@type") == "Recipe":
                        out.append((str(g.get("name", "")), normalize_text(g.get("recipeIngredient", ""))))
    return out


def fetch_recipe_page(url: str, fallback_title: str = "") -> Tuple[str, str]:
    if BeautifulSoup is None:
        return fallback_title, ""
    try:
        text = http_get(url)
    except Exception:
        return fallback_title, ""
    soup = BeautifulSoup(text, "html.parser")
    jsonld = parse_jsonld_recipes(soup)
    if jsonld:
        title, ingred = jsonld[0]
        return title or fallback_title, ingred
    # Generic fallback: gather list items under ingredient-ish headings.
    title = (soup.find("h1").get_text(" ", strip=True) if soup.find("h1") else fallback_title)
    page_text = soup.get_text("\n", strip=True)
    m = re.search(r"ingredients\s*(.*?)(steps|method|instructions|directions|preparation)", page_text, flags=re.I | re.S)
    ingred = m.group(1)[:3000] if m else ""
    return title, ingred


def fetch_cookpad_search(query: str, fallback_state: str, max_links: int = 30, delay: float = 0.7) -> List[Dict[str, str]]:
    """Scrape Cookpad search result snippets and, where possible, detail pages."""
    if BeautifulSoup is None:
        return []
    url = f"https://cookpad.com/eng/search/{quote_plus(query)}"
    try:
        text = http_get(url)
    except Exception:
        return []
    soup = BeautifulSoup(text, "html.parser")
    rows = []
    seen_links = set()
    # Cookpad search pages expose enough ingredient snippets in text. Link selectors shift, so be broad.
    links = []
    for a in soup.find_all("a", href=True):
        href = a["href"]
        label = a.get_text(" ", strip=True)
        if "/eng/recipes/" in href and label:
            full = urljoin("https://cookpad.com", href)
            if full not in seen_links:
                links.append((full, label))
                seen_links.add(full)
    if not links:
        # Fallback: make a single pseudo-row from the search page snippets.
        title = f"Cookpad search: {query}"
        ingred = soup.get_text(" | ", strip=True)[:5000]
        rows.append({"state": fallback_state, "title": title, "url": url, "source": "cookpad-search", "query": query, "ingredients_text": ingred})
        return rows
    for full, label in links[:max_links]:
        time.sleep(delay)
        title, ingred = fetch_recipe_page(full, label)
        state = infer_state(f"{title} {ingred} {query}", fallback=fallback_state)
        if not ingred:
            continue
        rows.append({"state": state, "title": title, "url": full, "source": "cookpad", "query": query, "ingredients_text": ingred})
    return rows


def collect_recipes(max_recipes: int = 300, per_state: Optional[int] = None, use_cookpad: bool = True, use_hf: bool = True) -> pd.DataFrame:
    per_state = per_state or math.ceil(max_recipes / len(STATES))
    rows: List[Dict[str, str]] = []
    if use_hf:
        rows.extend(fetch_anupam_dataset())
    counts = Counter(r["state"] for r in rows)
    # Search RecipeNLG and Cookpad by state-specific recipe name.
    for state, queries in STATE_QUERIES.items():
        for q in queries:
            if counts[state] >= per_state:
                break
            if use_hf:
                new_rows = fetch_recipenlg_search(q, state, max_rows=100)
                rows.extend(new_rows)
                counts.update(r["state"] for r in new_rows)
            if counts[state] < per_state and use_cookpad:
                new_rows = fetch_cookpad_search(q, state, max_links=20)
                rows.extend(new_rows)
                counts.update(r["state"] for r in new_rows)
    # Deduplicate and balance.
    seen = set()
    final = []
    state_counts = Counter()
    for r in rows:
        state = r.get("state")
        if state not in STATES or state_counts[state] >= per_state:
            continue
        key = canonical_key(r.get("title", ""), r.get("ingredients_text", ""))
        if not key or key in seen:
            continue
        seen.add(key)
        final.append(r)
        state_counts[state] += 1
        if len(final) >= max_recipes:
            break
    df = pd.DataFrame(final)
    return add_features(df)


def add_features(df: pd.DataFrame) -> pd.DataFrame:
    if df.empty:
        return df
    # If ingredients_text absent, reconstruct from existing feature columns is impossible; use title/url only.
    if "ingredients_text" not in df.columns:
        df["ingredients_text"] = ""
    feats = df["ingredients_text"].fillna("").map(feature_row).apply(pd.Series).fillna(0).astype(int)
    # Preserve any already-existing feature values if the input CSV is already featureized.
    for col in INGREDIENT_PATTERNS:
        if col in df.columns:
            feats[col] = df[col].fillna(0).astype(int).clip(0, 1)
    base_cols = [c for c in ["state", "title", "url", "source", "query", "ingredients_text"] if c in df.columns]
    return pd.concat([df[base_cols].reset_index(drop=True), feats.reset_index(drop=True)], axis=1)


def train_fftree(df: pd.DataFrame, max_depth: int = 3, max_leaf_nodes: int = 5):
    feature_cols = [c for c in INGREDIENT_PATTERNS if c in df.columns]
    X = df[feature_cols].fillna(0).astype(int)
    y = df["state"].astype(str)
    min_leaf = max(2, len(df) // 100)
    clf = DecisionTreeClassifier(
        criterion="entropy",
        max_depth=max_depth,
        max_leaf_nodes=max_leaf_nodes,
        min_samples_leaf=min_leaf,
        random_state=42,
    )
    clf.fit(X, y)
    return clf, feature_cols


def walk_predictions(clf, feature_cols: List[str], df: pd.DataFrame) -> pd.DataFrame:
    X = df[feature_cols].fillna(0).astype(int)
    pred = clf.predict(X)
    tree = clf.tree_
    classes = list(clf.classes_)
    paths = []
    node_ids = []
    for _, row in X.iterrows():
        node = 0
        parts = []
        while tree.children_left[node] != tree.children_right[node]:
            fname = feature_cols[tree.feature[node]]
            val = int(row[fname])
            go_left = val <= tree.threshold[node]
            parts.append(f"{fname}={'no' if go_left else 'yes'}")
            node = tree.children_left[node] if go_left else tree.children_right[node]
        paths.append(" > ".join(parts))
        node_ids.append(node)
    out = df.copy()
    out["pred_fftree"] = pred
    out["correct_fftree"] = (out["state"].astype(str) == out["pred_fftree"].astype(str)).astype(int)
    out["fftree_leaf_id"] = node_ids
    out["fftree_path"] = paths
    return out


def node_activation_counts(clf, feature_cols: List[str], df: pd.DataFrame) -> pd.DataFrame:
    X = df[feature_cols].fillna(0).astype(int).reset_index(drop=True)
    tree = clf.tree_
    rows = []
    def rec(node: int, idxs: List[int], prefix: str):
        if tree.children_left[node] == tree.children_right[node]:
            return
        fname = feature_cols[tree.feature[node]]
        left = []
        right = []
        for i in idxs:
            if X.loc[i, fname] <= tree.threshold[node]:
                left.append(i)
            else:
                right.append(i)
        rows.append({"node_id": node, "path_to_node": prefix or "root", "test": f"{fname}?", "tested": len(idxs), "yes": len(right), "no": len(left)})
        rec(tree.children_left[node], left, (prefix + " > " if prefix else "") + f"{fname}=no")
        rec(tree.children_right[node], right, (prefix + " > " if prefix else "") + f"{fname}=yes")
    rec(0, list(range(len(X))), "")
    return pd.DataFrame(rows)


def branch_counts(pred_df: pd.DataFrame) -> pd.DataFrame:
    rows = []
    for path, sub in pred_df.groupby("fftree_path", dropna=False):
        pred = str(sub["pred_fftree"].mode().iat[0])
        dist = sub["state"].value_counts().sort_index().to_dict()
        rows.append({
            "path": path,
            "predicted_state": pred,
            "n": len(sub),
            "correct": int((sub["state"].astype(str) == sub["pred_fftree"].astype(str)).sum()),
            "accuracy": round(float((sub["state"].astype(str) == sub["pred_fftree"].astype(str)).mean()), 4),
            "actual_distribution": "; ".join(f"{k}:{v}" for k, v in dist.items()),
        })
    return pd.DataFrame(rows).sort_values(["n", "path"], ascending=[False, True])


def format_dist(values: np.ndarray, classes: Sequence[str]) -> str:
    pairs = [(classes[i], int(v)) for i, v in enumerate(values) if int(v) > 0]
    return ", ".join(f"{k}:{v}" for k, v in pairs)


def tree_text(clf, feature_cols: List[str], df: pd.DataFrame) -> str:
    X = df[feature_cols].fillna(0).astype(int).reset_index(drop=True)
    y = df["state"].astype(str).reset_index(drop=True)
    tree = clf.tree_
    classes = list(clf.classes_)
    lines: List[str] = []

    def rec(node: int, idxs: List[int], indent: str):
        counts = Counter(y.iloc[idxs])
        pred_idx = int(np.argmax(tree.value[node][0]))
        pred = classes[pred_idx]
        if tree.children_left[node] == tree.children_right[node]:
            correct = counts.get(pred, 0)
            dist = "; ".join(f"{k}:{v}" for k, v in sorted(counts.items()))
            lines.append(f"{indent}→ {pred} [n={len(idxs)}, correct={correct}, actual={dist}]")
            return
        fname = feature_cols[tree.feature[node]]
        left = [i for i in idxs if X.loc[i, fname] <= tree.threshold[node]]
        right = [i for i in idxs if X.loc[i, fname] > tree.threshold[node]]
        lines.append(f"{indent}{fname}? [tested={len(idxs)}, yes={len(right)}, no={len(left)}]")
        lines.append(f"{indent}  yes:")
        rec(tree.children_right[node], right, indent + "    ")
        lines.append(f"{indent}  no:")
        rec(tree.children_left[node], left, indent + "    ")

    rec(0, list(range(len(df))), "")
    return "\n".join(lines)


def cv_accuracy(clf, feature_cols: List[str], df: pd.DataFrame) -> Optional[float]:
    if len(df) < 30 or df["state"].nunique() < 2:
        return None
    X = df[feature_cols].fillna(0).astype(int)
    y = df["state"].astype(str)
    min_class = y.value_counts().min()
    n_splits = min(5, int(min_class))
    if n_splits < 2:
        return None
    # Re-create same hyperparameters without fitted attributes.
    new_clf = DecisionTreeClassifier(
        criterion="entropy",
        max_depth=clf.max_depth,
        max_leaf_nodes=clf.max_leaf_nodes,
        min_samples_leaf=clf.min_samples_leaf,
        random_state=42,
    )
    scores = cross_val_score(new_clf, X, y, cv=StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42))
    return float(np.mean(scores))


def write_report(path: Path, df: pd.DataFrame, pred_df: pd.DataFrame, clf, feature_cols: List[str], branches: pd.DataFrame, nodes: pd.DataFrame):
    accuracy = pred_df["correct_fftree"].mean()
    state_summary = pred_df.groupby("state").agg(recipes=("state", "size"), correct=("correct_fftree", "sum")).reset_index()
    state_summary["accuracy"] = state_summary["correct"] / state_summary["recipes"]
    cv = cv_accuracy(clf, feature_cols, df)
    with path.open("w", encoding="utf-8") as f:
        f.write("# Sambar fast-and-frugal decision tree\n\n")
        f.write(f"Recipes: **{len(df)}**  \n")
        f.write(f"Training accuracy: **{pred_df['correct_fftree'].sum()}/{len(pred_df)} = {accuracy:.1%}**  \n")
        if cv is not None:
            f.write(f"Stratified CV accuracy: **{cv:.1%}**  \n")
        f.write("\n## Decision tree with branch activation counts\n\n")
        f.write("```text\n")
        f.write(tree_text(clf, feature_cols, df))
        f.write("\n```\n\n")
        f.write("## Per-state summary\n\n")
        f.write(state_summary.to_markdown(index=False, floatfmt=".3f"))
        f.write("\n\n## Branch counts\n\n")
        f.write(branches.to_markdown(index=False))
        f.write("\n\n## Node activation counts\n\n")
        f.write(nodes.to_markdown(index=False))
        f.write("\n")


def main(argv: Optional[Sequence[str]] = None) -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--input", help="Existing recipe CSV to replay. If omitted, the script downloads recipes.")
    ap.add_argument("--max-recipes", type=int, default=300)
    ap.add_argument("--per-state", type=int, default=None)
    ap.add_argument("--output-dir", default="sambar_out")
    ap.add_argument("--max-depth", type=int, default=3)
    ap.add_argument("--max-leaf-nodes", type=int, default=5)
    ap.add_argument("--no-cookpad", action="store_true")
    ap.add_argument("--no-hf", action="store_true")
    args = ap.parse_args(argv)

    outdir = Path(args.output_dir)
    outdir.mkdir(parents=True, exist_ok=True)

    if args.input:
        df = pd.read_csv(args.input)
        df = add_features(df)
    else:
        df = collect_recipes(
            max_recipes=args.max_recipes,
            per_state=args.per_state,
            use_cookpad=not args.no_cookpad,
            use_hf=not args.no_hf,
        )
    if df.empty:
        raise SystemExit("No recipes found. Check internet access or pass --input existing_recipe_dataset.csv")

    # Deduplicate and keep balanced if downloaded sources produced too many rows.
    if "ingredients_text" in df.columns:
        keys = df.apply(lambda r: canonical_key(r.get("title", ""), r.get("ingredients_text", "")), axis=1)
        df = df.loc[~keys.duplicated()].copy()
    if len(df) > args.max_recipes:
        df = df.groupby("state", group_keys=False).head(math.ceil(args.max_recipes / df["state"].nunique())).head(args.max_recipes).copy()

    clf, feature_cols = train_fftree(df, max_depth=args.max_depth, max_leaf_nodes=args.max_leaf_nodes)
    pred_df = walk_predictions(clf, feature_cols, df)
    branches = branch_counts(pred_df)
    nodes = node_activation_counts(clf, feature_cols, df)

    df.to_csv(outdir / "sambar_recipe_dataset.csv", index=False)
    pred_df.to_csv(outdir / "sambar_fftree_predictions.csv", index=False)
    branches.to_csv(outdir / "sambar_fftree_branch_counts.csv", index=False)
    nodes.to_csv(outdir / "sambar_fftree_node_counts.csv", index=False)
    write_report(outdir / "sambar_fftree_report.md", df, pred_df, clf, feature_cols, branches, nodes)

    print(f"Wrote {outdir}")
    print(f"Recipes: {len(df)}")
    print(f"Accuracy: {pred_df['correct_fftree'].sum()}/{len(pred_df)} = {pred_df['correct_fftree'].mean():.1%}")
    print("Tree:\n" + tree_text(clf, feature_cols, df))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
