Source code for moderndid.didtriple.ddd

"""Main wrapper for Triple Difference-in-Differences estimation."""

import numpy as np
import polars as pl

from moderndid.core.dataframe import to_polars
from moderndid.core.preprocessing import preprocess_ddd_2periods

from .estimators.ddd_mp import ddd_mp
from .estimators.ddd_mp_rc import ddd_mp_rc
from .estimators.ddd_panel import ddd_panel
from .estimators.ddd_rc import _ddd_rc_2period
from .utils import add_intercept, detect_multiple_periods, detect_rcs_mode, get_covariate_names


[docs] def ddd( data, yname, tname, idname=None, gname=None, pname=None, xformla=None, control_group="nevertreated", base_period="universal", est_method="dr", weightsname=None, boot=False, boot_type="multiplier", biters=1000, cluster=None, alpha=0.05, trim_level=0.995, panel=True, allow_unbalanced_panel=False, random_state=None, n_jobs=1, n_partitions=None, max_cohorts=None, backend=None, ): r"""Compute the doubly robust Triple Difference-in-Differences estimator for the ATT. Implements triple difference-in-differences (DDD) estimation following [1]_. DDD extends standard DiD by incorporating a partition variable :math:`Q` that identifies eligible units within treatment-enabling groups :math:`S`, allowing for violations of traditional DiD parallel trends as long as these violations are stable across groups. Let :math:`S_i` denote the period when treatment is enabled for unit :math:`i`'s group, and :math:`Q_i \in \{0,1\}` indicate eligibility within that group. The group-time average treatment effect measures the effect among eligible units in group :math:`g` at time :math:`t` .. math:: ATT(g,t) = \mathbb{E}[Y_{i,t}(g) - Y_{i,t}(\infty) \mid S_i = g, Q_i = 1]. Identification relies on a DDD conditional parallel trends assumption that allows for differential trends between eligible and ineligible units, provided these differentials are stable across treatment-enabling groups. For groups :math:`g` and :math:`g'` where :math:`g' > \max\{g,t\}` .. math:: &\mathbb{E}[\Delta Y(\infty) \mid S=g, Q=1, X] - \mathbb{E}[\Delta Y(\infty) \mid S=g, Q=0, X] \\ &= \mathbb{E}[\Delta Y(\infty) \mid S=g', Q=1, X] - \mathbb{E}[\Delta Y(\infty) \mid S=g', Q=0, X], where :math:`\Delta Y(\infty) = Y_t(\infty) - Y_{t-1}(\infty)` denotes the change in untreated potential outcomes. This assumption does not impose standard DiD parallel trends within or across groups, making DDD appealing when such assumptions are implausible. Parameters ---------- data : DataFrame Data in long format. Accepts any object implementing the Arrow PyCapsule Interface (``__arrow_c_stream__``), including polars, pandas, pyarrow Table, and cudf DataFrames. yname : str Name of outcome variable column. tname : str Name of time period column. idname : str, optional Name of unit identifier column. Required for panel data. For repeated cross-section data (panel=False), this can be omitted and a row index will be used automatically. gname : str Name of treatment group column. For 2-period data, this should be 0 for never-treated and a positive value for treated units. For multi-period data, this is the first period when treatment is enabled for the unit's group (use 0 or np.inf for never-treated units). pname : str Name of partition/eligibility column (1=eligible, 0=ineligible). This identifies which units within a treatment group are actually eligible to receive the treatment effect. xformla : str, optional Formula for covariates in the form "~ x1 + x2 + x3". If None, only an intercept is used. control_group : {"nevertreated", "notyettreated"}, default="nevertreated" Which units to use as controls in multi-period settings. This parameter is ignored for 2-period data. base_period : {"universal", "varying"}, default="universal" Base period selection for multi-period settings. This parameter is ignored for 2-period data. est_method : {"dr", "reg", "ipw"}, default="dr" Estimation method: doubly robust, regression, or IPW. weightsname : str, optional Name of the column containing observation weights. boot : bool, default=False Whether to use bootstrap for inference. boot_type : {"multiplier", "weighted"}, default="multiplier" Type of bootstrap for 2-period data (only used if boot=True). Multi-period data always uses multiplier bootstrap. biters : int, default=1000 Number of bootstrap repetitions (only used if boot=True). cluster : str, optional Name of the clustering variable for clustered standard errors. Currently only supported for 2-period data with bootstrap. alpha : float, default=0.05 Significance level for confidence intervals. trim_level : float, default=0.995 Trimming level for propensity scores. Only used for repeated cross-section data (panel=False). panel : bool, default=True Whether the data is panel data (True) or repeated cross-section data (False). Panel data has the same units observed across time periods. Repeated cross-section data has different samples in each period. allow_unbalanced_panel : bool, default=False If True and panel=True, allows unbalanced panel data. For multi-period settings, estimation stays in panel mode (preserving panel efficiency) while handling units that appear in different subsets of periods. For 2-period settings, unbalanced data falls back to repeated cross-section mode. If the panel is unbalanced and this is False, an error will be raised. random_state : int, Generator, optional Random seed for reproducibility of bootstrap. n_jobs : int, default=1 Number of parallel jobs for group-time estimation in multi-period settings. 1 = sequential (default), -1 = all cores, >1 = that many workers. Ignored for 2-period data. n_partitions : int or None, default=None Number of Dask partitions per cell. Only used when ``data`` is a Dask DataFrame; ignored for non-Dask inputs. max_cohorts : int or None, default=None Maximum number of treatment cohorts to process in parallel when using the Dask distributed backend. Each cohort's group-time cells are computed concurrently within a thread, so this controls how many cohorts share the cluster simultaneously. Higher values increase throughput but require more memory on workers to hold the per-cohort wide-pivoted DataFrames. When ``None``, defaults to the number of Dask workers. Ignored for non-Dask inputs. For best performance, set this equal to the total number of treatment cohorts so that all cohorts run concurrently. Reduce the value if the cluster runs out of memory. backend : {"numpy", "cupy"} or None, default=None Array backend to use for this call only. When set, the backend is activated for the duration of this call and reverted automatically when the call returns. ``None`` (the default) uses whatever backend is currently active (see :func:`~moderndid.set_backend`). Ignored when ``data`` is a Dask DataFrame. Returns ------- DDDPanelResult, DDDRCResult, DDDMultiPeriodResult, or DDDMultiPeriodRCResult For 2-period panel data (panel=True), returns DDDPanelResult containing: - **att**: The DDD point estimate - **se**: Standard error - **uci**, **lci**: Confidence interval bounds - **boots**: Bootstrap draws (if requested) - **att_inf_func**: Influence function - **did_atts**: Individual DiD ATT estimates - **subgroup_counts**: Number of units per subgroup - **args**: Estimation arguments For 2-period repeated cross-section data (panel=False), returns DDDRCResult with the same structure. For multi-period panel data, returns DDDMultiPeriodResult containing: - **att**: Array of ATT(g,t) point estimates - **se**: Standard errors for each ATT(g,t) - **uci**, **lci**: Confidence interval bounds - **groups**, **times**: Treatment cohort and time for each estimate - **glist**, **tlist**: Unique cohorts and periods - **inf_func_mat**: Influence function matrix - **n**: Number of units - **args**: Estimation arguments For multi-period repeated cross-section data, returns DDDMultiPeriodRCResult with the same structure. Examples -------- We can generate synthetic data for a 2-period DDD setup using the ``gen_ddd_2periods`` function. The data contains treatment status (``state``), eligibility within treatment groups (``partition``), and covariates. .. ipython:: In [1]: import numpy as np ...: from moderndid import ddd, gen_ddd_2periods ...: ...: dgp = gen_ddd_2periods(n=1000, dgp_type=1, random_state=42) ...: df = dgp["data"] ...: df.head() Now we can compute the DDD estimate using the doubly robust estimator. The ``pname`` parameter identifies which units within a treatment group are eligible to receive treatment, which is the key distinction from standard DiD. .. ipython:: :okwarning: In [2]: result = ddd( ...: data=df, ...: yname="y", ...: tname="time", ...: idname="id", ...: gname="state", ...: pname="partition", ...: xformla="~ cov1 + cov2 + cov3 + cov4", ...: est_method="dr", ...: ) ...: result The function automatically detects multi-period data with staggered treatment adoption. When there are more than two time periods or treatment cohorts, it returns group-time ATT estimates that can be aggregated using ``agg_ddd``. .. ipython:: :okwarning: In [3]: from moderndid import gen_ddd_mult_periods ...: ...: dgp_mp = gen_ddd_mult_periods(n=500, dgp_type=1, random_state=42) ...: result_mp = ddd( ...: data=dgp_mp["data"], ...: yname="y", ...: tname="time", ...: idname="id", ...: gname="group", ...: pname="partition", ...: control_group="nevertreated", ...: base_period="varying", ...: est_method="dr", ...: ) ...: result_mp The function also supports repeated cross-section data where different units are sampled in each time period. Set ``panel=False`` to use this mode. .. ipython:: :okwarning: In [4]: dgp_rcs = gen_ddd_2periods(n=2000, dgp_type=1, panel=False, random_state=42) ...: result_rcs = ddd( ...: data=dgp_rcs["data"], ...: yname="y", ...: tname="time", ...: gname="state", ...: pname="partition", ...: xformla="~ cov1 + cov2 + cov3 + cov4", ...: est_method="dr", ...: panel=False, ...: ) ...: result_rcs For multi-period repeated cross-section data with staggered treatment adoption, set ``panel=False`` with multiple time periods. .. ipython:: :okwarning: In [5]: dgp_mp_rcs = gen_ddd_mult_periods(n=500, dgp_type=1, panel=False, random_state=42) ...: result_mp_rcs = ddd( ...: data=dgp_mp_rcs["data"], ...: yname="y", ...: tname="time", ...: gname="group", ...: pname="partition", ...: control_group="notyettreated", ...: base_period="universal", ...: est_method="dr", ...: panel=False, ...: ) ...: result_mp_rcs Notes ----- The DDD estimator identifies treatment effects in settings where units must satisfy two criteria to be treated: belonging to a group that enables treatment (e.g., a state that passes a policy) and being in an eligible partition (e.g., women eligible for maternity benefits). This allows for violations of standard DiD parallel trends assumptions, as long as these violations are stable across groups. When ``est_method="dr"`` (the default), the function implements doubly robust DDD estimators that combine outcome regression and inverse probability weighting. These estimators are consistent if either the outcome model or the propensity score model is correctly specified. See Also -------- ddd_panel : Two-period DDD estimator for panel data. ddd_rc : Two-period DDD estimator for repeated cross-section data. ddd_mp : Multi-period DDD estimator for staggered adoption with panel data. ddd_mp_rc : Multi-period DDD estimator for staggered adoption with RCS data. agg_ddd : Aggregate group-time DDD effects. References ---------- .. [1] Ortiz-Villavicencio, M., & Sant'Anna, P. H. C. (2025). *Better Understanding Triple Differences Estimators.* arXiv preprint arXiv:2505.09942. https://arxiv.org/abs/2505.09942 """ if backend is not None: from moderndid.cupy.backend import use_backend with use_backend(backend): return ddd( data=data, yname=yname, tname=tname, idname=idname, gname=gname, pname=pname, xformla=xformla, control_group=control_group, base_period=base_period, est_method=est_method, weightsname=weightsname, boot=boot, boot_type=boot_type, biters=biters, cluster=cluster, alpha=alpha, trim_level=trim_level, panel=panel, allow_unbalanced_panel=allow_unbalanced_panel, random_state=random_state, n_jobs=n_jobs, n_partitions=n_partitions, max_cohorts=max_cohorts, backend=None, ) from moderndid.dask._utils import is_dask_collection if is_dask_collection(data): from moderndid.dask._ddd import dask_ddd return dask_ddd( data, yname, tname, idname, gname, pname, xformla, control_group=control_group, base_period=base_period, est_method=est_method, weightsname=weightsname, boot=boot, biters=biters, cluster=cluster, alpha=alpha, trim_level=trim_level, panel=panel, allow_unbalanced_panel=allow_unbalanced_panel, random_state=random_state, n_partitions=n_partitions, max_cohorts=max_cohorts, backend=backend, ) from moderndid.spark._utils import is_spark_dataframe if is_spark_dataframe(data): from moderndid.spark._ddd import spark_ddd return spark_ddd( data, yname, tname, idname, gname, pname, xformla, control_group=control_group, base_period=base_period, est_method=est_method, weightsname=weightsname, boot=boot, biters=biters, cluster=cluster, alpha=alpha, trim_level=trim_level, panel=panel, allow_unbalanced_panel=allow_unbalanced_panel, random_state=random_state, n_partitions=n_partitions, max_cohorts=max_cohorts, backend=backend, ) if gname is None: raise ValueError("gname is required. Please specify the treatment group column.") if pname is None: raise ValueError("pname is required. Please specify the partition/eligibility column.") if panel and idname is None: raise ValueError("idname must be provided when panel=True.") if est_method not in ("dr", "reg", "ipw"): raise ValueError(f"est_method='{est_method}' is not valid. Must be 'dr', 'reg', or 'ipw'.") if control_group not in ("nevertreated", "notyettreated"): raise ValueError(f"control_group='{control_group}' is not valid. Must be 'nevertreated' or 'notyettreated'.") if base_period not in ("universal", "varying"): raise ValueError(f"base_period='{base_period}' is not valid. Must be 'universal' or 'varying'.") if not 0 < alpha < 1: raise ValueError(f"alpha={alpha} is not valid. Must be between 0 and 1 (exclusive).") if not isinstance(biters, int) or biters < 1: raise ValueError(f"biters={biters} is not valid. Must be a positive integer.") if boot_type not in ("weighted", "multiplier"): raise ValueError(f"boot_type='{boot_type}' is not valid. Must be 'weighted' or 'multiplier'.") if not 0 < trim_level < 1: raise ValueError(f"trim_level={trim_level} is not valid. Must be between 0 and 1 (exclusive).") if not isinstance(n_jobs, int) or (n_jobs < 1 and n_jobs != -1): raise ValueError(f"n_jobs={n_jobs} is not valid. Must be a positive integer or -1 for all cores.") is_rcs = detect_rcs_mode(data, tname, idname, panel, allow_unbalanced_panel) data = to_polars(data) if is_rcs and idname is None: data = data.with_columns(pl.Series("_row_id", np.arange(len(data)))) idname = "_row_id" multiple_periods = detect_multiple_periods(data, tname, gname) if multiple_periods: covariate_cols = get_covariate_names(xformla) if covariate_cols is not None: missing_covs = [c for c in covariate_cols if c not in data.columns] if missing_covs: raise ValueError(f"Covariates not found in data: {missing_covs}") use_panel = panel and idname is not None and idname != "_row_id" if not use_panel: return ddd_mp_rc( data=data, y_col=yname, time_col=tname, id_col=idname, group_col=gname, partition_col=pname, covariate_cols=covariate_cols, control_group=control_group, base_period=base_period, est_method=est_method, boot=boot, biters=biters, cband=False, cluster=cluster, alpha=alpha, trim_level=trim_level, random_state=random_state, n_jobs=n_jobs, ) return ddd_mp( data=data, y_col=yname, time_col=tname, id_col=idname, group_col=gname, partition_col=pname, covariate_cols=covariate_cols, control_group=control_group, base_period=base_period, est_method=est_method, boot=boot, biters=biters, cband=False, cluster=cluster, alpha=alpha, random_state=random_state, n_jobs=n_jobs, ) if is_rcs: return _ddd_rc_2period( data=data, yname=yname, tname=tname, gname=gname, pname=pname, xformla=xformla, weightsname=weightsname, est_method=est_method, boot=boot, boot_type=boot_type, biters=biters, alpha=alpha, trim_level=trim_level, random_state=random_state, ) ddd_data = preprocess_ddd_2periods( data=data, yname=yname, tname=tname, idname=idname, gname=gname, pname=pname, xformla=xformla, est_method=est_method, weightsname=weightsname, boot=boot, boot_type=boot_type, n_boot=biters, cluster=cluster, alp=alpha, inf_func=True, ) covariates_with_intercept = add_intercept(ddd_data.covariates) return ddd_panel( y1=ddd_data.y1, y0=ddd_data.y0, subgroup=ddd_data.subgroup, covariates=covariates_with_intercept, i_weights=ddd_data.weights, est_method=est_method, boot=ddd_data.config.boot, boot_type=ddd_data.config.boot_type.value, biters=ddd_data.config.n_boot, influence_func=True, alpha=ddd_data.config.alp, random_state=random_state, )