Source code for moderndid.did.container

"""Result containers for multi-period DiD estimators."""

from typing import Literal, NamedTuple

import numpy as np
from scipy import stats

from moderndid.core.maketables import (
    build_coef_table_with_ci,
    control_group_label,
    est_method_label,
    make_effect_names,
    make_group_time_names,
    se_type_label,
)
from moderndid.core.result import extract_n_obs, extract_vcov_info


[docs] class AGGTEResult(NamedTuple): """Container for aggregated treatment effect parameters. This class implements the ``maketables`` plug-in interface for publication-quality tables. See :ref:`publication_tables`. Attributes ---------- overall_att : float The estimated overall average treatment effect on the treated. overall_se : float Standard error for overall ATT. aggregation_type : {'simple', 'dynamic', 'group', 'calendar'} Type of aggregation performed. event_times : np.ndarray, optional Event/group/time values depending on aggregation type: - For dynamic effects: length of exposure - For group effects: treatment group indicators - For calendar effects: time periods att_by_event : ndarray, optional ATT estimates specific to each event time value. se_by_event : ndarray, optional Standard errors specific to each event time value. critical_values : ndarray, optional Critical values for uniform confidence bands. influence_func : ndarray, optional Influence function of the aggregated parameters. - For overall ATT: 1D array of length n_units - For dynamic/group/calendar: 2D array of shape (n_units, n_events) containing influence functions for each event-specific ATT influence_func_overall : ndarray, optional Influence function for the overall ATT (1D array of length n_units). This is stored separately for compatibility with both aggregation types. min_event_time : int, optional Minimum event time (for dynamic effects). max_event_time : int, optional Maximum event time (for dynamic effects). balanced_event_threshold : int, optional Balanced event time threshold. estimation_params : dict Dictionary containing DID estimation parameters including: - alpha: significance level - bootstrap: whether bootstrap was used - uniform_bands: whether uniform confidence bands were computed - control_group: 'nevertreated' or 'notyettreated' - anticipation_periods: number of anticipation periods - estimation_method: estimation method used call_info : dict Information about the function call that created this object. """ #: Estimated overall average treatment effect on the treated. overall_att: float #: Standard error for overall ATT. overall_se: float #: Type of aggregation performed. aggregation_type: Literal["simple", "dynamic", "group", "calendar"] #: Event/group/time values depending on aggregation type. event_times: np.ndarray | None = None #: ATT estimates specific to each event time value. att_by_event: np.ndarray | None = None #: Standard errors specific to each event time value. se_by_event: np.ndarray | None = None #: Critical values for uniform confidence bands. critical_values: np.ndarray | None = None #: Influence function of the aggregated parameters. influence_func: np.ndarray | None = None #: Influence function for the overall ATT. influence_func_overall: np.ndarray | None = None #: Minimum event time. min_event_time: int | None = None #: Maximum event time. max_event_time: int | None = None #: Balanced event time threshold. balanced_event_threshold: int | None = None #: DID estimation parameters. estimation_params: dict = {} #: Information about the function call that created this object. call_info: dict = {} @property def __maketables_coef_table__(self): """Return canonical coefficient table for maketables.""" alpha = float(self.estimation_params.get("alpha", 0.05)) names = ["Overall ATT"] estimates = [self.overall_att] se = [self.overall_se] crit = None if self.event_times is not None and self.att_by_event is not None and self.se_by_event is not None: z_crit = stats.norm.ppf(1 - alpha / 2) prefix = {"dynamic": "Event", "group": "Group", "calendar": "Time"}.get(self.aggregation_type, "Effect") names.extend(make_effect_names(self.event_times, prefix=prefix)) estimates.extend(np.asarray(self.att_by_event, dtype=float).tolist()) se.extend(np.asarray(self.se_by_event, dtype=float).tolist()) if self.critical_values is not None: event_crit = np.asarray(self.critical_values, dtype=float) else: event_crit = np.full(len(self.event_times), z_crit) crit = np.concatenate([[z_crit], event_crit]) return build_coef_table_with_ci(names, estimates, se, alpha=alpha, critical_values=crit) def __maketables_stat__(self, key: str) -> int | float | str | None: """Return model-level statistics for maketables.""" if key == "N": return extract_n_obs( self.influence_func_overall, self.influence_func, params=self.estimation_params, ) if key == "n_units": return extract_n_obs( self.influence_func_overall, self.influence_func, params=self.estimation_params, keys=("n_units",), ) if key == "aggregation": return self.aggregation_type if key == "se_type": return se_type_label(bool(self.estimation_params.get("bootstrap", False))) if key == "control_group": return control_group_label(self.estimation_params.get("control_group")) if key == "estimation_method": return est_method_label(self.estimation_params.get("estimation_method")) return None @property def __maketables_depvar__(self) -> str: """Return dependent variable label for maketables.""" return str(self.estimation_params.get("yname", "Aggregated ATT")) @property def __maketables_fixef_string__(self) -> str | None: """AGGTE output does not report fixed-effects formulas.""" return None @property def __maketables_vcov_info__(self) -> dict[str, str | None]: """Return variance-covariance metadata.""" return extract_vcov_info(self.estimation_params) @property def __maketables_stat_labels__(self) -> dict[str, str]: """Return custom labels for model-level statistics.""" return { "n_units": "Units", "aggregation": "Aggregation", "control_group": "Control Group", "estimation_method": "Estimation Method", } @property def __maketables_default_stat_keys__(self) -> list[str]: """Default model-level stats to display in ETable.""" keys = ["aggregation", "se_type", "control_group"] if self.__maketables_stat__("N") is not None: keys.insert(0, "N") if self.__maketables_stat__("n_units") is not None: idx = keys.index("N") + 1 if "N" in keys else 0 keys.insert(idx, "n_units") if self.estimation_params.get("estimation_method") is not None: keys.append("estimation_method") return keys
[docs] class MPResult(NamedTuple): """Container for group-time average treatment effect results. This class implements the ``maketables`` plug-in interface for publication-quality tables. See :ref:`publication_tables`. Attributes ---------- groups : ndarray Which group (defined by period first treated) each group-time ATT is for. times : ndarray Which time period each group-time ATT is for. att_gt : ndarray The group-time average treatment effects for each group-time combination. vcov_analytical : ndarray Analytical estimator for the asymptotic variance-covariance matrix. se_gt : ndarray Standard errors for group-time ATTs. If bootstrap used, provides bootstrap-based SE. critical_value : float Critical value - simultaneous if obtaining simultaneous confidence bands, otherwise based on pointwise normal approximation. influence_func : ndarray The influence function for estimating group-time average treatment effects. n_units : int, optional The number of unique cross-sectional units. wald_stat : float, optional The Wald statistic for pre-testing the common trends assumption. wald_pvalue : float, optional The p-value of the Wald statistic for pre-testing common trends. aggregate_effects : object, optional An aggregate treatment effects object. alpha : float The significance level (default 0.05). estimation_params : dict Dictionary containing DID estimation parameters including: - call_info: original function call information - control_group: 'nevertreated' or 'notyettreated' - anticipation_periods: number of anticipation periods - estimation_method: estimation method used - bootstrap: whether bootstrap was used - uniform_bands: whether simultaneous confidence bands were computed - G: unit-level group assignments - weights_ind: unit-level sampling weights """ #: Which group (defined by period first treated) each group-time ATT is for. groups: np.ndarray #: Which time period each group-time ATT is for. times: np.ndarray #: Group-time average treatment effects. att_gt: np.ndarray #: Analytical estimator for the asymptotic variance-covariance matrix. vcov_analytical: np.ndarray #: Standard errors for group-time ATTs. se_gt: np.ndarray #: Critical value for confidence intervals. critical_value: float #: Influence function for estimating group-time average treatment effects. influence_func: np.ndarray #: Number of unique cross-sectional units. n_units: int | None = None #: Wald statistic for pre-testing common trends. wald_stat: float | None = None #: P-value of the Wald statistic for pre-testing common trends. wald_pvalue: float | None = None #: Aggregate treatment effects object. aggregate_effects: object | None = None #: Significance level. alpha: float = 0.05 #: DID estimation parameters. estimation_params: dict = {} #: Unit-level group assignments. G: np.ndarray | None = None #: Unit-level sampling weights. weights_ind: np.ndarray | None = None @property def __maketables_coef_table__(self): """Return canonical coefficient table for maketables.""" names = make_group_time_names(self.groups, self.times, prefix="ATT") crit = self.critical_value if self.critical_value is not None else None return build_coef_table_with_ci(names, self.att_gt, self.se_gt, alpha=float(self.alpha), critical_values=crit) def __maketables_stat__(self, key: str) -> int | float | str | None: """Return model-level statistics for maketables.""" if key == "N": return int(self.n_units) if self.n_units is not None else None if key == "wald_pvalue": return self.wald_pvalue if key == "se_type": return se_type_label(bool(self.estimation_params.get("bootstrap", False))) if key == "control_group": return control_group_label(self.estimation_params.get("control_group")) return None @property def __maketables_depvar__(self) -> str: """Return dependent variable label for maketables.""" return str(self.estimation_params.get("yname", "ATT(g,t)")) @property def __maketables_fixef_string__(self) -> str | None: """Group-time ATT results do not report fixed-effects formulas.""" return None @property def __maketables_vcov_info__(self) -> dict[str, str | None]: """Return variance-covariance metadata.""" return extract_vcov_info(self.estimation_params) @property def __maketables_stat_labels__(self) -> dict[str, str]: """Return custom labels for model-level statistics.""" return {"wald_pvalue": "Pre-trends p-value", "control_group": "Control Group"} @property def __maketables_default_stat_keys__(self) -> list[str]: """Default model-level stats to display in ETable.""" keys = ["N", "se_type", "control_group"] if self.wald_pvalue is not None: keys.insert(1, "wald_pvalue") return keys
[docs] class MPPretestResult(NamedTuple): """Container for pre-test results of conditional parallel trends assumption. Attributes ---------- cvm_stat : float Cramer von Mises test statistic. cvm_boots : ndarray, optional Vector of bootstrapped Cramer von Mises test statistics. cvm_critval : float Cramer von Mises critical value. cvm_pval : float P-value for Cramer von Mises test. ks_stat : float Kolmogorov-Smirnov test statistic. ks_boots : ndarray, optional Vector of bootstrapped Kolmogorov-Smirnov test statistics. ks_critval : float Kolmogorov-Smirnov critical value. ks_pval : float P-value for Kolmogorov-Smirnov test. cluster_vars : list[str], optional Variables that were clustered on for the test. x_formula : str, optional Formula for the X variables used in the test. """ #: Cramer von Mises test statistic. cvm_stat: float #: Bootstrapped Cramer von Mises test statistics. cvm_boots: np.ndarray | None #: Cramer von Mises critical value. cvm_critval: float #: P-value for Cramer von Mises test. cvm_pval: float #: Kolmogorov-Smirnov test statistic. ks_stat: float #: Bootstrapped Kolmogorov-Smirnov test statistics. ks_boots: np.ndarray | None #: Kolmogorov-Smirnov critical value. ks_critval: float #: P-value for Kolmogorov-Smirnov test. ks_pval: float #: Variables that were clustered on for the test. cluster_vars: list[str] | None = None #: Formula for the X variables used in the test. x_formula: str | None = None
def aggte( overall_att, overall_se, aggregation_type="simple", event_times=None, att_by_event=None, se_by_event=None, critical_values=None, influence_func=None, influence_func_overall=None, min_event_time=None, max_event_time=None, balanced_event_threshold=None, estimation_params=None, call_info=None, ): """Create an aggregate treatment effect result object. Parameters ---------- overall_att : float The estimated overall ATT. overall_se : float Standard error for overall ATT. aggregation_type : {'simple', 'dynamic', 'group', 'calendar'}, default='simple' Type of aggregation performed. event_times : ndarray, optional Event/group/time values for disaggregated effects. att_by_event : ndarray, optional ATT estimates for each event time value. se_by_event : ndarray, optional Standard errors for each event time value. critical_values : ndarray, optional Critical values for confidence bands. influence_func : ndarray, optional Influence function of aggregated parameters. - For dynamic/group/calendar: 2D array of shape (n_units, n_events) - For simple: 1D array of length n_units influence_func_overall : ndarray, optional Influence function for the overall ATT (1D array). min_event_time : int, optional Minimum event time. max_event_time : int, optional Maximum event time. balanced_event_threshold : int, optional Balanced event time threshold. estimation_params : dict, optional DID estimation parameters. call_info : dict, optional Information about the function call. Returns ------- AGGTEResult NamedTuple containing aggregated treatment effect parameters. """ if aggregation_type not in ["simple", "dynamic", "group", "calendar"]: raise ValueError( f"Invalid aggregation_type: {aggregation_type}. Must be one of 'simple', 'dynamic', 'group', 'calendar'." ) if event_times is not None: n_events = len(event_times) if att_by_event is not None and len(att_by_event) != n_events: raise ValueError("att_by_event must have same length as event_times.") if se_by_event is not None and len(se_by_event) != n_events: raise ValueError("se_by_event must have same length as event_times.") if critical_values is not None and len(critical_values) != n_events: raise ValueError("critical_values must have same length as event_times.") if estimation_params is None: estimation_params = {} if call_info is None: call_info = {} return AGGTEResult( overall_att=overall_att, overall_se=overall_se, aggregation_type=aggregation_type, event_times=event_times, att_by_event=att_by_event, se_by_event=se_by_event, critical_values=critical_values, influence_func=influence_func, influence_func_overall=influence_func_overall, min_event_time=min_event_time, max_event_time=max_event_time, balanced_event_threshold=balanced_event_threshold, estimation_params=estimation_params, call_info=call_info, ) def mp( groups, times, att_gt, vcov_analytical, se_gt, critical_value, influence_func, n_units=None, wald_stat=None, wald_pvalue=None, aggregate_effects=None, alpha=0.05, estimation_params=None, G=None, weights_ind=None, ): """Create a multi-period result object for group-time ATTs. Parameters ---------- groups : ndarray Group indicators (defined by period first treated). times : ndarray Time period indicators. att_gt : ndarray Group-time average treatment effects. vcov_analytical : ndarray Analytical variance-covariance matrix estimator. se_gt : ndarray Standard errors for group-time ATTs. critical_value : float Critical value for confidence intervals. influence_func : ndarray Influence function for group-time ATTs. n_units : int, optional Number of unique cross-sectional units. wald_stat : float, optional Wald statistic for common trends test. wald_pvalue : float, optional P-value for common trends test. aggregate_effects : object, optional Aggregate treatment effects object. alpha : float, default=0.05 Significance level. estimation_params : dict, optional DID estimation parameters. G : ndarray, optional Unit-level group assignments (length n, where n is number of units). weights_ind : ndarray, optional Unit-level sampling weights (length n, where n is number of units). Returns ------- MPResult NamedTuple containing multi-period results. """ groups = np.asarray(groups) times = np.asarray(times) att_gt = np.asarray(att_gt) se_gt = np.asarray(se_gt) n_gt = len(groups) if len(times) != n_gt: raise ValueError("groups and times must have the same length.") if len(att_gt) != n_gt: raise ValueError("att_gt must have same length as groups and times.") if len(se_gt) != n_gt: raise ValueError("se_gt must have same length as groups and times.") if estimation_params is None: estimation_params = {} return MPResult( groups=groups, times=times, att_gt=att_gt, vcov_analytical=vcov_analytical, se_gt=se_gt, critical_value=critical_value, influence_func=influence_func, n_units=n_units, wald_stat=wald_stat, wald_pvalue=wald_pvalue, aggregate_effects=aggregate_effects, alpha=alpha, estimation_params=estimation_params, G=G, weights_ind=weights_ind, ) def mp_pretest( cvm_stat, cvm_critval, cvm_pval, ks_stat, ks_critval, ks_pval, cvm_boots=None, ks_boots=None, cluster_vars=None, x_formula=None, ): """Create a pre-test result object for conditional parallel trends assumption. Parameters ---------- cvm_stat : float Cramer von Mises test statistic. cvm_critval : float Cramer von Mises critical value. cvm_pval : float P-value for Cramer von Mises test. ks_stat : float Kolmogorov-Smirnov test statistic. ks_critval : float Kolmogorov-Smirnov critical value. ks_pval : float P-value for Kolmogorov-Smirnov test. cvm_boots : ndarray, optional Vector of bootstrapped Cramer von Mises test statistics. ks_boots : ndarray, optional Vector of bootstrapped Kolmogorov-Smirnov test statistics. cluster_vars : list[str], optional Variables that were clustered on for the test. x_formula : str, optional Formula for the X variables used in the test. Returns ------- MPPretestResult NamedTuple containing pre-test results. """ if cvm_boots is not None: cvm_boots = np.asarray(cvm_boots) if ks_boots is not None: ks_boots = np.asarray(ks_boots) return MPPretestResult( cvm_stat=cvm_stat, cvm_boots=cvm_boots, cvm_critval=cvm_critval, cvm_pval=cvm_pval, ks_stat=ks_stat, ks_boots=ks_boots, ks_critval=ks_critval, ks_pval=ks_pval, cluster_vars=cluster_vars, x_formula=x_formula, ) def summary_mp(result): """Print summary of a multi-period result. Parameters ---------- result : MPResult The multi-period result to summarize. Returns ------- str Formatted summary string. """ return str(result) def summary_mp_pretest(result): """Print summary of a pre-test result. Parameters ---------- result : MPPretestResult The pre-test result to summarize. Returns ------- str Formatted summary string. """ return str(result)