#!/usr/bin/env python3
# z_c_RescaleTI_Statistical_to_Spatial.py
# Minimal-combo synthetic maps with three positional args and hard-coded AH flag.
#
# Usage:
#   python z_c_RescaleTI_Statistical_to_Spatial.py <path1> <path2> <new_project_dir>
#
#   <path1> : folder with HydroPlusConfig.xml and TI CSV (z_TI_ExponentialDecay.csv or z_TI_PowerDecay.csv)
#   <path2> : folder with lc_*.asc, ic_*.asc, tc_*.asc, or landcover.asc, imperviouscover.asc, treecover.asc, and optional AH flux .asc files
#   <new_project_dir> : output folder (Inputs/, Outputs/ will be created)
#
# Notes:
# - If both qcr and qncr exist, they are combined as q_cr_and_ncr = qcr + qncr.
# - If both qtot and q_cr_and_ncr exist: Flag_AH_Flux_Total_default == 1 → use qtot, 0 → use combined.
#
# -----------------------------------------------------------------------------

import os
import sys
import glob
import math
import shutil
import numpy as np
import pandas as pd
from pathlib import Path
from decimal import Decimal, ROUND_HALF_UP
from lxml import etree as ET

# -------------------------------
# User-tunable constants
# -------------------------------
# Change this to 0 later if you want to default to (qcr + qncr) instead of qtot
Flag_AH_Flux_Total_default = 0

# Optional changes to output map values (from prior script)
MIN_PIXELS = 100
TI_PER_COMBO = 2

# Max steps (used only if a real range exists for that LC in the input maps)
RANGE_STEPS_IC_MAX_DEV = 3   # LC 21–24 (Developed)
RANGE_STEPS_TC_MAX_DEV = 3
RANGE_STEPS_IC_MAX_NON = 2   # all other LCs
RANGE_STEPS_TC_MAX_NON = 2
RANGE_STEPS_AH_MAX     = 2
# Rounding & range detection
ROUND_TO = 1.0                 # snap IC/TC to nearest %, set 5.0 to snap to 5%
MIN_SPAN_FOR_TWO_STEPS = 0.5   # after rounding; if span < this, collapse to 1 step

dem_m = 120
z_FDorganizer_deg = 180
z_AspectGround_N_0_rad = 0
z_SlopeGround_rad = 0.001
AH_flux_Qtot_avg_Wpm2_default = 0
AH_flux_Qcr_avg_Wpm2_default = 0
AH_flux_Qncr_avg_Wpm2_default = 0
blockgroup_flag = 1
blockgroup_starting_value = 1

ASCII_HEADER_KEYS = ["ncols", "nrows", "xllcorner", "yllcorner", "cellsize", "NODATA_value"]

# -------------------------------
# I/O helpers for ESRI ASCII grid
# -------------------------------
def read_ascii_header(filepath):
    header = {}
    with open(filepath, 'r') as f:
        for _ in range(6):
            line = f.readline()
            if not line:
                raise ValueError(f"Unexpected end of header in %s" % filepath)
            key, value = line.strip().split()
            header[key] = float(value) if '.' in value else int(value)
    return header

def read_ascii_grid(filepath):
    header = read_ascii_header(filepath)
    nrows, ncols = int(header["nrows"]), int(header["ncols"])
    data = np.loadtxt(filepath, skiprows=6)
    if data.shape != (nrows, ncols):
        data = data.reshape((nrows, ncols))
    return header, data

def write_ascii(filename, header, array, fmt='%.5f'):
    with open(filename, 'w') as f:
        for key in ASCII_HEADER_KEYS:
            f.write(f"{key} {int(header[key]) if isinstance(header[key], int) else header[key]}\n")
        np.savetxt(f, array, fmt=fmt)

# -----------------------------------------------
# Helper Functions
# -----------------------------------------------
def _ensure_array(x, fallback=0.0):
    """Return x as a float numpy array; if None or empty, return [fallback]."""
    if x is None:
        return np.array([fallback], dtype=float)
    x = np.asarray(x, dtype=float)
    return x if x.size > 0 else np.array([fallback], dtype=float)

