#!/usr/bin/env sage
# -*- coding: utf-8 -*-

"""
BBBdec.py  —  Bialynicki-Birula-Brosnan decomposition of motives of projective homogeneous varieties.

Given a semi-simple algebraic group G with anisotropic kernel G0 (indexed by Sigma0),
and a parabolic subgroup P_J, computes the motivic decomposition

    M(G/P_J)  =  direct sum_{w in ^I W^J}  M(G0 / P_{J_w})(l(w))

following Brosnan's theorem (Thm. 7.4), where:
  - ^I W^J  are the bi-minimal representatives of double cosets W_I \ W / W_J,
  - J_w = { i in I : w^{-1}(alpha_i) in R_J },
  - l(w) is the Tate twist (length of w).

Core algorithm: enumerate W^J, then filter for left-minimality w.r.t. I.
This avoids the expensive iterative bi-minimal reduction on all of W.

Modes:
- Full (default):        aggregate by (J_w, twist, abs_rank)
- Tate-trace (--tatetrace):  keep only pure Tate terms (J_w == I), aggregate by twist

Options:
- --use-root-test:  use Humphreys' root criterion for left-minimality
- --words:          show a sample reduced word per class in the output
"""

from sage.all import *
import argparse
import time

def _parse_set(s: str):
    s = (s or "").strip()
    if not s:
        return set()
    return set(int(x.strip()) for x in s.split(',') if x.strip())

def _disp_list(xs):
    xs = list(xs)
    if not xs:
        return "[]"
    return "[" + ", ".join(str(int(x)) for x in xs) + "]"


def _format_cartan_type(ct):
    def fmt_one(c):
        try:
            return f"{c.type()}{int(c.rank())}"
        except Exception:
            pass
        try:
            cc = c.cartan_type()
            return f"{cc.type()}{int(cc.rank())}"
        except Exception:
            pass
        return str(c)

    try:
        comps = list(ct.component_types()) if hasattr(ct, "component_types") else [ct]
    except Exception:
        comps = [ct]

    parts = [fmt_one(c) for c in comps]
    return "x".join(parts) if parts else "1"


def levi_type_of_subset_ordered(CT, S):
    S = sorted(set(int(x) for x in S))
    if not S:
        return "1"
    try:
        G = CT.dynkin_diagram()
        H = G.subgraph(S)
        comps = [sorted(list(cc)) for cc in H.connected_components()]
        comps.sort(key=lambda nodes: nodes[0])
        parts = []
        for nodes in comps:
            parts.append(_format_cartan_type(CT.subtype(nodes)))
        return "x".join(parts) if parts else "1"
    except Exception:
        sub = CT.subtype(S)
        try:
            comps_ct = list(sub.component_types()) if hasattr(sub, "component_types") else [sub]
            return "x".join(_format_cartan_type(c) for c in comps_ct)
        except Exception:
            return _format_cartan_type(sub)


def anisotropic_kernel_type(CT, I):
    return levi_type_of_subset_ordered(CT, I)

# Split case 

def split_poincare_tate_decomposition(dynkin, J):
    CT = CartanType(dynkin)
    PR = PolynomialRing(ZZ, "t")
    t = PR.gen()

    def poincare_from_degrees(degs):
        P = PR(1)
        for d in degs:
            d = int(d)
            P *= sum(t**k for k in range(d))
        return PR(P)

    Wfull = WeylGroup(CT, prefix="s")
    PW = poincare_from_degrees(Wfull.degrees())

    J = list(sorted(set(int(x) for x in J)))
    if not J:
        PWJ = PR(1)
    else:
        sub = CT.subtype(J)
        comps = list(sub.component_types()) if hasattr(sub, "component_types") else [sub]
        PWJ = PR(1)
        for c in comps:
            Wc = WeylGroup(c, prefix="s")
            PWJ *= poincare_from_degrees(Wc.degrees())

    Q, Rm = PW.quo_rem(PWJ)
    if Rm != 0:
        raise RuntimeError("Poincaré quotient PW/PWJ not exact in ZZ[t].")

    coeffs = [int(c) for c in Q.list()]
    while coeffs and coeffs[-1] == 0:
        coeffs.pop()
    return coeffs


def _w_key(w):
    return tuple(int(i) for i in w.reduced_word())


def format_w_word(word):
    word = list(word)
    if not word:
        return "1"
    return "*".join(f"s{int(i)}" for i in word)

# Root combinatorics

def _coords_in_simple_basis(beta):
    v = beta.to_vector()
    return [ZZ(v[i]) for i in range(len(v))]


def in_RJ(beta, J):
    coords = _coords_in_simple_basis(beta)
    for idx, c in enumerate(coords, start=1):
        if idx not in J and c != 0:
            return False
    return True


