Source code for moderndid.core.converters

"""Converters for transforming DiD result objects to polars DataFrames."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np
import polars as pl

if TYPE_CHECKING:
    from moderndid.did.container import AGGTEResult, MPResult
    from moderndid.didcont.container import DoseResult, PTEResult
    from moderndid.didhonest.honest_did import HonestDiDResult
    from moderndid.didinter.container import DIDInterResult, HeterogeneityResult
    from moderndid.didtriple.container import DDDAggResult, DDDMultiPeriodRCResult, DDDMultiPeriodResult


[docs] def mpresult_to_polars(result: MPResult) -> pl.DataFrame: """Convert MPResult to polars DataFrame for plotting. Parameters ---------- result : MPResult Multi-period DID result containing group-time ATT estimates. Returns ------- pl.DataFrame DataFrame with columns: - group: treatment cohort - time: time period - att: group-time ATT estimate - se: standard error - ci_lower: lower confidence interval - ci_upper: upper confidence interval - treatment_status: "Pre" or "Post" treatment """ groups = result.groups times = result.times att = result.att_gt se = result.se_gt crit_val = result.critical_value ci_lower = att - crit_val * se ci_upper = att + crit_val * se treatment_status = np.array(["Pre" if t < g else "Post" for g, t in zip(groups, times, strict=False)]) return pl.DataFrame( { "group": groups, "time": times, "att": att, "se": se, "ci_lower": ci_lower, "ci_upper": ci_upper, "treatment_status": treatment_status, } )
[docs] def aggteresult_to_polars(result: AGGTEResult) -> pl.DataFrame: """Convert AGGTEResult to polars DataFrame for plotting. Parameters ---------- result : AGGTEResult Aggregated treatment effect result (dynamic, group, or calendar). Returns ------- pl.DataFrame DataFrame with columns: - event_time: event time (for dynamic), group (for group), or time (for calendar) - att: ATT estimate - se: standard error - ci_lower: lower confidence interval - ci_upper: upper confidence interval - treatment_status: "Pre" or "Post" (for dynamic aggregation) Raises ------ ValueError If result is simple aggregation or missing required data. """ if result.aggregation_type == "simple": raise ValueError("Simple aggregation does not produce event-level data for plotting.") if result.event_times is None or result.att_by_event is None or result.se_by_event is None: raise ValueError( f"AGGTEResult with aggregation_type='{result.aggregation_type}' " "must have event_times, att_by_event, and se_by_event" ) event_times = result.event_times att = result.att_by_event se = result.se_by_event crit_vals = result.critical_values if result.critical_values is not None else np.full_like(se, 1.96) ci_lower = att - crit_vals * se ci_upper = att + crit_vals * se data = { "event_time": event_times, "att": att, "se": se, "ci_lower": ci_lower, "ci_upper": ci_upper, } if result.aggregation_type == "dynamic": data["treatment_status"] = np.array(["Pre" if e < 0 else "Post" for e in event_times]) df = pl.DataFrame(data) return df.filter(~pl.col("se").is_nan())
[docs] def doseresult_to_polars(result: DoseResult, effect_type: str = "att") -> pl.DataFrame: """Convert DoseResult to polars DataFrame for plotting. Parameters ---------- result : DoseResult Continuous treatment dose-response result. effect_type : {'att', 'acrt'}, default='att' Type of effect to extract: - 'att': Average Treatment Effect on Treated - 'acrt': Average Causal Response on Treated Returns ------- pl.DataFrame DataFrame with columns: - dose: dose level - effect: effect estimate (ATT or ACRT) - se: standard error - ci_lower: lower confidence interval - ci_upper: upper confidence interval Raises ------ ValueError If effect_type is invalid or required data is missing. """ dose = result.dose if effect_type == "att": effect = result.att_d se = result.att_d_se crit_val = result.att_d_crit_val elif effect_type == "acrt": effect = result.acrt_d se = result.acrt_d_se crit_val = result.acrt_d_crit_val else: raise ValueError(f"effect_type must be 'att' or 'acrt', got '{effect_type}'") if effect is None or se is None: raise ValueError(f"DoseResult missing {effect_type.upper()} data") if crit_val is None or np.isnan(crit_val): crit_val = 1.96 ci_lower = effect - crit_val * se ci_upper = effect + crit_val * se return pl.DataFrame( { "dose": dose, "effect": effect, "se": se, "ci_lower": ci_lower, "ci_upper": ci_upper, } )
[docs] def pteresult_to_polars(result: PTEResult) -> pl.DataFrame: """Convert PTEResult event study to polars DataFrame for plotting. Parameters ---------- result : PTEResult Panel treatment effects result with event_study. Returns ------- pl.DataFrame DataFrame with columns: - event_time: event time relative to treatment - att: ATT estimate - se: standard error - ci_lower: lower confidence interval - ci_upper: upper confidence interval - treatment_status: "Pre" or "Post" treatment Raises ------ ValueError If result does not contain event study. """ if result.event_study is None: raise ValueError("PTEResult does not contain event study results") event_study = result.event_study event_times = event_study.event_times att = event_study.att_by_event se = event_study.se_by_event if hasattr(event_study, "critical_value") and event_study.critical_value is not None: crit_val = event_study.critical_value else: crit_val = 1.96 ci_lower = att - crit_val * se ci_upper = att + crit_val * se treatment_status = np.array(["Pre" if e < 0 else "Post" for e in event_times]) df = pl.DataFrame( { "event_time": event_times, "att": att, "se": se, "ci_lower": ci_lower, "ci_upper": ci_upper, "treatment_status": treatment_status, } ) return df.filter(~pl.col("se").is_nan())
[docs] def honestdid_to_polars(result: HonestDiDResult) -> pl.DataFrame: """Convert HonestDiDResult to polars DataFrame for plotting. Parameters ---------- result : HonestDiDResult Honest DiD sensitivity analysis result. Returns ------- pl.DataFrame DataFrame with columns: - param_value: M or Mbar parameter value - method: CI method name - lb: lower bound of confidence interval - ub: upper bound of confidence interval - midpoint: (lb + ub) / 2 Combined with original CI at param_value before the minimum robust value. Raises ------ ValueError If result has empty robust_ci DataFrame. """ robust_df = result.robust_ci original = result.original_ci if robust_df.is_empty(): raise ValueError("HonestDiDResult has empty robust_ci DataFrame") if "M" in robust_df.columns: param_col = "M" elif "m" in robust_df.columns: param_col = "m" elif "Mbar" in robust_df.columns: param_col = "Mbar" else: raise ValueError("robust_ci must have 'M', 'm', or 'Mbar' column") m_values = robust_df[param_col].unique().sort().to_numpy() m_gap = np.min(np.diff(m_values)) if len(m_values) > 1 else m_values[0] if len(m_values) > 0 else 1.0 original_m = m_values[0] - m_gap original_row = pl.DataFrame( { param_col: [original_m], "lb": [original.lb], "ub": [original.ub], "method": [getattr(original, "method", "Original")], } ) combined = pl.concat([original_row, robust_df.select([param_col, "lb", "ub", "method"])]) combined = combined.with_columns( [ ((pl.col("lb") + pl.col("ub")) / 2).alias("midpoint"), ] ) combined = combined.rename({param_col: "param_value"}) return combined.sort(["method", "param_value"])
[docs] def dddmpresult_to_polars(result: DDDMultiPeriodResult | DDDMultiPeriodRCResult) -> pl.DataFrame: """Convert DDDMultiPeriodResult or DDDMultiPeriodRCResult to polars DataFrame for plotting. Parameters ---------- result : DDDMultiPeriodResult or DDDMultiPeriodRCResult Multi-period DDD result containing group-time ATT estimates. Returns ------- pl.DataFrame DataFrame with columns: - group: treatment cohort - time: time period - att: group-time ATT estimate - se: standard error - ci_lower: lower confidence interval - ci_upper: upper confidence interval - treatment_status: "Pre" or "Post" treatment """ groups = result.groups times = result.times att = result.att se = result.se ci_lower = result.lci ci_upper = result.uci treatment_status = np.array(["Pre" if t < g else "Post" for g, t in zip(groups, times, strict=False)]) df = pl.DataFrame( { "group": groups, "time": times, "att": att, "se": se, "ci_lower": ci_lower, "ci_upper": ci_upper, "treatment_status": treatment_status, } ) return df.filter(~pl.col("se").is_nan())
[docs] def dddaggresult_to_polars(result: DDDAggResult) -> pl.DataFrame: """Convert DDDAggResult to polars DataFrame for plotting. Parameters ---------- result : DDDAggResult Aggregated DDD treatment effect result (eventstudy, group, or calendar). Returns ------- pl.DataFrame DataFrame with columns: - event_time: event time (for eventstudy), group (for group), or time (for calendar) - att: ATT estimate - se: standard error - ci_lower: lower confidence interval - ci_upper: upper confidence interval - treatment_status: "Pre" or "Post" (for eventstudy aggregation) Raises ------ ValueError If result is simple aggregation or missing required data. """ if result.aggregation_type == "simple": raise ValueError("Simple aggregation does not produce event-level data for plotting.") if result.egt is None or result.att_egt is None or result.se_egt is None: raise ValueError( f"DDDAggResult with aggregation_type='{result.aggregation_type}' must have egt, att_egt, and se_egt" ) event_times = result.egt att = result.att_egt se = result.se_egt crit_val = result.crit_val if result.crit_val is not None else 1.96 ci_lower = att - crit_val * se ci_upper = att + crit_val * se data = { "event_time": event_times, "att": att, "se": se, "ci_lower": ci_lower, "ci_upper": ci_upper, } if result.aggregation_type == "eventstudy": data["treatment_status"] = np.array(["Pre" if e < 0 else "Post" for e in event_times]) df = pl.DataFrame(data) return df.filter(~pl.col("se").is_nan())
[docs] def didinterresult_to_polars(result: DIDInterResult) -> pl.DataFrame: """Convert DIDInterResult to polars DataFrame for plotting. Parameters ---------- result : DIDInterResult Intertemporal treatment effects result from did_multiplegt(). Returns ------- pl.DataFrame DataFrame with columns: - horizon: event horizon (negative for placebos, positive for effects) - att: effect estimate - se: standard error - ci_lower: lower confidence interval - ci_upper: upper confidence interval - treatment_status: "Pre" or "Post" treatment """ effects = result.effects placebos = result.placebos horizons = [] att = [] se = [] ci_lower = [] ci_upper = [] treatment_status = [] if placebos is not None: sorted_indices = np.argsort(placebos.horizons) horizons.extend(placebos.horizons[sorted_indices]) att.extend(placebos.estimates[sorted_indices]) se.extend(placebos.std_errors[sorted_indices]) ci_lower.extend(placebos.ci_lower[sorted_indices]) ci_upper.extend(placebos.ci_upper[sorted_indices]) treatment_status.extend(["Pre"] * len(placebos.horizons)) sorted_indices = np.argsort(effects.horizons) horizons.extend(effects.horizons[sorted_indices]) att.extend(effects.estimates[sorted_indices]) se.extend(effects.std_errors[sorted_indices]) ci_lower.extend(effects.ci_lower[sorted_indices]) ci_upper.extend(effects.ci_upper[sorted_indices]) treatment_status.extend(["Post"] * len(effects.horizons)) return pl.DataFrame( { "horizon": horizons, "att": att, "se": se, "ci_lower": ci_lower, "ci_upper": ci_upper, "treatment_status": treatment_status, } )
def heterogeneityresult_to_polars(result: HeterogeneityResult) -> pl.DataFrame: """Convert HeterogeneityResult to polars DataFrame. Parameters ---------- result : HeterogeneityResult Heterogeneous effects analysis result for a single horizon. Returns ------- pl.DataFrame DataFrame with columns: - Horizon: effect horizon analyzed - Covariate: covariate name - Estimate: coefficient estimate - Std. Error: standard error - t-stat: t-statistic - CI Lower: lower confidence interval bound - CI Upper: upper confidence interval bound - N: number of observations - F p-value: joint F-test p-value """ return pl.DataFrame( { "Horizon": [result.horizon] * len(result.covariates), "Covariate": result.covariates, "Estimate": result.estimates, "Std. Error": result.std_errors, "t-stat": result.t_stats, "CI Lower": result.ci_lower, "CI Upper": result.ci_upper, "N": [result.n_obs] * len(result.covariates), "F p-value": [result.f_pvalue] * len(result.covariates), } ) _DISPATCH: dict[str, Any] = { "MPResult": mpresult_to_polars, "AGGTEResult": aggteresult_to_polars, "DoseResult": doseresult_to_polars, "PTEResult": pteresult_to_polars, "HonestDiDResult": honestdid_to_polars, "DDDMultiPeriodResult": dddmpresult_to_polars, "DDDMultiPeriodRCResult": dddmpresult_to_polars, "DDDAggResult": dddaggresult_to_polars, "DIDInterResult": didinterresult_to_polars, "HeterogeneityResult": heterogeneityresult_to_polars, }
[docs] def to_df(result: Any, **kwargs: Any) -> pl.DataFrame: """Convert any ModernDiD result object to a polars DataFrame. Parameters ---------- result : Any A ModernDiD result object. Supported types: - :class:`~moderndid.did.container.AGGTEResult` - :class:`~moderndid.did.container.MPResult` - :class:`~moderndid.didcont.container.DoseResult` - :class:`~moderndid.didcont.container.PTEResult` - :class:`~moderndid.didhonest.honest_did.HonestDiDResult` - :class:`~moderndid.didtriple.container.DDDAggResult` - :class:`~moderndid.didtriple.container.DDDMultiPeriodResult` - :class:`~moderndid.didtriple.container.DDDMultiPeriodRCResult` - :class:`~moderndid.didinter.container.DIDInterResult` - :class:`~moderndid.didinter.container.HeterogeneityResult` **kwargs Additional arguments passed to the underlying converter. For example, ``effect_type="acrt"`` for DoseResult. Returns ------- pl.DataFrame DataFrame with columns appropriate to the result type. Examples -------- Convert group-time ATT results from :func:`~moderndid.att_gt` into a tidy DataFrame with one row per (group, time) cell: .. ipython:: :okwarning: In [1]: from moderndid import att_gt, aggte, load_mpdta, to_df ...: ...: df = load_mpdta() ...: result = att_gt( ...: data=df, ...: yname="lemp", ...: tname="year", ...: gname="first.treat", ...: idname="countyreal", ...: est_method="dr", ...: boot=False, ...: ) ...: print(to_df(result).head()) Aggregated event-study results work the same way: .. ipython:: :okwarning: In [2]: agg = aggte(result, type="dynamic") ...: print(to_df(agg)) """ type_name = type(result).__name__ converter = _DISPATCH.get(type_name) if converter is None: raise TypeError(f"No converter for {type_name!r}. Supported types: {', '.join(sorted(_DISPATCH))}") return converter(result, **kwargs)