def _copy_comments_between_trees(src_tree, dst_tree):
    def key(elem, tree):
        for attr in ("id", "name"):
            if elem.get(attr) is not None:
                return f"{elem.tag}|{attr}={elem.get(attr)}"
        return tree.getpath(elem)

    def collect(parent, child):
        kids = list(parent)
        before, after = [], []
        try:
            i = kids.index(child)
        except ValueError:
            return before, after
        j = i - 1
        while j >= 0 and isinstance(kids[j], ET._Comment):
            t = (kids[j].text or "").strip()
            if t: before.append(t)
            j -= 1
        before.reverse()
        j = i + 1
        while j < len(kids) and isinstance(kids[j], ET._Comment):
            t = (kids[j].text or "").strip()
            if t: after.append(t)
            j += 1
        return before, after

    def index_src(src_tree):
        out = {}
        for e in src_tree.iter():
            if not isinstance(e.tag, str):
                continue
            p = e.getparent()
            if p is None:
                continue
            b, a = collect(p, e)
            if b or a:
                out[key(e, src_tree)] = {"before": b, "after": a}
        return out

    def ensure(parent, child, befores, afters):
        kids = list(parent)
        eb, ea = collect(parent, child)
        pos = kids.index(child)
        for t in befores:
            if t not in eb:
                parent.insert(pos, ET.Comment(t))
                pos += 1
                eb.append(t)
        kids = list(parent)
        pos = kids.index(child) + 1
        for t in afters:
            if t not in ea:
                parent.insert(pos, ET.Comment(t))
                pos += 1
                ea.append(t)

    idx = index_src(src_tree)
    for e in dst_tree.iter():
        if not isinstance(e.tag, str):
            continue
        p = e.getparent()
        if p is None:
            continue
        k = key(e, dst_tree)
        if k in idx:
            ensure(p, e, idx[k]["before"], idx[k]["after"])

def adjust_config_xml(config_path, new_output_dir, model_selection):
    tree = ET.parse(config_path)
    root = tree.getroot()
    catchment_area = None
    for elem in root.iter():
        if elem.tag == "CatchmentArea_m2":
            catchment_area = float(elem.text)
        elif elem.tag == "OutputFolder_Path":
            elem.text = str(new_output_dir) + os.sep
        elif elem.tag == "Model_Selection":
            elem.text = model_selection
        elif elem.tag == "flag_Recompute_TopographicIndex":
            elem.text = "0"
        elif elem.tag == "Flag_DEM_RemovePitsFlats":
            elem.text = "0"
    if catchment_area is None:
        raise ValueError("CatchmentArea_m2 not found in HydroPlusConfig.xml")
    return tree, catchment_area

def get_ti_bins(csv_path, tc_ic_value_count_per_row, nlcd_class_list, tc_ic_increment_per_row):
    df = pd.read_csv(csv_path, skiprows=2, header=None)
    TI_area_per_bin_fraction = df[0].values
    TI_value_per_bin         = df[1].values
    per_row_N  = np.asarray(tc_ic_value_count_per_row, dtype=int)
    per_row_px = (per_row_N ** 2).astype(int)
    total_px   = int(per_row_px.sum())
    scale = total_px / TI_area_per_bin_fraction.sum()
    px_per_bin = np.round(TI_area_per_bin_fraction * scale).astype(int)
    while px_per_bin.sum() < total_px:
        px_per_bin[np.argmax(TI_area_per_bin_fraction)] += 1
    while px_per_bin.sum() > total_px:
        px_per_bin[np.argmax(px_per_bin)] -= 1
    ti_values_sorted = np.repeat(TI_value_per_bin, px_per_bin)
    ti_values_sorted.sort()
    return ti_values_sorted