def compute_Jw_and_twist(w, I, J, simple_roots):
    winv = ~w
    Jw = []
    for i in I:
        beta = winv.action(simple_roots[i])
        if in_RJ(beta, J):
            Jw.append(i)
    return frozenset(Jw), int(w.length())


def is_tate_trace_direct(w, I, J, simple_roots):
    winv = ~w
    for i in I:
        beta = winv.action(simple_roots[i])
        if not in_RJ(beta, J):
            return False
    return True

_parabolic_order_cache = {}


def weyl_order_from_cartan_type(ct):
    k = str(ct)
    if k in _parabolic_order_cache:
        return _parabolic_order_cache[k]
    R = RootSystem(ct).root_lattice()
    W = R.weyl_group(prefix="s")
    val = int(W.order())
    _parabolic_order_cache[k] = val
    return val


def weyl_order_parabolic(CT, S):
    S = tuple(sorted(set(int(x) for x in S)))
    key = (str(CT), S)
    if key in _parabolic_order_cache:
        return _parabolic_order_cache[key]
    if not S:
        _parabolic_order_cache[key] = 1
        return 1
    sub = CT.subtype(list(S))
    val = int(weyl_order_from_cartan_type(sub))
    _parabolic_order_cache[key] = val
    return val

# Minimal coset reps

def reduce_right_to_WJ(w, J_reflections):
    """
    Reduce w to its right-minimal representative in W^J.
    J_reflections: dict {j: s_j} with simple reflections precomputed for j in J.
    """
    changed = True
    while changed:
        changed = False
        for sj in J_reflections.values():
            w2 = w * sj
            if w2.length() < w.length():
                w = w2
                changed = True
    return w


def enumerate_WJ_elements(W, rank, J, progress=False, every=5000, max_reps=None, expected_size=None):
    gens = [W.simple_reflection(i) for i in range(1, rank + 1)]
    # Precompute J-reflections once for all calls to reduce_right_to_WJ
    J_reflections = {j: W.simple_reflection(j) for j in J}
    start = reduce_right_to_WJ(W.one(), J_reflections)
    q = [start]
    seen = {_w_key(start): start}

    t0 = time.time()
    while q:
        w = q.pop()
        for s in gens:
            w2 = reduce_right_to_WJ(s * w, J_reflections)
            k = _w_key(w2)
            if k not in seen:
                seen[k] = w2
                q.append(w2)
                if max_reps is not None and len(seen) > max_reps:
                    raise RuntimeError(f"W^J exceeded max_reps={max_reps} (currently {len(seen)})." )
                if progress and len(seen) % every == 0:
                    dt = time.time() - t0
                    current = len(seen)
                    if expected_size:
                        pct = 100.0 * current / expected_size
                        speed = current / dt if dt > 0 else 0
                        eta = (expected_size - current) / speed if speed > 0 else 0
                        print(f"  [W^J] size={current:10d}/{expected_size:10d} ({pct:6.2f}%) | t={dt:7.2f}s | eta~{eta:7.1f}s")
                    else:
                        print(f"  [W^J] size={current:10d} | t={dt:7.2f}s")

    return list(seen.values())

def is_left_minimal_by_lengths(w, I_reflections):
    """
    Left-minimal w.r.t I  <=>  no left descent in I,
    i.e. for all i in I, l(s_i * w) > l(w).
    I_reflections: dict {i: s_i} with simple reflections precomputed for i in I.
    """
    lw = w.length()
    for si in I_reflections.values():
        if (si * w).length() < lw:
            return False
    return True


def _is_negative_root(beta):
    coords = beta.to_vector()
    all_leq0 = True
    any_lt0 = False
    for c in coords:
        c = ZZ(c)
        if c > 0:
            all_leq0 = False
            break
        if c < 0:
            any_lt0 = True
    return all_leq0 and any_lt0


def is_left_minimal_by_roots(w, I, simple_roots):
    """
    Equivalent Humphreys condition for left-minimality:
      w^{-1}(alpha_i) is positive for all i in I.
    """
    winv = ~w
    for i in I:
        beta = winv.action(simple_roots[i])
        if _is_negative_root(beta):
            return False
    return True


# Main

