"""Doubly robust DDD estimator for multi-period panel data with staggered adoption."""
from __future__ import annotations
import warnings
import numpy as np
import polars as pl
from scipy import stats
from moderndid.core.dataframe import to_polars
from moderndid.core.parallel import parallel_map
from ..bootstrap.mboot_ddd import mboot_ddd
from ..container import ATTgtResult, DDDMultiPeriodResult
from .ddd_panel import ddd_panel
[docs]
def ddd_mp(
data,
y_col,
time_col,
id_col,
group_col,
partition_col,
covariate_cols=None,
control_group="nevertreated",
base_period="universal",
est_method="dr",
boot=False,
biters=1000,
cband=False,
cluster=None,
alpha=0.05,
random_state=None,
n_jobs=1,
):
r"""Compute the multi-period doubly robust DDD estimator for the ATT with panel data.
Implements the multi-period triple difference-in-differences estimator from [1]_.
The target parameters are the group-time average treatment effects
.. math::
ATT(g, t) = \mathbb{E}[Y_t(g) - Y_t(\infty) \mid S=g, Q=1]
for all treatment cohorts :math:`g \in \mathcal{G}_{trt}` and time periods
:math:`t \in \{2, \ldots, T\}` such that :math:`t \geq g`.
For each (g,t) cell with comparison group :math:`g_{\mathrm{c}}`, the doubly robust
estimand (Equation 4.8 from [1]_) is
.. math::
\widehat{ATT}_{\mathrm{dr},g_{\mathrm{c}}}(g,t) &= \mathbb{E}_n\left[
\left(\widehat{w}_{\mathrm{trt}}^{S=g,Q=1}(S,Q)
- \widehat{w}_{g,0}^{S=g,Q=1}(S,Q,X)\right)
\left(Y_t - Y_{g-1} - \widehat{m}_{Y_t-Y_{g-1}}^{S=g,Q=0}(X)\right)\right] \\
&+ \mathbb{E}_n\left[
\left(\widehat{w}_{\mathrm{trt}}^{S=g,Q=1}(S,Q)
- \widehat{w}_{g_{\mathrm{c}},1}^{S=g,Q=1}(S,Q,X)\right)
\left(Y_t - Y_{g-1} - \widehat{m}_{Y_t-Y_{g-1}}^{S=g_{\mathrm{c}},Q=1}(X)\right)\right] \\
&- \mathbb{E}_n\left[
\left(\widehat{w}_{\mathrm{trt}}^{S=g,Q=1}(S,Q)
- \widehat{w}_{g_{\mathrm{c}},0}^{S=g,Q=1}(S,Q,X)\right)
\left(Y_t - Y_{g-1} - \widehat{m}_{Y_t-Y_{g-1}}^{S=g_{\mathrm{c}},Q=0}(X)\right)\right].
When multiple comparison groups are available (not-yet-treated setting), the
estimator combines them using optimal GMM weights (Equation 4.11 from [1]_)
.. math::
\widehat{w}_{gmm}^{g,t} = \frac{\widehat{\Omega}_{g,t}^{-1} \mathbf{1}}
{\mathbf{1}' \widehat{\Omega}_{g,t}^{-1} \mathbf{1}}
where :math:`\widehat{\Omega}_{g,t}` is the covariance matrix of
:math:`\widehat{ATT}_{dr,g_c}(g,t)` across comparison groups. The GMM
estimator (Equation 4.12 from [1]_) is then
.. math::
\widehat{ATT}_{dr,gmm}(g,t) = \frac{\mathbf{1}' \widehat{\Omega}_{g,t}^{-1}}
{\mathbf{1}' \widehat{\Omega}_{g,t}^{-1} \mathbf{1}}
\widehat{ATT}_{dr}(g,t).
Parameters
----------
data : DataFrame
Panel data in long format with columns for outcome, time, unit id,
treatment group, and partition.
y_col : str
Name of the outcome variable column.
time_col : str
Name of the time period column.
id_col : str
Name of the unit identifier column.
group_col : str
Name of the treatment group column (first period when treatment enabled).
Use 0 or np.inf for never-treated units.
partition_col : str
Name of the partition/eligibility column (1 = eligible, 0 = ineligible).
covariate_cols : list of str or None, default None
Names of covariate columns in the data. If None, uses intercept only.
control_group : {"nevertreated", "notyettreated"}, default "nevertreated"
Which units to use as controls. With "notyettreated", multiple comparison
groups may be available, triggering GMM aggregation.
base_period : {"universal", "varying"}, default "universal"
Base period selection. "universal" uses period g-1 as baseline for all
comparisons; "varying" uses period t-1 for each t.
est_method : {"dr", "reg", "ipw"}, default "dr"
Estimation method for each 2-period comparison.
boot : bool, default False
Whether to use multiplier bootstrap for inference.
biters : int, default 1000
Number of bootstrap repetitions (only used if boot=True).
cband : bool, default False
Whether to compute uniform confidence bands (only used if boot=True).
cluster : str or None, default None
Name of the column containing cluster identifiers for clustered
standard errors. If provided, the bootstrap resamples at the cluster
level (only used if boot=True).
alpha : float, default 0.05
Significance level for confidence intervals.
random_state : int, Generator, or None, default None
Controls random number generation for bootstrap reproducibility.
n_jobs : int, default=1
Number of parallel jobs for group-time estimation. 1 = sequential
(default), -1 = all cores, >1 = that many workers.
Returns
-------
DDDMultiPeriodResult
A NamedTuple containing:
- att: Array of ATT(g,t) point estimates
- se: Standard errors for each ATT(g,t)
- uci, lci: Confidence interval bounds
- groups: Treatment cohort for each estimate
- times: Time period for each estimate
- glist, tlist: Unique cohorts and periods
- inf_func_mat: Influence function matrix (n x k)
- n: Number of units
- args: Estimation arguments
See Also
--------
ddd_panel : Two-period DDD estimator for panel data.
Notes
-----
The influence functions are rescaled by :math:`n / n_{g,t}` where :math:`n_{g,t}`
is the number of units in each (g,t) cell, following the approach in [1]_.
The standard errors are computed from the influence function matrix as
.. math::
\widehat{V} = \frac{1}{n} \widehat{\Psi}' \widehat{\Psi}, \quad
\widehat{se}_{g,t} = \sqrt{\widehat{V}_{g,t,g,t} / n}
where :math:`\widehat{\Psi}` is the :math:`n \times k` matrix of influence
functions. For cells with GMM aggregation, the standard error formula from
Equation 4.12 is used instead.
References
----------
.. [1] Ortiz-Villavicencio, M., & Sant'Anna, P. H. C. (2025).
*Better Understanding Triple Differences Estimators.*
arXiv preprint arXiv:2505.09942. https://arxiv.org/abs/2505.09942
"""
data = to_polars(data)
tlist = np.sort(data[time_col].unique().to_numpy())
glist_raw = data[group_col].unique().to_numpy()
glist = np.sort([g for g in glist_raw if g > 0 and np.isfinite(g)])
n_units = data[id_col].n_unique()
n_periods = len(tlist)
n_cohorts = len(glist)
tfac = 0 if base_period == "universal" else 1
tlist_length = n_periods - tfac
inf_func_mat = np.zeros((n_units, n_cohorts * tlist_length))
se_array = np.full(n_cohorts * tlist_length, np.nan)
unique_ids = np.sort(data[id_col].unique().to_numpy())
id_to_idx = {uid: idx for idx, uid in enumerate(unique_ids)}
args_list = []
for g in glist:
for t_idx in range(tlist_length):
t = tlist[t_idx + tfac]
args_list.append(
(
data,
g,
t,
t_idx,
tlist,
base_period,
control_group,
y_col,
time_col,
id_col,
group_col,
partition_col,
covariate_cols,
est_method,
n_units,
)
)
cell_results = parallel_map(_process_gt_cell, args_list, n_jobs=n_jobs)
attgt_list = []
for counter, result in enumerate(cell_results):
if result is not None:
att_entry, inf_data, se_val = result
if att_entry is not None:
attgt_list.append(att_entry)
if inf_data is not None:
inf_func_scaled, cell_id_list = inf_data
_update_inf_func_matrix(inf_func_mat, inf_func_scaled, cell_id_list, id_to_idx, counter)
if se_val is not None:
se_array[counter] = se_val
if len(attgt_list) == 0:
raise ValueError("No valid (g,t) cells found.")
att_array = np.array([r.att for r in attgt_list])
groups_array = np.array([r.group for r in attgt_list])
times_array = np.array([r.time for r in attgt_list])
inf_func_trimmed = inf_func_mat[:, : len(attgt_list)]
unit_info = data.sort([id_col, time_col]).group_by(id_col, maintain_order=True).first().sort(id_col)
cluster_vals = None
if cluster is not None:
cluster_vals = unit_info[cluster].to_numpy()
unit_groups = unit_info[group_col].to_numpy()
if boot:
boot_result = mboot_ddd(
inf_func=inf_func_trimmed,
biters=biters,
alpha=alpha,
cluster=cluster_vals,
random_state=random_state,
)
se_computed = boot_result.se.copy()
valid_se_mask = ~np.isnan(se_array[: len(se_computed)])
se_computed[valid_se_mask] = se_array[: len(se_computed)][valid_se_mask]
se_computed[se_computed <= np.sqrt(np.finfo(float).eps) * 10] = np.nan
cv = boot_result.crit_val if cband and np.isfinite(boot_result.crit_val) else stats.norm.ppf(1 - alpha / 2)
else:
V = inf_func_trimmed.T @ inf_func_trimmed / n_units
se_computed = np.sqrt(np.diag(V) / n_units)
valid_se_mask = ~np.isnan(se_array[: len(se_computed)])
se_computed[valid_se_mask] = se_array[: len(se_computed)][valid_se_mask]
se_computed[se_computed <= np.sqrt(np.finfo(float).eps) * 10] = np.nan
cv = stats.norm.ppf(1 - alpha / 2)
uci = att_array + cv * se_computed
lci = att_array - cv * se_computed
args = {
"panel": True,
"yname": y_col,
"pname": partition_col,
"control_group": control_group,
"base_period": base_period,
"est_method": est_method,
"boot": boot,
"biters": biters if boot else None,
"cband": cband if boot else None,
"cluster": cluster,
"alpha": alpha,
}
return DDDMultiPeriodResult(
att=att_array,
se=se_computed,
uci=uci,
lci=lci,
groups=groups_array,
times=times_array,
glist=glist,
tlist=tlist,
inf_func_mat=inf_func_mat[:, : len(attgt_list)],
n=n_units,
args=args,
unit_groups=unit_groups,
)
def _process_gt_cell(
data,
g,
t,
t_idx,
tlist,
base_period,
control_group,
y_col,
time_col,
id_col,
group_col,
partition_col,
covariate_cols,
est_method,
n_units,
):
"""Process a single (g,t) cell and return results.
Returns
-------
tuple or None
(ATTgtResult, (inf_func_scaled, cell_id_list) or None, se or None),
or None if cell is skipped entirely.
"""
pret = _get_base_period(g, t_idx, tlist, base_period)
if pret is None:
warnings.warn(f"No pre-treatment periods for group {g}. Skipping.", UserWarning)
return None
post_treat = int(g <= t)
if post_treat:
pre_periods = tlist[tlist < g]
if len(pre_periods) == 0:
return None
pret = pre_periods[-1]
if base_period == "universal" and pret == t:
return (ATTgtResult(att=0.0, group=int(g), time=int(t), post=0), None, None)
cell_data, available_controls = _get_cell_data(data, g, t, pret, control_group, time_col, group_col)
if cell_data is None or len(available_controls) == 0:
return None
n_cell = cell_data[id_col].n_unique()
if len(available_controls) == 1:
result = _process_single_control(
cell_data,
y_col,
time_col,
id_col,
group_col,
partition_col,
g,
t,
pret,
covariate_cols,
est_method,
n_units,
n_cell,
)
att_result, inf_func_scaled, cell_id_list = result
if att_result is not None:
return (
ATTgtResult(att=att_result, group=int(g), time=int(t), post=post_treat),
(inf_func_scaled, cell_id_list),
None,
)
return None
else:
result = _process_multiple_controls(
cell_data,
available_controls,
y_col,
time_col,
id_col,
group_col,
partition_col,
g,
t,
pret,
covariate_cols,
est_method,
n_units,
n_cell,
)
if result[0] is not None:
att_gmm, inf_func_scaled, cell_id_list, se_gmm = result
return (
ATTgtResult(att=att_gmm, group=int(g), time=int(t), post=post_treat),
(inf_func_scaled, cell_id_list),
se_gmm,
)
return None
def _get_base_period(g, t_idx, tlist, base_period):
"""Get the base (pre-treatment) period for comparison."""
if base_period == "universal":
pre_periods = tlist[tlist < g]
if len(pre_periods) == 0:
return None
return pre_periods[-1]
return tlist[t_idx]
def _get_cell_data(data, g, t, pret, control_group, time_col, group_col):
"""Get data for a specific (g,t) cell and available controls."""
max_period = max(t, pret)
if control_group == "nevertreated":
control_expr = (pl.col(group_col) == 0) | (~pl.col(group_col).is_finite())
else:
control_expr = (
(pl.col(group_col) == 0) | (~pl.col(group_col).is_finite()) | (pl.col(group_col) > max_period)
) & (pl.col(group_col) != g)
treat_expr = pl.col(group_col) == g
cell_expr = treat_expr | control_expr
time_expr = pl.col(time_col).is_in([t, pret])
cell_data = data.filter(cell_expr & time_expr)
if len(cell_data) == 0:
return None, []
control_data = cell_data.filter(~pl.col(group_col).is_in([g]))
available_controls = [c for c in control_data[group_col].unique().to_list() if c != g]
return cell_data, available_controls
def _update_inf_func_matrix(inf_func_mat, inf_func_scaled, cell_id_list, id_to_idx, counter):
"""Update influence function matrix with scaled values for a cell."""
for i, uid in enumerate(cell_id_list):
if uid in id_to_idx and i < len(inf_func_scaled):
inf_func_mat[id_to_idx[uid], counter] = inf_func_scaled[i]
def _process_single_control(
cell_data,
y_col,
time_col,
id_col,
group_col,
partition_col,
g,
t,
pret,
covariate_cols,
est_method,
n_units,
n_cell,
):
"""Process a (g,t) cell with a single control group."""
att_result, inf_func, common_ids = _compute_single_ddd(
cell_data, y_col, time_col, id_col, group_col, partition_col, g, t, pret, covariate_cols, est_method
)
if att_result is None:
return None, None, None
inf_func_scaled = (n_units / len(common_ids)) * inf_func
return att_result, inf_func_scaled, common_ids
def _process_multiple_controls(
cell_data,
available_controls,
y_col,
time_col,
id_col,
group_col,
partition_col,
g,
t,
pret,
covariate_cols,
est_method,
n_units,
n_cell,
):
"""Process a (g,t) cell with multiple control groups using GMM aggregation."""
ddd_results = []
inf_funcs_local = []
all_common_ids = set()
for ctrl in available_controls:
ctrl_expr = (pl.col(group_col) == g) | (pl.col(group_col) == ctrl)
subset_data = cell_data.filter(ctrl_expr)
att_result, inf_func, common_ids = _compute_single_ddd(
subset_data, y_col, time_col, id_col, group_col, partition_col, g, t, pret, covariate_cols, est_method
)
if att_result is None:
continue
all_common_ids.update(common_ids)
ddd_results.append(att_result)
inf_funcs_local.append((inf_func, common_ids))
if len(ddd_results) == 0:
return None, None, None, None
cell_id_list = np.sort(np.array(list(all_common_ids)))
cell_id_to_local = {uid: idx for idx, uid in enumerate(cell_id_list)}
n_cell_actual = len(cell_id_list)
inf_mat_local = []
for inf_func, common_ids in inf_funcs_local:
inf_func_scaled = (n_cell_actual / len(common_ids)) * inf_func
inf_full = np.zeros(n_cell_actual)
for i, uid in enumerate(common_ids):
if uid in cell_id_to_local and i < len(inf_func_scaled):
inf_full[cell_id_to_local[uid]] = inf_func_scaled[i]
inf_mat_local.append(inf_full)
att_gmm, if_gmm, se_gmm = _gmm_aggregate(np.array(ddd_results), np.column_stack(inf_mat_local), n_units)
inf_func_scaled = (n_units / n_cell_actual) * if_gmm
return att_gmm, inf_func_scaled, cell_id_list, se_gmm
def _compute_single_ddd(
cell_data, y_col, time_col, id_col, group_col, partition_col, g, t, pret, covariate_cols, est_method
):
"""Compute DDD for a single (g,t) cell with a single control group."""
treat_col = (pl.col(group_col) == g).cast(pl.Int64).alias("treat")
subgroup_expr = (
4 * (pl.col("treat") == 1).cast(pl.Int64) * (pl.col(partition_col) == 1).cast(pl.Int64)
+ 3 * (pl.col("treat") == 1).cast(pl.Int64) * (pl.col(partition_col) == 0).cast(pl.Int64)
+ 2 * (pl.col("treat") == 0).cast(pl.Int64) * (pl.col(partition_col) == 1).cast(pl.Int64)
+ 1 * (pl.col("treat") == 0).cast(pl.Int64) * (pl.col(partition_col) == 0).cast(pl.Int64)
).alias("subgroup")
cell_data = cell_data.with_columns([treat_col]).with_columns([subgroup_expr])
post_data = cell_data.filter(pl.col(time_col) == t).sort(id_col)
pre_data = cell_data.filter(pl.col(time_col) == pret).sort(id_col)
post_ids = set(post_data[id_col].to_list())
pre_ids = set(pre_data[id_col].to_list())
common_ids = post_ids & pre_ids
if len(common_ids) == 0:
return None, None, None
common_ids_list = list(common_ids)
post_data = post_data.filter(pl.col(id_col).is_in(common_ids_list)).sort(id_col)
pre_data = pre_data.filter(pl.col(id_col).is_in(common_ids_list)).sort(id_col)
common_ids_arr = post_data[id_col].to_numpy()
y1 = post_data[y_col].to_numpy()
y0 = pre_data[y_col].to_numpy()
subgroup = post_data["subgroup"].to_numpy()
if 4 not in set(subgroup):
return None, None, None
if covariate_cols is None:
X = np.ones((len(y1), 1))
else:
cov_matrix = post_data.select(covariate_cols).to_numpy()
intercept = np.ones((len(y1), 1))
X = np.hstack([intercept, cov_matrix])
try:
result = ddd_panel(y1=y1, y0=y0, subgroup=subgroup, covariates=X, est_method=est_method, influence_func=True)
return result.att, result.att_inf_func, common_ids_arr
except (ValueError, np.linalg.LinAlgError):
return None, None, None
def _gmm_aggregate(att_vals, inf_mat, n_total):
"""Compute GMM-weighted aggregate of ATT estimates across control groups."""
omega = np.cov(inf_mat, rowvar=False)
if omega.ndim == 0:
omega = np.array([[omega]])
try:
inv_omega = np.linalg.inv(omega)
except np.linalg.LinAlgError:
inv_omega = np.linalg.pinv(omega)
ones = np.ones(len(att_vals))
w = inv_omega @ ones / (ones @ inv_omega @ ones)
att_gmm = np.sum(w * att_vals)
if_gmm = inf_mat @ w
se_gmm = np.sqrt(1 / (n_total * np.sum(inv_omega)))
return att_gmm, if_gmm, se_gmm