def get_ti_flat_from_histogram(csv_path, total_pixels):
    df = pd.read_csv(csv_path, skiprows=2, header=None)
    area_frac = df[0].values.astype(float)
    ti_vals   = df[1].values.astype(float)
    if area_frac.sum() <= 0:
        raise ValueError("TI histogram area fractions must sum to a positive number.")
    scale = total_pixels / area_frac.sum()
    counts = np.round(area_frac * scale).astype(int)
    while counts.sum() < total_pixels:
        counts[np.argmax(area_frac)] += 1
    while counts.sum() > total_pixels:
        counts[np.argmax(counts)] -= 1
    flat = np.repeat(ti_vals, counts)
    flat.sort()
    return flat

def fmt_from_dem_decimal(dem_m):
    d = Decimal(str(dem_m)).quantize(Decimal('0.01'), rounding=ROUND_HALF_UP)
    if d == d.quantize(Decimal('1')):
        return '%d', 0, int(d)
    elif d == d.quantize(Decimal('0.1')):
        return '%.1f', 1, float(d)
    else:
        return '%.2f', 2, float(d)

# -------------------------------
# Utility
# -------------------------------
def first_match(patterns, root):
    for p in patterns:
        files = sorted(glob.glob(os.path.join(root, p)))
        if files:
            return files[0]
    return None

def combine_qcr_qncr(qcr, qncr, nodata):
    mask = (qcr == nodata) | (qncr == nodata)
    out = np.where(mask, nodata, qcr + qncr)
    return out

def summarize_ranges_by_lc(lc, ic, tc, ah, nodata):
    ranges = {}
    unique = np.unique(lc[lc != nodata]).tolist()
    for v in unique:
        m = (lc == v)
        if ic is not None: m &= (ic != nodata)
        if tc is not None: m &= (tc != nodata)
        if ah is not None: m &= (ah != nodata)
        if not np.any(m):
            continue
        rng = {}
        if ic is not None:
            vals = ic[m].astype(float)
            rng['ic'] = (float(np.nanmin(vals)), float(np.nanmax(vals)))
        if tc is not None:
            vals = tc[m].astype(float)
            rng['tc'] = (float(np.nanmin(vals)), float(np.nanmax(vals)))
        if ah is not None:
            vals = ah[m].astype(float)
            rng['ah'] = (float(np.nanmin(vals)), float(np.nanmax(vals)))
        ranges[int(v)] = rng
    return ranges, sorted(ranges.keys())

def _linspace_inclusive(lo, hi, n, round_to=None):
    if lo is None or hi is None:
        return None
    lo = float(lo); hi = float(hi)
    if np.isclose(lo, hi):
        arr = np.array([lo], dtype=float)
    else:
        arr = np.linspace(lo, hi, int(max(1, n)))
    if round_to and round_to > 0:
        arr = np.round(arr / round_to) * round_to
    return np.unique(arr)

def _effective_span(lo, hi, round_to=None):
    """Span after optional rounding; returns 0 when lo≈hi."""
    if lo is None or hi is None:
        return 0.0
    lo = float(lo); hi = float(hi)
    if round_to and round_to > 0:
        lo = np.round(lo / round_to) * round_to
        hi = np.round(hi / round_to) * round_to
    return float(abs(hi - lo))

def _decide_steps(lo, hi, max_steps, round_to=None, min_span_for_two=0.5):
    """Return 1 if span too small (or single value), else max_steps."""
    span = _effective_span(lo, hi, round_to=round_to)
    if span < min_span_for_two:
        return 1
    return int(max(1, max_steps))