def BBBdec(
    dynkin, Sigma0, J,
    progress=False, every=2000, max_reps=None,
    tate_trace=False,
    use_root_test=False,
    show_words=False,
):
    Sigma0 = set(int(x) for x in Sigma0)
    J = set(int(x) for x in J)

    CT = CartanType(dynkin)
    r = int(CT.rank())
    I = Sigma0

    R = RootSystem(dynkin).root_lattice()
    W = R.weyl_group(prefix="s")
    simple = R.simple_roots()

    total_W = int(W.order())
    WJ_order = int(weyl_order_parabolic(CT, sorted(J)))
    WI_order = int(weyl_order_parabolic(CT, sorted(I)))
    target_tate = total_W // WJ_order

    t0 = time.time()

    # Splity

    if not I:
        coeffs = split_poincare_tate_decomposition(dynkin, J)
        dim = len(coeffs) - 1 if coeffs else 0

        rows = []
        for tw in range(0, dim + 1):
            mlt = int(coeffs[tw]) if tw < len(coeffs) else 0
            if mlt == 0:
                continue
            rows.append({
                "mult": mlt,
                "twist": int(tw),
                "abs_rank": 1,
                "contrib": mlt,
                "sample_w": "poincare_quotient",
            })

        elapsed = time.time() - t0
        stats = {
            "dynkin": list(CT),
            "Sigma0": sorted(I),
            "J": sorted(J),
            "total_W": total_W,
            "WI_order": WI_order,
            "WJ_order": WJ_order,
            "WJ_reps": None,
            "bimin_reps": None,
            "tate_twists": len(rows),
            "tate_mult_total": int(sum(coeffs)),
            "covered_tate": int(sum(coeffs)),
            "target_tate": target_tate,
            "elapsed_sec": elapsed,
            "tate_trace": False,
            "use_root_test": bool(use_root_test),
            "dimension": int(dim),
            "split_method": "poincare_quotient",
        }
        return rows, stats

    reps_WJ = enumerate_WJ_elements(W, r, J, progress=progress, every=every, max_reps=max_reps, expected_size=target_tate)

    I_reflections = {i: W.simple_reflection(i) for i in I}

    bimin = []
    for idx, w in enumerate(reps_WJ, start=1):
        ok = is_left_minimal_by_roots(w, I, simple) if use_root_test else is_left_minimal_by_lengths(w, I_reflections)
        if ok:
            bimin.append(w)
        if progress and idx % every == 0:
            dt = time.time() - t0
            print(f"  [filter] processed={idx:10d}/{len(reps_WJ)} | bimin={len(bimin):8d} | t={dt:7.2f}s")

    # Tate-trace direct mode

    if tate_trace:
        tate_by_twist = {}
        tate_samples = {}
        tate_mult_total = 0

        for w in bimin:
            if not is_tate_trace_direct(w, I, J, simple):
                continue
            tw = int(w.length())
            tate_by_twist[tw] = tate_by_twist.get(tw, 0) + 1
            tate_mult_total += 1
            if show_words and tw not in tate_samples:
                tate_samples[tw] = w.reduced_word()

        rows = []
        for tw in sorted(tate_by_twist.keys()):
            mult = int(tate_by_twist[tw])
            row = {
                "mult": mult,
                "twist": int(tw),
                "abs_rank": 1,
                "contrib": mult,
            }
            if show_words:
                row["sample_w"] = format_w_word(tate_samples[tw])
            rows.append(row)

        elapsed = time.time() - t0
        stats = {
            "dynkin": list(CT),
            "Sigma0": sorted(I),
            "J": sorted(J),
            "total_W": total_W,
            "WI_order": WI_order,
            "WJ_order": WJ_order,
            "WJ_reps": len(reps_WJ),
            "bimin_reps": len(bimin),
            "tate_twists": len(rows),
            "tate_mult_total": int(tate_mult_total),
            "covered_tate": sum(r["contrib"] for r in rows),
            "target_tate": target_tate,
            "elapsed_sec": elapsed,
            "tate_trace": True,
            "use_root_test": bool(use_root_test),
        }
        return rows, stats

    else:
        agg = {}
        covered = 0
        # Compute once: does not depend on w
        levi_type = anisotropic_kernel_type(CT, I)
        for b in bimin:
            Jw, tw = compute_Jw_and_twist(b, I, J, simple)
            abs_rank = WI_order // int(weyl_order_parabolic(CT, sorted(Jw)))
            k = (Jw, int(tw), int(abs_rank))
            if k not in agg:
                agg[k] = {"mult": 0}
                if show_words:
                    agg[k]["sample_word"] = b.reduced_word()
            agg[k]["mult"] += 1
            covered += abs_rank

        rows = []
        for (Jw, tw, abs_rank), info in agg.items():
            mult = int(info["mult"])
            row = {
                "mult": mult,
                "Levi_type": levi_type,
                "Jwcomp": sorted(set(I) - set(Jw)),
                "twist": int(tw),
                "abs_rank": int(abs_rank),
                "contrib": int(mult * abs_rank),
            }
            if show_words:
                row["sample_w"] = format_w_word(info["sample_word"])
            rows.append(row)
        rows.sort(key=lambda rr: (rr["twist"], rr.get("Jwcomp", []), rr["abs_rank"]))

        elapsed = time.time() - t0
        stats = {
            "dynkin": list(CT),
            "Sigma0": sorted(I),
            "J": sorted(J),
            "total_W": total_W,
            "WI_order": WI_order,
            "WJ_order": WJ_order,
            "WJ_reps": len(reps_WJ),
            "bimin_reps": len(bimin),
            "aggregated_classes": len(rows),
            "covered_tate": covered,
            "target_tate": target_tate,
            "elapsed_sec": elapsed,
            "tate_trace": False,
            "use_root_test": bool(use_root_test),
        }
        return rows, stats

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dynkin", required=True)
    ap.add_argument("--Sigma0", default="")
    ap.add_argument("--J", default="")
    ap.add_argument("--Sigma0comp", default="")
    ap.add_argument("--Jcomp", default="")
    ap.add_argument("--progress", action="store_true")
    ap.add_argument("--every", type=int, default=2000)
    ap.add_argument("--max_reps", type=int, default=None)

    ap.add_argument("--tatetrace", action="store_true",
                    help="Keep only pure Tate terms (J_w == I), aggregate by twist.")

    ap.add_argument("--use-root-test", action="store_true",
                    help="Use root-sign criterion for left-minimality (instead of length checks)." )
    ap.add_argument("--words", action="store_true",
                    help="Show a sample reduced word for each class in the output (off by default).")

    args = ap.parse_args()

    CT = CartanType(args.dynkin.strip())
    n = int(CT.rank())
    full = set(range(1, n + 1))

    Sigma0 = full - _parse_set(args.Sigma0comp) if args.Sigma0comp.strip() else _parse_set(args.Sigma0)
    J = full - _parse_set(args.Jcomp) if args.Jcomp.strip() else _parse_set(args.J)

    print("\n=== BBBdec — Bialynicki-Birula-Brosnan decomposition (Brosnan Thm. 7.4) ===")
    print(f"dynkin={list(CT)}, Sigma0={sorted(Sigma0)}, J={sorted(J)}")
    print(f"tatetrace={bool(args.tatetrace)} | use_root_test={bool(args.use_root_test)} | words={bool(args.words)}")

    rows, stats = BBBdec(
        dynkin=args.dynkin.strip(),
        Sigma0=Sigma0,
        J=J,
        progress=args.progress,
        every=args.every,
        max_reps=args.max_reps,
        tate_trace=bool(args.tatetrace),
        use_root_test=bool(args.use_root_test),
        show_words=bool(args.words),
    )

    if stats.get("split_method", "") == "poincare_quotient":
        print("\n=== DECOMPOSITION (SPLIT: POINCARE QUOTIENT) ===")
        dim = int(stats.get("dimension", 0))
        mult_by_twist = {int(r["twist"]): int(r["mult"]) for r in rows}
        for i in range(0, dim + 1):
            n_i = int(mult_by_twist.get(i, 0))
            print(f"  Z{{{i}}} + {n_i}")
        print("\n=== STATS ===")
        for k in [
            "dynkin","Sigma0","J","total_W","WJ_order",
            "tate_mult_total","target_tate","dimension","elapsed_sec","split_method"
        ]:
            if k in stats:
                print(f"{k}: {stats[k]}")

    elif stats.get("tate_trace", False):
        print("\n=== TATE TRACE ===")
        for r in rows:
            line = f"  mult={r['mult']} | twist={r['twist']} | contrib={r['contrib']}"
            if "sample_w" in r:
                line += f" | sample_w={r['sample_w']}"
            print(line)
        print("\n=== STATS ===")
        for k in [
            "dynkin","Sigma0","J","total_W","WI_order","WJ_order","WJ_reps","bimin_reps",
            "tate_twists","tate_mult_total","covered_tate","target_tate","elapsed_sec",
            "tate_trace","use_root_test"
        ]:
            if k in stats:
                print(f"{k}: {stats[k]}")

    else:
        print("\n=== AGGREGATED TERMS ===")
        for r in rows:
            line = (
                f"  mult={r['mult']} | Levi_type={r['Levi_type']} | Jwcomp={_disp_list(r['Jwcomp'])} | twist={r['twist']} | "
                f"abs_rank={r['abs_rank']} | contrib={r['contrib']}"
            )
            if "sample_w" in r:
                line += f" | sample_w={r['sample_w']}"
            print(line)
        print("\n=== STATS ===")
        for k in [
            "dynkin","Sigma0","J","total_W","WI_order","WJ_order","WJ_reps","bimin_reps",
            "aggregated_classes","covered_tate","target_tate","elapsed_sec",
            "tate_trace","use_root_test"
        ]:
            if k in stats:
                print(f"{k}: {stats[k]}")
        if stats["covered_tate"] != stats["target_tate"]:
            print("\nWARNING: covered_tate != target_tate. This indicates either a convention mismatch in target_tate or a bug.")


if __name__ == "__main__":
    main()