"""Doubly robust DDD estimator for multi-period repeated cross-section data with staggered adoption."""
from __future__ import annotations
import warnings
import numpy as np
import polars as pl
from scipy import stats
from moderndid.core.dataframe import to_polars
from moderndid.core.parallel import parallel_map
from ..bootstrap.mboot_ddd import mboot_ddd
from ..container import ATTgtRCResult, DDDMultiPeriodRCResult
from .ddd_mp import _gmm_aggregate
from .ddd_rc import ddd_rc
[docs]
def ddd_mp_rc(
data,
y_col,
time_col,
id_col,
group_col,
partition_col,
covariate_cols=None,
control_group="nevertreated",
base_period="universal",
est_method="dr",
boot=False,
biters=1000,
cband=False,
cluster=None,
alpha=0.05,
trim_level=0.995,
random_state=None,
n_jobs=1,
):
r"""Compute the multi-period doubly robust DDD estimator for the ATT with repeated cross-section data.
Implements the multi-period triple difference-in-differences estimator from [1]_
for repeated cross-section data with staggered treatment adoption. Unlike panel
data, different samples are observed in each period.
The target parameters are the group-time average treatment effects
.. math::
ATT(g, t) = \mathbb{E}[Y_t(g) - Y_t(\infty) \mid S=g, Q=1]
for all treatment cohorts :math:`g \in \mathcal{G}_{\mathrm{trt}}` and time periods
:math:`t \in \{2, \ldots, T\}` such that :math:`t \geq g`.
For each :math:`(g, t)` cell, the estimator compares outcomes at time :math:`t`
to a base period. With ``base_period="universal"``, all comparisons use period
:math:`g-1` (the last pre-treatment period for cohort :math:`g`). With
``base_period="varying"``, each comparison uses period :math:`t-1`.
For repeated cross-sections, the estimator follows the approach of [2]_,
extending the DDD framework from [1]_. Unlike panel data where outcomes are
differenced within units, RCS fits separate outcome regression
models for the target period :math:`t` and the base period for each subgroup.
When multiple comparison groups are available (not-yet-treated setting), the
estimator combines them using optimal GMM weights (Equation 4.11 from [1]_)
.. math::
\widehat{w}_{\mathrm{gmm}}^{g,t} = \frac{\widehat{\Omega}_{g,t}^{-1} \mathbf{1}}
{\mathbf{1}' \widehat{\Omega}_{g,t}^{-1} \mathbf{1}}
where :math:`\widehat{\Omega}_{g,t}` is the covariance matrix of
:math:`\widehat{ATT}_{\mathrm{dr},g_c}(g,t)` across comparison groups. The GMM
estimator (Equation 4.12 from [1]_) is then
.. math::
\widehat{ATT}_{\mathrm{dr,gmm}}(g,t) = \frac{\mathbf{1}' \widehat{\Omega}_{g,t}^{-1}}
{\mathbf{1}' \widehat{\Omega}_{g,t}^{-1} \mathbf{1}}
\widehat{ATT}_{\mathrm{dr}}(g,t).
Parameters
----------
data : DataFrame
Repeated cross-section data in long format with columns for outcome, time,
observation id, treatment group, and partition.
y_col : str
Name of the outcome variable column.
time_col : str
Name of the time period column.
id_col : str
Name of the observation identifier column. For RCS, this can be a row index
since units are not tracked across periods.
group_col : str
Name of the treatment group column (first period when treatment enabled).
Use 0 or np.inf for never-treated units.
partition_col : str
Name of the partition/eligibility column (1 = eligible, 0 = ineligible).
covariate_cols : list of str or None, default None
Names of covariate columns in the data. If None, uses intercept only.
control_group : {"nevertreated", "notyettreated"}, default "nevertreated"
Which units to use as controls. With "notyettreated", multiple comparison
groups may be available, triggering GMM aggregation.
base_period : {"universal", "varying"}, default "universal"
Base period selection. "universal" uses period g-1 as baseline for all
comparisons; "varying" uses period t-1 for each t.
est_method : {"dr", "reg", "ipw"}, default "dr"
Estimation method for each 2-period comparison.
boot : bool, default False
Whether to use multiplier bootstrap for inference.
biters : int, default 1000
Number of bootstrap repetitions (only used if boot=True).
cband : bool, default False
Whether to compute uniform confidence bands (only used if boot=True).
cluster : str or None, default None
Name of the column containing cluster identifiers for clustered
standard errors. If provided, the bootstrap resamples at the cluster
level (only used if boot=True).
alpha : float, default 0.05
Significance level for confidence intervals.
trim_level : float, default 0.995
Trimming level for propensity scores.
random_state : int, Generator, or None, default None
Controls random number generation for bootstrap reproducibility.
n_jobs : int, default=1
Number of parallel jobs for group-time estimation. 1 = sequential
(default), -1 = all cores, >1 = that many workers.
Returns
-------
DDDMultiPeriodRCResult
A NamedTuple containing:
- att: Array of ATT(g,t) point estimates
- se: Standard errors for each ATT(g,t)
- uci, lci: Confidence interval bounds
- groups: Treatment cohort for each estimate
- times: Time period for each estimate
- glist, tlist: Unique cohorts and periods
- inf_func_mat: Influence function matrix (n_obs x k)
- n: Number of observations
- args: Estimation arguments
See Also
--------
ddd_rc : Two-period DDD estimator for repeated cross-section data.
ddd_mp : Multi-period DDD estimator for panel data.
References
----------
.. [1] Ortiz-Villavicencio, M., & Sant'Anna, P. H. C. (2025).
*Better Understanding Triple Differences Estimators.*
arXiv preprint arXiv:2505.09942. https://arxiv.org/abs/2505.09942
.. [2] Sant'Anna, P. H. C., & Zhao, J. (2020).
*Doubly robust difference-in-differences estimators.*
Journal of Econometrics, 219(1), 101-122.
https://doi.org/10.1016/j.jeconom.2020.06.003
"""
data = to_polars(data)
tlist = np.sort(data[time_col].unique().to_numpy())
glist_raw = data[group_col].unique().to_numpy()
glist = np.sort([g for g in glist_raw if g > 0 and np.isfinite(g)])
n_obs = len(data)
n_periods = len(tlist)
n_cohorts = len(glist)
tfac = 0 if base_period == "universal" else 1
tlist_length = n_periods - tfac
inf_func_mat = np.zeros((n_obs, n_cohorts * tlist_length))
se_array = np.full(n_cohorts * tlist_length, np.nan)
data_with_idx = data.with_columns(pl.Series("_obs_idx", np.arange(len(data))))
args_list = []
for g in glist:
for t_idx in range(tlist_length):
t = tlist[t_idx + tfac]
args_list.append(
(
data_with_idx,
g,
t,
t_idx,
tlist,
base_period,
control_group,
y_col,
time_col,
id_col,
group_col,
partition_col,
covariate_cols,
est_method,
trim_level,
n_obs,
)
)
cell_results = parallel_map(_process_gt_cell_rc, args_list, n_jobs=n_jobs)
attgt_list = []
for counter, result in enumerate(cell_results):
if result is not None:
att_entry, inf_data, se_val = result
if att_entry is not None:
attgt_list.append(att_entry)
if inf_data is not None:
inf_func_scaled, obs_indices = inf_data
_update_inf_func_matrix_rc(inf_func_mat, inf_func_scaled, obs_indices, counter)
if se_val is not None:
se_array[counter] = se_val
if len(attgt_list) == 0:
raise ValueError("No valid (g,t) cells found.")
att_array = np.array([r.att for r in attgt_list])
groups_array = np.array([r.group for r in attgt_list])
times_array = np.array([r.time for r in attgt_list])
inf_func_trimmed = inf_func_mat[:, : len(attgt_list)]
cluster_vals = None
if cluster is not None:
cluster_vals = data_with_idx[cluster].to_numpy()
if boot:
boot_result = mboot_ddd(
inf_func=inf_func_trimmed,
biters=biters,
alpha=alpha,
cluster=cluster_vals,
random_state=random_state,
)
se_computed = boot_result.se.copy()
valid_se_mask = ~np.isnan(se_array[: len(se_computed)])
se_computed[valid_se_mask] = se_array[: len(se_computed)][valid_se_mask]
se_computed[se_computed <= np.sqrt(np.finfo(float).eps) * 10] = np.nan
cv = boot_result.crit_val if cband and np.isfinite(boot_result.crit_val) else stats.norm.ppf(1 - alpha / 2)
else:
V = inf_func_trimmed.T @ inf_func_trimmed / n_obs
se_computed = np.sqrt(np.diag(V) / n_obs)
valid_se_mask = ~np.isnan(se_array[: len(se_computed)])
se_computed[valid_se_mask] = se_array[: len(se_computed)][valid_se_mask]
se_computed[se_computed <= np.sqrt(np.finfo(float).eps) * 10] = np.nan
cv = stats.norm.ppf(1 - alpha / 2)
uci = att_array + cv * se_computed
lci = att_array - cv * se_computed
args = {
"panel": False,
"yname": y_col,
"pname": partition_col,
"control_group": control_group,
"base_period": base_period,
"est_method": est_method,
"boot": boot,
"biters": biters if boot else None,
"cband": cband if boot else None,
"cluster": cluster,
"alpha": alpha,
"trim_level": trim_level,
}
obs_groups = data[group_col].to_numpy()
return DDDMultiPeriodRCResult(
att=att_array,
se=se_computed,
uci=uci,
lci=lci,
groups=groups_array,
times=times_array,
glist=glist,
tlist=tlist,
inf_func_mat=inf_func_mat[:, : len(attgt_list)],
n=n_obs,
args=args,
unit_groups=obs_groups,
)
def _process_gt_cell_rc(
data,
g,
t,
t_idx,
tlist,
base_period,
control_group,
y_col,
time_col,
_id_col,
group_col,
partition_col,
covariate_cols,
est_method,
trim_level,
n_obs,
):
"""Process a single (g,t) cell and return results for RCS.
Returns
-------
tuple or None
(ATTgtRCResult, (inf_func_scaled, obs_indices) or None, se or None),
or None if cell is skipped entirely.
"""
pret = _get_base_period_rc(g, t_idx, tlist, base_period)
if pret is None:
warnings.warn(f"No pre-treatment periods for group {g}. Skipping.", UserWarning)
return None
post_treat = int(g <= t)
if post_treat:
pre_periods = tlist[tlist < g]
if len(pre_periods) == 0:
return None
pret = pre_periods[-1]
if base_period == "universal" and pret == t:
return (ATTgtRCResult(att=0.0, group=int(g), time=int(t), post=0), None, None)
cell_data, available_controls = _get_cell_data_rc(data, g, t, pret, control_group, time_col, group_col)
if cell_data is None or len(available_controls) == 0:
return None
n_cell = len(cell_data)
if len(available_controls) == 1:
result = _process_single_control_rc(
cell_data,
y_col,
time_col,
group_col,
partition_col,
g,
t,
pret,
covariate_cols,
est_method,
trim_level,
n_obs,
n_cell,
)
att_result, inf_func_scaled, obs_indices = result
if att_result is not None:
return (
ATTgtRCResult(att=att_result, group=int(g), time=int(t), post=post_treat),
(inf_func_scaled, obs_indices),
None,
)
return None
else:
result = _process_multiple_controls_rc(
cell_data,
available_controls,
y_col,
time_col,
group_col,
partition_col,
g,
t,
pret,
covariate_cols,
est_method,
trim_level,
n_obs,
n_cell,
)
if result[0] is not None:
att_gmm, inf_func_scaled, obs_indices, se_gmm = result
return (
ATTgtRCResult(att=att_gmm, group=int(g), time=int(t), post=post_treat),
(inf_func_scaled, obs_indices),
se_gmm,
)
return None
def _get_base_period_rc(g, t_idx, tlist, base_period):
"""Get the base (pre-treatment) period for comparison."""
if base_period == "universal":
pre_periods = tlist[tlist < g]
if len(pre_periods) == 0:
return None
return pre_periods[-1]
return tlist[t_idx]
def _get_cell_data_rc(data, g, t, pret, control_group, time_col, group_col):
"""Get data for a specific (g,t) cell and available controls for RCS."""
max_period = max(t, pret)
if control_group == "nevertreated":
control_expr = (pl.col(group_col) == 0) | (~pl.col(group_col).is_finite())
else:
control_expr = (
(pl.col(group_col) == 0) | (~pl.col(group_col).is_finite()) | (pl.col(group_col) > max_period)
) & (pl.col(group_col) != g)
treat_expr = pl.col(group_col) == g
cell_expr = treat_expr | control_expr
time_expr = pl.col(time_col).is_in([t, pret])
cell_data = data.filter(cell_expr & time_expr)
if len(cell_data) == 0:
return None, []
control_data = cell_data.filter(~pl.col(group_col).is_in([g]))
available_controls = [c for c in control_data[group_col].unique().to_list() if c != g]
return cell_data, available_controls
def _update_inf_func_matrix_rc(inf_func_mat, inf_func_scaled, obs_indices, counter):
"""Update influence function matrix."""
for i, idx in enumerate(obs_indices):
if i < len(inf_func_scaled):
inf_func_mat[idx, counter] = inf_func_scaled[i]
def _process_single_control_rc(
cell_data,
y_col,
time_col,
group_col,
partition_col,
g,
t,
pret,
covariate_cols,
est_method,
trim_level,
n_obs,
n_cell,
):
"""Process a (g,t) cell with a single control group for RCS."""
att_result, inf_func, obs_indices = _compute_single_ddd_rc(
cell_data, y_col, time_col, group_col, partition_col, g, t, pret, covariate_cols, est_method, trim_level
)
if att_result is None:
return None, None, None
inf_func_scaled = (n_obs / n_cell) * inf_func
return att_result, inf_func_scaled, obs_indices
def _process_multiple_controls_rc(
cell_data,
available_controls,
y_col,
time_col,
group_col,
partition_col,
g,
t,
pret,
covariate_cols,
est_method,
trim_level,
n_obs,
n_cell,
):
"""Process a (g,t) cell with multiple control groups using GMM aggregation for RCS."""
ddd_results = []
inf_funcs_local = []
cell_obs_indices = cell_data["_obs_idx"].to_numpy()
cell_idx_to_local = {idx: i for i, idx in enumerate(cell_obs_indices)}
for ctrl in available_controls:
ctrl_expr = (pl.col(group_col) == g) | (pl.col(group_col) == ctrl)
subset_data = cell_data.filter(ctrl_expr)
att_result, inf_func, subset_obs_indices = _compute_single_ddd_rc(
subset_data, y_col, time_col, group_col, partition_col, g, t, pret, covariate_cols, est_method, trim_level
)
if att_result is None:
continue
n_subset = len(subset_data)
inf_func_scaled = (n_cell / n_subset) * inf_func
ddd_results.append(att_result)
inf_full = np.zeros(n_cell)
for i, idx in enumerate(subset_obs_indices):
if idx in cell_idx_to_local and i < len(inf_func_scaled):
inf_full[cell_idx_to_local[idx]] = inf_func_scaled[i]
inf_funcs_local.append(inf_full)
if len(ddd_results) == 0:
return None, None, None, None
att_gmm, if_gmm, se_gmm = _gmm_aggregate(np.array(ddd_results), np.column_stack(inf_funcs_local), n_obs)
inf_func_scaled = (n_obs / n_cell) * if_gmm
return att_gmm, inf_func_scaled, cell_obs_indices, se_gmm
def _compute_single_ddd_rc(
cell_data, y_col, time_col, group_col, partition_col, g, t, _pret, covariate_cols, est_method, trim_level
):
"""Compute DDD for a single (g,t) cell with a single control group using RCS."""
treat_col = (pl.col(group_col) == g).cast(pl.Int64).alias("treat")
subgroup_expr = (
4 * (pl.col("treat") == 1).cast(pl.Int64) * (pl.col(partition_col) == 1).cast(pl.Int64)
+ 3 * (pl.col("treat") == 1).cast(pl.Int64) * (pl.col(partition_col) == 0).cast(pl.Int64)
+ 2 * (pl.col("treat") == 0).cast(pl.Int64) * (pl.col(partition_col) == 1).cast(pl.Int64)
+ 1 * (pl.col("treat") == 0).cast(pl.Int64) * (pl.col(partition_col) == 0).cast(pl.Int64)
).alias("subgroup")
cell_data = cell_data.with_columns([treat_col]).with_columns([subgroup_expr])
post_col = (pl.col(time_col) == t).cast(pl.Int64).alias("_post")
cell_data = cell_data.with_columns([post_col])
y = cell_data[y_col].to_numpy()
post = cell_data["_post"].to_numpy()
subgroup = cell_data["subgroup"].to_numpy()
obs_indices = cell_data["_obs_idx"].to_numpy()
if 4 not in set(subgroup):
return None, None, None
if covariate_cols is None:
X = np.ones((len(y), 1))
else:
cov_matrix = cell_data.select(covariate_cols).to_numpy()
intercept = np.ones((len(y), 1))
X = np.hstack([intercept, cov_matrix])
try:
result = ddd_rc(
y=y,
post=post,
subgroup=subgroup,
covariates=X,
est_method=est_method,
trim_level=trim_level,
influence_func=True,
)
return result.att, result.att_inf_func, obs_indices
except (ValueError, np.linalg.LinAlgError):
return None, None, None