def build_combos_from_ranges_adaptive(
    unique_lc, ranges,
    ic_steps_dev_max=5, tc_steps_dev_max=5,
    ic_steps_non_max=3, tc_steps_non_max=3,
    ah_steps_max=3,
    round_to=1.0, min_span_for_two=0.5
):
    """
    Adaptive IC × TC × AH combos per LC:
    - For LC 21–24, use *_dev_max steps if the observed span warrants it.
    - For others, use *_non_max steps if the span warrants it.
    - If span after rounding is < min_span_for_two, collapse to 1 step.
    - AH steps adapt similarly; if range collapses to a single value, use 1.
    Returns a flat list of (lc, ic, tc, ah) tuples.
    """
    combos = []
    for lc_val in unique_lc:
        r = ranges.get(lc_val, {})
        ic_lo, ic_hi = (r.get('ic') or (0.0, 0.0))
        tc_lo, tc_hi = (r.get('tc') or (0.0, 0.0))
        ah_lo, ah_hi = (r.get('ah') or (0.0, 0.0))

        is_dev = (21 <= lc_val <= 24)
        ic_max = ic_steps_dev_max if is_dev else ic_steps_non_max
        tc_max = tc_steps_dev_max if is_dev else tc_steps_non_max

        ic_n = _decide_steps(ic_lo, ic_hi, ic_max, round_to=round_to, min_span_for_two=min_span_for_two)
        tc_n = _decide_steps(tc_lo, tc_hi, tc_max, round_to=round_to, min_span_for_two=min_span_for_two)
        ah_n = _decide_steps(ah_lo, ah_hi, ah_steps_max, round_to=None,    min_span_for_two=1e-6)

        ic_vals = _ensure_array(_linspace_inclusive(ic_lo, ic_hi, ic_n, round_to=round_to), fallback=0.0)
        tc_vals = _ensure_array(_linspace_inclusive(tc_lo, tc_hi, tc_n, round_to=round_to), fallback=0.0)
        ah_vals = _ensure_array(_linspace_inclusive(ah_lo, ah_hi, ah_n, round_to=None),    fallback=0.0)
        
        #Ensure range is within absolute range
        ic_vals = np.clip(ic_vals, 0.0, 100.0)
        tc_vals = np.clip(tc_vals, 0.0, 100.0)


        # Cartesian product
        for ic_v in ic_vals:
            for tc_v in tc_vals:
                for ah_v in ah_vals:
                    combos.append((int(lc_val), float(ic_v), float(tc_v), float(ah_v)))
    return combos

