Source code for moderndid.didcont.container

"""Containers for panel treatment effects."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Literal, NamedTuple

import numpy as np
import polars as pl
from scipy import stats

from moderndid.core.maketables import (
    build_coef_table_with_ci,
    build_single_coef_table,
    format_effect_value,
    make_effect_names,
    se_type_label,
    vcov_info_from_bootstrap,
)


class PTEParams(NamedTuple):
    """Container for panel treatment effect parameters.

    Attributes
    ----------
    yname : str
        Name of the outcome variable.
    gname : str
        Name of the group variable (first treatment period).
    tname : str
        Name of the time period variable.
    idname : str
        Name of the id variable.
    data : pl.DataFrame
        Panel data as a pandas DataFrame.
    g_list : np.ndarray
        Array of unique group identifiers.
    t_list : np.ndarray
        Array of unique time period identifiers.
    cband : bool
        Whether to compute a uniform confidence band.
    alp : float
        Significance level for confidence intervals.
    boot_type : str
        Method for bootstrapping.
    anticipation : int
        Number of periods of anticipation.
    base_period : str
        Base period for computing ATT(g,t).
    weightsname : str
        Name of the weights variable.
    control_group : str
        Which units to use as the control group.
    gt_type : str
        Type of group-time average treatment effect.
    ret_quantile : float
        Quantile to return for conditional distribution.
    biters : int
        Number of bootstrap iterations.
    dname : str
        Name of the continuous treatment variable.
    degree : int
        Degree of the spline for continuous treatment.
    num_knots : int
        Number of knots for the spline.
    knots : np.ndarray
        Array of knot locations for the spline.
    dvals : np.ndarray
        Values of the dose to evaluate the dose-response function.
    target_parameter : str
        The target parameter of interest.
    aggregation : str
        Type of aggregation for results.
    treatment_type : str
        Type of treatment (e.g., 'continuous').
    xformula : str
        Formula for covariates.
    dose_est_method : str
        Method for estimating dose-specific effects ('parametric' or 'cck').
    """

    #: Name of the outcome variable.
    yname: str
    #: Name of the group variable (first treatment period).
    gname: str
    #: Name of the time period variable.
    tname: str
    #: Name of the id variable.
    idname: str
    #: Panel data as a Polars DataFrame.
    data: pl.DataFrame
    #: Array of unique group identifiers.
    g_list: np.ndarray
    #: Array of unique time period identifiers.
    t_list: np.ndarray
    #: Whether to compute a uniform confidence band.
    cband: bool
    #: Significance level for confidence intervals.
    alp: float
    #: Method for bootstrapping.
    boot_type: str
    #: Number of periods of anticipation.
    anticipation: int
    #: Base period for computing ATT(g,t).
    base_period: str
    #: Name of the weights variable.
    weightsname: str
    #: Which units to use as the control group.
    control_group: str
    #: Type of group-time average treatment effect.
    gt_type: str
    #: Quantile to return for conditional distribution.
    ret_quantile: float
    #: Number of bootstrap iterations.
    biters: int
    #: Name of the continuous treatment variable.
    dname: str
    #: Degree of the spline for continuous treatment.
    degree: int
    #: Number of knots for the spline.
    num_knots: int
    #: Array of knot locations for the spline.
    knots: np.ndarray
    #: Values of the dose to evaluate the dose-response function.
    dvals: np.ndarray
    #: The target parameter of interest.
    target_parameter: str
    #: Type of aggregation for results.
    aggregation: str
    #: Type of treatment.
    treatment_type: str
    #: Formula for covariates.
    xformula: str
    #: Method for estimating dose-specific effects.
    dose_est_method: str = "parametric"


class AttgtResult(NamedTuple):
    """Container for a single ATT(g,t) result with influence function."""

    #: ATT estimate for this (g,t) cell.
    attgt: float
    #: Influence function for this estimate.
    inf_func: np.ndarray | None
    #: Extra returns from group-time calculations.
    extra_gt_returns: dict | None


[docs] class PTEResult(NamedTuple): """Container for panel treatment effects results. This class implements the ``maketables`` plug-in interface for publication-quality tables. See :ref:`publication_tables`. """ #: Group-time ATT results. att_gt: object #: Overall ATT estimate. overall_att: object | None #: Event study results. event_study: object | None #: Panel treatment effect parameters. ptep: PTEParams @property def __maketables_coef_table__(self): """Delegate coefficient extraction to the most informative nested result.""" if self.event_study is not None and hasattr(self.event_study, "__maketables_coef_table__"): return self.event_study.__maketables_coef_table__ if self.overall_att is not None: att = getattr(self.overall_att, "overall_att", None) se = getattr(self.overall_att, "overall_se", None) if att is not None and se is not None: return build_single_coef_table("Overall ATT", float(att), float(se)) if isinstance(self.overall_att, dict): att = self.overall_att.get("overall_att") or self.overall_att.get("att") se = self.overall_att.get("overall_se") or self.overall_att.get("se") if att is not None and se is not None: return build_single_coef_table("Overall ATT", float(att), float(se)) if self.att_gt is not None and hasattr(self.att_gt, "__maketables_coef_table__"): return self.att_gt.__maketables_coef_table__ raise ValueError("PTEResult does not contain a maketables-compatible estimate table.") def __maketables_stat__(self, key: str) -> int | float | str | None: """Return model-level statistics for maketables.""" if self.event_study is not None and hasattr(self.event_study, "__maketables_stat__"): return self.event_study.__maketables_stat__(key) if key == "N": return _n_obs_from_pte_params(self.ptep) if key == "se_type": return se_type_label(True) return None @property def __maketables_depvar__(self) -> str: """Return dependent variable label for maketables.""" return str(getattr(self.ptep, "yname", "Outcome")) @property def __maketables_fixef_string__(self) -> str | None: """Continuous DiD result wrappers do not report fixed-effects formulas.""" return None @property def __maketables_vcov_info__(self) -> dict[str, str | None]: """Return variance-covariance metadata.""" return vcov_info_from_bootstrap(is_bootstrap=True) @property def __maketables_default_stat_keys__(self) -> list[str]: """Default model-level stats to display in ETable.""" if self.event_study is not None and hasattr(self.event_study, "__maketables_default_stat_keys__"): return self.event_study.__maketables_default_stat_keys__ return ["N", "se_type"]
[docs] class PTEAggteResult(NamedTuple): """Container for aggregated panel treatment effect parameters. This class implements the ``maketables`` plug-in interface for publication-quality tables. See :ref:`publication_tables`. Attributes ---------- overall_att : float The estimated overall average treatment effect on the treated. overall_se : float Standard error for overall ATT. aggregation_type : {'overall', 'dynamic', 'group'} Type of aggregation performed. event_times : np.ndarray, optional Event/group values depending on aggregation type: - For dynamic effects: length of exposure (event time) - For group effects: treatment group indicators att_by_event : np.ndarray, optional ATT estimates specific to each event time value. se_by_event : np.ndarray, optional Standard errors specific to each event time value. critical_value : float, optional Critical value for uniform confidence bands. influence_func : dict, optional Dictionary containing influence functions: - **overall**: Overall ATT influence function - **by_event**: Event-specific influence functions min_event_time : int, optional Minimum event time (for dynamic effects). max_event_time : int, optional Maximum event time (for dynamic effects). balance_event : int, optional Balanced event time threshold. att_gt_result : object Original group-time ATT result object. """ #: Estimated overall average treatment effect on the treated. overall_att: float #: Standard error for overall ATT. overall_se: float #: Type of aggregation performed. aggregation_type: Literal["overall", "dynamic", "group"] #: Event/group values depending on aggregation type. event_times: np.ndarray | None = None #: ATT estimates specific to each event time value. att_by_event: np.ndarray | None = None #: Standard errors specific to each event time value. se_by_event: np.ndarray | None = None #: Critical value for uniform confidence bands. critical_value: float | None = None #: Influence functions for overall and event-specific ATTs. influence_func: dict | None = None #: Minimum event time. min_event_time: int | None = None #: Maximum event time. max_event_time: int | None = None #: Balanced event time threshold. balance_event: int | None = None #: Original group-time ATT result object. att_gt_result: object | None = None @property def __maketables_coef_table__(self): """Return canonical coefficient table for maketables.""" pte_params = getattr(self.att_gt_result, "pte_params", None) target = getattr(pte_params, "target_parameter", "level") overall_label = "Overall ACRT" if target == "slope" else "Overall ATT" names = [overall_label] estimates = [self.overall_att] se = [self.overall_se] if self.event_times is not None and self.att_by_event is not None and self.se_by_event is not None: prefix = "Event" if self.aggregation_type == "dynamic" else "Group" names.extend(make_effect_names(self.event_times, prefix=prefix)) estimates.extend(np.asarray(self.att_by_event, dtype=float).tolist()) se.extend(np.asarray(self.se_by_event, dtype=float).tolist()) alpha = float(getattr(pte_params, "alp", 0.05)) crit = self.critical_value if self.critical_value is not None else None if self.event_times is not None and crit is not None: z_crit = stats.norm.ppf(1 - alpha / 2) event_crit = np.full(len(self.event_times), crit) crit = np.concatenate([[z_crit], event_crit]) return build_coef_table_with_ci(names, estimates, se, alpha=alpha, critical_values=crit) def __maketables_stat__(self, key: str) -> int | float | str | None: """Return model-level statistics for maketables.""" pte_params = getattr(self.att_gt_result, "pte_params", None) if key == "N": if isinstance(self.influence_func, dict) and self.influence_func.get("overall") is not None: return int(np.asarray(self.influence_func["overall"]).shape[0]) return _n_obs_from_pte_params(pte_params) if key == "aggregation": return self.aggregation_type if key == "se_type": return se_type_label(True) if key == "control_group": return getattr(pte_params, "control_group", None) if key == "est_method": return getattr(pte_params, "gt_type", None) return None @property def __maketables_depvar__(self) -> str: """Return dependent variable label for maketables.""" pte_params = getattr(self.att_gt_result, "pte_params", None) return str(getattr(pte_params, "yname", "Continuous-Treatment ATT")) @property def __maketables_fixef_string__(self) -> str | None: """Continuous DiD output does not report fixed-effects formulas.""" return None @property def __maketables_vcov_info__(self) -> dict[str, str | None]: """Return variance-covariance metadata.""" return vcov_info_from_bootstrap(is_bootstrap=True) @property def __maketables_stat_labels__(self) -> dict[str, str]: """Return custom labels for model-level statistics.""" return { "aggregation": "Aggregation", "control_group": "Control Group", "est_method": "Estimation Method", } @property def __maketables_default_stat_keys__(self) -> list[str]: """Default model-level stats to display in ETable.""" keys = ["aggregation", "se_type", "control_group", "est_method"] if self.__maketables_stat__("N") is not None: keys.insert(0, "N") return keys
[docs] @dataclass class GroupTimeATTResult: """Container for group-time average treatment effect results. This class implements the ``maketables`` plug-in interface for publication-quality tables. See :ref:`publication_tables`. Attributes ---------- groups : np.ndarray Which group (defined by period first treated) each group-time ATT is for. times : np.ndarray Which time period each group-time ATT is for. att : np.ndarray The group-time average treatment effects for each group-time combination. vcov_analytical : np.ndarray Analytical estimator for the asymptotic variance-covariance matrix. se : np.ndarray Standard errors for group-time ATTs. If bootstrap used, provides bootstrap-based SE. critical_value : float Critical value - simultaneous if obtaining simultaneous confidence bands, otherwise based on pointwise normal approximation. influence_func : np.ndarray The influence function for estimating group-time average treatment effects. n_units : int The number of unique cross-sectional units. wald_stat : float, optional The Wald statistic for pre-testing the common trends assumption. wald_pvalue : float, optional The p-value of the Wald statistic for pre-testing common trends. cband : bool Whether uniform confidence band was computed. alpha : float The significance level. pte_params : object The PTE parameters object containing estimation settings. extra_gt_returns : list List of extra returns from gt-specific calculations. """ #: Which group (defined by period first treated) each group-time ATT is for. groups: np.ndarray #: Which time period each group-time ATT is for. times: np.ndarray #: Group-time average treatment effects. att: np.ndarray #: Analytical estimator for the asymptotic variance-covariance matrix. vcov_analytical: np.ndarray #: Standard errors for group-time ATTs. se: np.ndarray #: Critical value for confidence intervals. critical_value: float #: Influence function for estimating group-time average treatment effects. influence_func: np.ndarray #: Number of unique cross-sectional units. n_units: int #: Wald statistic for pre-testing common trends. wald_stat: float | None = None #: P-value of the Wald statistic for pre-testing common trends. wald_pvalue: float | None = None #: Whether uniform confidence band was computed. cband: bool = True #: Significance level. alpha: float = 0.05 #: PTE parameters object containing estimation settings. pte_params: object | None = None #: Extra returns from group-time calculations. extra_gt_returns: list | None = None @property def att_gt(self): """Alias for att field to maintain compatibility with aggte.""" return self.att @property def se_gt(self): """Alias for se field to maintain compatibility with aggte.""" return self.se @property def estimation_params(self): """Return estimation parameters for aggte compatibility.""" return { "bootstrap": True, "biters": 999, "uniform_bands": self.cband, "alpha": self.alpha, } @property def G(self): """Unit-level group assignments (not tracked in continuous DiD).""" return None @property def __maketables_coef_table__(self): """Return canonical coefficient table for maketables.""" names = [ f"ATT(g={format_effect_value(g)}, t={format_effect_value(t)})" for g, t in zip(self.groups, self.times, strict=False) ] crit = self.critical_value if self.critical_value is not None else None return build_coef_table_with_ci(names, self.att, self.se, alpha=float(self.alpha), critical_values=crit) def __maketables_stat__(self, key: str) -> int | float | str | None: """Return model-level statistics for maketables.""" pte_params = self.pte_params if key == "N": return int(self.n_units) if key == "se_type": return se_type_label(True) if key == "control_group": return getattr(pte_params, "control_group", None) return None @property def __maketables_depvar__(self) -> str: """Return dependent variable label for maketables.""" return str(getattr(self.pte_params, "yname", "ATT(g,t)")) @property def __maketables_fixef_string__(self) -> str | None: """Continuous DiD group-time output does not report fixed-effects formulas.""" return None @property def __maketables_vcov_info__(self) -> dict[str, str | None]: """Return variance-covariance metadata.""" return vcov_info_from_bootstrap(is_bootstrap=True) @property def __maketables_stat_labels__(self) -> dict[str, str]: """Return custom labels for model-level statistics.""" return {"control_group": "Control Group"} @property def __maketables_default_stat_keys__(self) -> list[str]: """Default model-level stats to display in ETable.""" return ["N", "se_type", "control_group"]
class PteEmpBootResult(NamedTuple): """Container for empirical bootstrap results. Attributes ---------- attgt_results : pl.DataFrame ATT(g,t) estimates with standard errors. overall_results : dict Overall ATT estimate and standard error. group_results : pl.DataFrame | None Group-specific ATT estimates and standard errors. dyn_results : pl.DataFrame | None Dynamic (event-time) ATT estimates and standard errors. extra_gt_returns : list | None Extra returns from group-time calculations. """ #: ATT(g,t) estimates with standard errors. attgt_results: pl.DataFrame #: Overall ATT estimate and standard error. overall_results: dict #: Group-specific ATT estimates and standard errors. group_results: pl.DataFrame | None = None #: Dynamic (event-time) ATT estimates and standard errors. dyn_results: pl.DataFrame | None = None #: Extra returns from group-time calculations. extra_gt_returns: list | None = None
[docs] class DoseResult(NamedTuple): """Container for continuous treatment dose-response results. This class implements the ``maketables`` plug-in interface for publication-quality tables. See :ref:`publication_tables`. Attributes ---------- dose : np.ndarray Vector containing the values of the dose used in estimation. overall_att : float Estimate of the overall ATT, the mean of ATT(D) given D > 0. overall_att_se : float The standard error of the estimate of overall_att. overall_att_inf_func : np.ndarray The influence function for estimating overall_att. overall_acrt : float Estimate of the overall ACRT, the mean of ACRT(D|D) given D > 0. overall_acrt_se : float The standard error for the estimate of overall_acrt. overall_acrt_inf_func : np.ndarray The influence function for estimating overall_acrt. att_d : np.ndarray Estimates of ATT(d) for each value of dose. att_d_se : np.ndarray Standard error of ATT(d) for each value of dose. att_d_crit_val : float Critical value to produce pointwise or uniform confidence interval for ATT(d). att_d_inf_func : np.ndarray Matrix containing the influence function from estimating ATT(d). acrt_d : np.ndarray Estimates of ACRT(d) for each value of dose. acrt_d_se : np.ndarray Standard error of ACRT(d) for each value of dose. acrt_d_crit_val : float Critical value to produce pointwise or uniform confidence interval for ACRT(d). acrt_d_inf_func : np.ndarray Matrix containing the influence function from estimating ACRT(d). pte_params : object A PTEParams object containing other parameters passed to the function. """ #: Values of the dose used in estimation. dose: np.ndarray #: Estimate of the overall ATT. overall_att: float | None = None #: Standard error of the overall ATT estimate. overall_att_se: float | None = None #: Influence function for estimating overall ATT. overall_att_inf_func: np.ndarray | None = None #: Estimate of the overall ACRT. overall_acrt: float | None = None #: Standard error of the overall ACRT estimate. overall_acrt_se: float | None = None #: Influence function for estimating overall ACRT. overall_acrt_inf_func: np.ndarray | None = None #: Estimates of ATT(d) for each value of dose. att_d: np.ndarray | None = None #: Standard errors of ATT(d) for each value of dose. att_d_se: np.ndarray | None = None #: Critical value for ATT(d) confidence intervals. att_d_crit_val: float | None = None #: Influence function matrix for ATT(d). att_d_inf_func: np.ndarray | None = None #: Estimates of ACRT(d) for each value of dose. acrt_d: np.ndarray | None = None #: Standard errors of ACRT(d) for each value of dose. acrt_d_se: np.ndarray | None = None #: Critical value for ACRT(d) confidence intervals. acrt_d_crit_val: float | None = None #: Influence function matrix for ACRT(d). acrt_d_inf_func: np.ndarray | None = None #: PTEParams object containing estimation settings. pte_params: object | None = None @property def __maketables_coef_table__(self): """Return canonical coefficient table for maketables.""" alpha = float(getattr(self.pte_params, "alp", 0.05)) z_crit = stats.norm.ppf(1 - alpha / 2) names: list[str] = [] estimates: list[float] = [] se: list[float] = [] crit_vals: list[float] = [] if self.overall_att is not None and self.overall_att_se is not None: names.append("Overall ATT") estimates.append(float(self.overall_att)) se.append(float(self.overall_att_se)) crit_vals.append(z_crit) if self.overall_acrt is not None and self.overall_acrt_se is not None: names.append("Overall ACRT") estimates.append(float(self.overall_acrt)) se.append(float(self.overall_acrt_se)) crit_vals.append(z_crit) att_d_cv = self.att_d_crit_val if self.att_d_crit_val is not None else z_crit if self.att_d is not None and self.att_d_se is not None and self.dose is not None: for dose, effect, std_error in zip(self.dose, self.att_d, self.att_d_se, strict=False): names.append(f"ATT(d={format_effect_value(dose)})") estimates.append(float(effect)) se.append(float(std_error)) crit_vals.append(att_d_cv) acrt_d_cv = self.acrt_d_crit_val if self.acrt_d_crit_val is not None else z_crit if self.acrt_d is not None and self.acrt_d_se is not None and self.dose is not None: for dose, effect, std_error in zip(self.dose, self.acrt_d, self.acrt_d_se, strict=False): names.append(f"ACRT(d={format_effect_value(dose)})") estimates.append(float(effect)) se.append(float(std_error)) crit_vals.append(acrt_d_cv) return build_coef_table_with_ci(names, estimates, se, alpha=alpha, critical_values=crit_vals) def __maketables_stat__(self, key: str) -> int | float | str | None: """Return model-level statistics for maketables.""" if key == "N": return _n_obs_from_pte_params(self.pte_params) if key == "se_type": return se_type_label(True) if key == "control_group": return getattr(self.pte_params, "control_group", None) if key == "dose_est_method": return getattr(self.pte_params, "dose_est_method", None) return None @property def __maketables_depvar__(self) -> str: """Return dependent variable label for maketables.""" return str(getattr(self.pte_params, "yname", "Dose Response")) @property def __maketables_fixef_string__(self) -> str | None: """Continuous dose-response output does not report fixed-effects formulas.""" return None @property def __maketables_vcov_info__(self) -> dict[str, str | None]: """Return variance-covariance metadata.""" return vcov_info_from_bootstrap(is_bootstrap=True) @property def __maketables_stat_labels__(self) -> dict[str, str]: """Return custom labels for model-level statistics.""" return {"control_group": "Control Group", "dose_est_method": "Dose Estimation"} @property def __maketables_default_stat_keys__(self) -> list[str]: """Default model-level stats to display in ETable.""" keys = ["se_type", "control_group", "dose_est_method"] if self.__maketables_stat__("N") is not None: keys.insert(0, "N") return keys
def _n_obs_from_pte_params(params: PTEParams | None) -> int | None: """Extract a sensible observation count from PTE parameters when available.""" if params is None: return None data = getattr(params, "data", None) idname = getattr(params, "idname", None) if data is None: return None if idname is not None and isinstance(data, pl.DataFrame) and idname in data.columns: return int(data[idname].n_unique()) try: return len(data) except TypeError: return None