def minimal_grid_from_combos_with_balanced_ti(
    combos, ti_flat, ti_per_combo=3, min_pixels=200, nodata=-9999.0
):
    """
    Build a compact rectangle containing the given (LC,IC,TC,AH) combos and
    assign TI so that:
      • the TI pool spans the *full* histogram, and
      • priority ordering mirrors the 'working' script:

        - LC in {11, 90, 95} -> highest TI (with tiny tie-break on TC)
        - LC in [21..43]     -> TI increases monotonically with TC
        - other LCs          -> bucket into low/med/high bands by TC

    Returns: lc_arr, ic_arr, tc_arr, ah_arr, ti_arr
    """
    import math
    import numpy as np

    # ----- 1) Layout the rectangle and fill LC/IC/TC/AH (as before) -----
    n_combo = max(1, len(combos))
    need    = n_combo * ti_per_combo
    target  = max(min_pixels, need)

    rows = int(math.ceil(math.sqrt(target)))
    cols = int(math.ceil(target / rows))

    lc_arr = np.full((rows, cols), nodata, dtype=int)
    ic_arr = np.full((rows, cols), nodata, dtype=float)
    tc_arr = np.full((rows, cols), nodata, dtype=float)
    ah_arr = np.full((rows, cols), nodata, dtype=float)

    idx = 0
    for i, (lc_v, ic_v, tc_v, ah_v) in enumerate(combos):
        for _ in range(ti_per_combo):
            if idx >= rows * cols:
                break
            r, c = divmod(idx, cols)
            lc_arr[r, c] = int(lc_v)
            ic_arr[r, c] = float(ic_v)
            tc_arr[r, c] = float(tc_v)
            ah_arr[r, c] = float(ah_v)
            idx += 1
        if idx >= rows * cols:
            break

    # If we have space left to reach 'target', just keep cycling combos
    j = 0
    while idx < rows * cols and idx < target:
        lc_v, ic_v, tc_v, ah_v = combos[j % n_combo]
        r, c = divmod(idx, cols)
        lc_arr[r, c] = int(lc_v)
        ic_arr[r, c] = float(ic_v)
        tc_arr[r, c] = float(tc_v)
        ah_arr[r, c] = float(ah_v)
        idx += 1
        j += 1

    # Count valid pixels we actually filled
    valid = (lc_arr != nodata)
    valid_count = int(valid.sum())
    if valid_count == 0:
        # No pixels? Return empty TI
        ti_arr = np.full_like(tc_arr, nodata, dtype=float)
        return lc_arr, ic_arr, tc_arr, ah_arr, ti_arr

    # ----- 2) Build TI pool spanning the full histogram (not just the head) -----
    # ti_flat is sorted ascending; select across its full index range
    if ti_flat.size >= valid_count:
        idxs = np.linspace(0, ti_flat.size - 1, valid_count).astype(int)
        ti_pool = ti_flat[idxs]
    else:
        reps = int(np.ceil(valid_count / ti_flat.size))
        ti_pool = np.tile(ti_flat, reps)[:valid_count]

    # ----- 3) Priority-based TI assignment (class rules + TC dependence) -----
    high_set = {11, 90, 95}

    lc_v = lc_arr[valid].astype(int)
    tc_v = tc_arr[valid].astype(float)

    # Build a key per pixel:
    keys = np.empty(lc_v.size, dtype=float)

    # Normalize TC *within each class* for fair ordering
    for cls in np.unique(lc_v):
        cls_mask = (lc_v == cls)
        tc_cls   = tc_v[cls_mask]
        if tc_cls.size > 1:
            # normalize to [0,1] within class
            rng = tc_cls.max() - tc_cls.min()
            tc_norm = (tc_cls - tc_cls.min()) / rng if rng > 0 else np.zeros_like(tc_cls)
        else:
            tc_norm = np.zeros_like(tc_cls)

        if cls in high_set:
            # highest TI; tiny tie-break on TC
            keys[cls_mask] = 2.0 + 0.001 * tc_norm
        elif 21 <= cls <= 43:
            # monotonic with TC (higher TC -> higher TI)
            keys[cls_mask] = 1.0 + tc_norm
        else:
            # bucket others into low/med/high bands by TC
            if tc_norm.size:
                bands = np.floor(tc_norm * 3.0).astype(int)  # 0,1,2
                bands = np.clip(bands, 0, 2)
                band_off = np.choose(bands, [0.2, 0.5, 0.8])
                keys[cls_mask] = 0.5 + band_off
            else:
                keys[cls_mask] = 0.5 + 0.5  # default 'med'

    # Sort pixels by key, assign sorted TI accordingly
    order = np.argsort(keys)
    ti_sorted = np.sort(ti_pool)
    ti_assigned = np.empty_like(keys, dtype=float)
    ti_assigned[order] = ti_sorted

    # Rebuild TI grid
    ti_arr = np.full_like(tc_arr, nodata, dtype=float)
    ti_arr[valid] = ti_assigned

    # (Optional) quick sanity
    # print(f"TI pool: min={ti_pool.min():.5g}, max={ti_pool.max():.5g}, N={ti_pool.size}")
    # print(f"TI grid: min={np.nanmin(ti_arr):.5g}, max={np.nanmax(ti_arr):.5g}")

    return lc_arr, ic_arr, tc_arr, ah_arr, ti_arr

# -------------------------------
# Main
# -------------------------------
def main():
    if len(sys.argv) != 4:
        print("Usage:\n  python z_c_MinCombos_FromMaps_V1c.py <path1> <path2> <new_project_dir>")
        sys.exit(1)

    path1 = Path(sys.argv[1]).resolve()
    path2 = Path(sys.argv[2]).resolve()
    out_root = Path(sys.argv[3]).resolve()
    out_inputs = out_root / "Inputs"
    out_outputs = out_root / "Outputs"
    out_inputs.mkdir(parents=True, exist_ok=True)
    out_outputs.mkdir(parents=True, exist_ok=True)

    # --------- Load config, preserve comments, and update output folder ---------
    config_path = path1 / "HydroPlusConfig.xml"
    if not config_path.exists():
        raise FileNotFoundError(f"HydroPlusConfig.xml not found in {path1}")
    parser = ET.XMLParser(remove_blank_text=False, remove_comments=False)
    tree_src = ET.parse(str(config_path), parser)
    tree_mod, catchment_area = adjust_config_xml(str(config_path), str(out_outputs), "SpatialTemperatureHydro")
    _copy_comments_between_trees(tree_src, tree_mod)
    tree_mod.write(str(out_inputs / "HydroPlusConfig.xml"), pretty_print=True, encoding="utf-8", xml_declaration=True)

    # --------- Find TI histogram CSV ---------
    ti_csv = None
    for name in ("z_TI_ExponentialDecay.csv", "z_TI_PowerDecay.csv"):
        p = path1 / name
        if p.exists():
            ti_csv = str(p)
            break
    if ti_csv is None:
        raise FileNotFoundError("Could not find z_TI_ExponentialDecay.csv or z_TI_PowerDecay.csv in <path1>")

    # --------- Discover required maps in <path2> ---------
    nodata = -9999.0

    # Primary → fallback patterns for each map type
    lc_file = first_match(["lc_*.asc", "landcover.asc"], str(path2))
    ic_file = first_match(["ic_*.asc", "imperviouscover.asc"], str(path2))
    tc_file = first_match(["tc_*.asc", "treecover.asc"], str(path2))

    missing = []
    if not lc_file:
        missing.append("lc_*.asc or landcover.asc")
    if not ic_file:
        missing.append("ic_*.asc or imperviouscover.asc")
    if not tc_file:
        missing.append("tc_*.asc or treecover.asc")

    if missing:
        raise FileNotFoundError(
            "Missing required map file(s) in <path2>: " + ", ".join(missing)
        )

    lc_hdr, lc = read_ascii_grid(lc_file)
    ic_hdr, ic = read_ascii_grid(ic_file)
    tc_hdr, tc = read_ascii_grid(tc_file)
    header = dict(lc_hdr)

    qtot_file = first_match(["ah_flux_qtot_avg_wpm2_*.asc"], str(path2))
    qcr_file  = first_match(["ah_flux_qcr_avg_wpm2_*.asc"],  str(path2))
    qncr_file = first_match(["ah_flux_qncr_avg_wpm2_*.asc"], str(path2))

    ah = None
    if qcr_file and qncr_file:
        _, qcr = read_ascii_grid(qcr_file)
        _, qncr = read_ascii_grid(qncr_file)
        ah_combined = combine_qcr_qncr(qcr, qncr, nodata)
        if qtot_file:
            _, qtot = read_ascii_grid(qtot_file)
            ah = qtot if Flag_AH_Flux_Total_default == 1 else ah_combined
        else:
            ah = ah_combined
    elif qtot_file:
        _, ah = read_ascii_grid(qtot_file)
    else:
        ah = None

    # --------- Build minimal-combo synthetic maps with balanced TI ---------
    # Requires your balanced function and constants: TI_PER_COMBO, MIN_PIXELS
    # --------- Compute ranges per NLCD class ---------
    ranges, unique_lc = summarize_ranges_by_lc(lc, ic, tc, ah, nodata)

    # # --------- TI pool BEFORE building the grid ---------
    # The argument total_pixels tells it how many synthetic “pixels” to draw according to those area fractions.
    # The resulting ti_flat vector is later used to assign TI values across your synthetic map grid.
    ti_flat = get_ti_flat_from_histogram(ti_csv, total_pixels=100000)

    # --------- Build ranged combos per LC (denser for 21–24) ---------
    combos = build_combos_from_ranges_adaptive(
        unique_lc, ranges,
        ic_steps_dev_max=RANGE_STEPS_IC_MAX_DEV,
        tc_steps_dev_max=RANGE_STEPS_TC_MAX_DEV,
        ic_steps_non_max=RANGE_STEPS_IC_MAX_NON,
        tc_steps_non_max=RANGE_STEPS_TC_MAX_NON,
        ah_steps_max=RANGE_STEPS_AH_MAX,
        round_to=ROUND_TO,
        min_span_for_two=MIN_SPAN_FOR_TWO_STEPS
    )

    # --------- Make minimal rectangle with balanced TI coverage ---------
    lc_new, ic_new, tc_new, ah_new, ti_grid = minimal_grid_from_combos_with_balanced_ti(
        combos, ti_flat,
        ti_per_combo=TI_PER_COMBO,
        min_pixels=MIN_PIXELS,
        nodata=nodata
    )
    # Harmonize header with new grid size; derive cellsize from catchment area
    nrows, ncols = lc_new.shape
    valid_pixels = int(np.sum(lc_new != nodata))
    if valid_pixels <= 0:
        raise ValueError("No valid pixels produced in minimal-combo grid.")
    cell_area = catchment_area / valid_pixels
    cellsize = int(round(math.sqrt(cell_area))) if cell_area > 0 else int(header.get("cellsize", 10))

    out_header = dict(header)
    out_header.update({
        "ncols": ncols,
        "nrows": nrows,
        "cellsize": cellsize,
        "NODATA_value": nodata
    })

    # --------- Write synthetic maps ---------
    fmt, ndp, dem_m_rounded = fmt_from_dem_decimal(dem_m)
    dem_arr    = np.where(lc_new != nodata, float(dem_m_rounded), nodata)
    fd_arr     = np.where(lc_new != nodata, float(z_FDorganizer_deg), nodata)
    aspect_arr = np.where(lc_new != nodata, float(z_AspectGround_N_0_rad), nodata)
    slope_arr  = np.where(lc_new != nodata, float(z_SlopeGround_rad), nodata)

    write_ascii(str(out_inputs / "dem.asc"), out_header, dem_arr, fmt=fmt)
    write_ascii(str(out_inputs / "z_FDorganizer.asc"), out_header, fd_arr, fmt='%d')
    write_ascii(str(out_inputs / "z_AspectGround_N_0_rad.asc"), out_header, aspect_arr, fmt='%d')
    write_ascii(str(out_inputs / "z_SlopeGround_rad.asc"), out_header, slope_arr)

    write_ascii(str(out_inputs / "landcover.asc"),        out_header, lc_new, fmt='%d')
    write_ascii(str(out_inputs / "imperviouscover.asc"),  out_header, ic_new, fmt='%.2f')
    write_ascii(str(out_inputs / "treecover.asc"),        out_header, tc_new, fmt='%.2f')

    if ah_new is not None:
        write_ascii(str(out_inputs / "AH_flux_avg_Wpm2.asc"), out_header, ah_new, fmt='%.2f')
    else:
        ah_default = np.where(lc_new != nodata, 0.0, nodata)
        write_ascii(str(out_inputs / "AH_flux_avg_Wpm2.asc"), out_header, ah_default, fmt='%d')

    # Blockgroup map
    bg = np.full_like(lc_new, nodata, dtype=float)
    next_id = blockgroup_starting_value
    for r in range(nrows):
        for c in range(ncols):
            if lc_new[r, c] != nodata:
                if blockgroup_flag == 1:
                    bg[r, c] = next_id
                    next_id += 1
                else:
                    bg[r, c] = blockgroup_starting_value
    write_ascii(str(out_inputs / "blockgroup.asc"), out_header, bg, fmt='%d')

    # --------- TI organizer: write the balanced TI grid returned above ---------
    write_ascii(str(out_inputs / "z_TIorganizer.asc"), out_header, ti_grid)

    
    # --------- Copy any convenient inputs if present ---------
    for fname in ["Weather.csv", "Radiation.csv"]:
        src = path1 / fname
        if src.exists():
            shutil.copy(str(src), str(out_inputs / fname))

    # --------- Print concise summary ---------
    print("Minimal-combo project created.")
    print(f"Inputs written to: {out_inputs}")
    print("Per-NLCD ranges (ic/tc/ah):")
    for lc_val in unique_lc:
        r = ranges.get(lc_val, {})
        print(f"  LC {lc_val}: ic={r.get('ic')}  tc={r.get('tc')}  ah={r.get('ah')}")
    print("Done.")

if __name__ == "__main__":
    main()
