"""Dynamic covariate balancing treatment effect estimation for panel data with time-varying treatments."""
from __future__ import annotations
import warnings
import numpy as np
import polars as pl
from moderndid.core.dataframe import to_polars
from moderndid.core.parallel import parallel_map
from moderndid.core.preprocess import DynBalancingConfig, PreprocessDataBuilder
from .container import DynBalancingHetResult, DynBalancingHistoryResult, DynBalancingResult
from .estimation.inference import compute_quantiles, compute_variance, compute_variance_clustered
from .estimation.weights_dcb import compute_dcb_estimator
from .estimation.weights_ipw import compute_ipw_estimator
[docs]
def dyn_balancing(
data,
yname: str,
tname: str,
idname: str,
treatment_name: str,
ds1: list[int],
ds2: list[int],
xformla: str | None = None,
fixed_effects: list[str] | None = None,
pooled: bool = False,
clustervars: list[str] | None = None,
balancing: str = "dcb",
method: str = "lasso_plain",
alp: float = 0.05,
final_period: int | None = None,
initial_period: int | None = None,
adaptive_balancing: bool = True,
debias: bool = False,
continuous_treatment: bool = False,
lb: float = 0.0005,
ub: float = 2.0,
regularization: bool = True,
fast_adaptive: bool = False,
grid_length: int = 1000,
n_beta_nonsparse: float = 1e-4,
ratio_coefficients: float = 1 / 3,
nfolds: int = 10,
lags: int | None = None,
robust_quantile: bool = True,
demeaned_fe: bool = False,
histories_length: list[int] | None = None,
final_periods: list[int] | None = None,
impulse_response: bool = False,
n_jobs: int = 1,
) -> DynBalancingResult | DynBalancingHistoryResult | DynBalancingHetResult:
r"""Estimate treatment effects under dynamic treatment regimes.
Implements the dynamic covariate balancing (DCB) estimator of [1]_ for
comparing potential outcomes under two treatment
histories :math:`d_{1:T}` and :math:`d'_{1:T}`. The average treatment
effect is defined as
.. math::
\text{ATE}(d_{1:T}, d'_{1:T}) = \mu_T(d_{1:T}) - \mu_T(d'_{1:T}),
where :math:`\mu_T(d_{1:T}) = \mathbb{E}[Y_T(d_{1:T})]` is the
potential outcome under treatment history :math:`d_{1:T}`.
Identification relies on a sequential conditional independence assumption
and overlap. For each period :math:`t`, the DCB estimator solves a
quadratic program to find balancing weights :math:`\hat{\gamma}_t` that
satisfy dynamic covariate balance constraints while minimising the
:math:`\ell_2` norm. The potential outcome is then estimated as a
bias-corrected weighted average of outcomes in the final period. IPW,
AIPW, and IPW-MSM alternatives are also available as benchmarks.
Parameters
----------
data : DataFrame
Panel data in long format. Accepts any object implementing the Arrow
PyCapsule Interface (``__arrow_c_stream__``), including polars, pandas,
pyarrow Table, and cudf DataFrames.
yname : str
The name of the outcome variable.
tname : str
The name of the column containing the time periods.
idname : str
The individual (cross-sectional unit) id name.
treatment_name : str
The name of the binary treatment column.
ds1 : list[int]
Target treatment history for the first potential outcome.
Length must equal the number of time periods.
ds2 : list[int]
Target treatment history for the second potential outcome.
Must have the same length as ``ds1``.
xformla : str or None, default=None
A formula for the covariates to include in the model. It should be of
the form ``"~ X1 + X2"``.
fixed_effects : list[str] or None, default=None
Column names to include as fixed-effect dummies.
pooled : bool, default=False
If True, pool observations across periods for coefficient estimation.
clustervars : list[str] or None, default=None
Column names on which to cluster standard errors.
balancing : {'dcb', 'aipw', 'ipw', 'ipw_msm'}, default='dcb'
Weighting strategy. ``'dcb'`` uses dynamic covariate balancing,
``'ipw'`` uses inverse probability weighting, ``'aipw'`` uses
augmented IPW, and ``'ipw_msm'`` uses stabilised marginal structural
model weights.
method : {'lasso_plain', 'lasso_subsample'}, default='lasso_plain'
LASSO estimation strategy for the coefficient stage.
alp : float, default=0.05
Significance level for confidence intervals.
final_period : int or None, default=None
Last time period to include. Defaults to the maximum in the data.
initial_period : int or None, default=None
First time period to include. Defaults to the minimum in the data.
adaptive_balancing : bool, default=True
If True, use tighter balance constraints on covariates with large
estimated coefficients.
debias : bool, default=False
If True, apply bootstrap debiasing with 20 replicates.
continuous_treatment : bool, default=False
If True, treat the treatment variable as continuous.
lb : float, default=0.0005
Lower bound for tuning constant grid search.
ub : float, default=2.0
Upper bound for tuning constant grid search.
regularization : bool, default=True
If True use cross-validated LASSO, otherwise ridge.
fast_adaptive : bool, default=False
If True, use flat grid search instead of three-segment nested search.
grid_length : int, default=1000
Number of grid points for tuning constant search.
n_beta_nonsparse : float, default=1e-4
Threshold below which a rescaled coefficient is treated as zero.
ratio_coefficients : float, default=1/3
Fraction of largest coefficients to prioritise when sparsity is low.
nfolds : int, default=10
Cross-validation folds for LASSO.
lags : int or None, default=None
Treatment lags for the coefficient stage.
robust_quantile : bool, default=True
If True, use chi-squared critical values for inference.
demeaned_fe : bool, default=False
If True, demean fixed effects before estimation.
histories_length : list[int] or None, default=None
If provided, estimate ATEs for varying treatment history lengths.
Each entry ``k`` must satisfy ``1 <= k <= len(ds1)``. For each ``k``,
the last ``k`` elements of ``ds1`` and ``ds2`` are used. Returns a
:class:`DynBalancingHistoryResult`. Mutually exclusive with
``final_periods``.
final_periods : list[int] or None, default=None
If provided, estimate ATEs at each specified final period. Returns a
:class:`DynBalancingHetResult`. Mutually exclusive with
``histories_length``.
impulse_response : bool, default=False
If True (requires ``histories_length``), estimate impulse responses
instead of cumulative effects. For each history length ``k``, the
treatment sequences are set to ``ds1 = [1, 0, ..., 0]`` and
``ds2 = [0, 0, ..., 0]`` (both length ``k``), measuring the effect
of a one-period treatment shock at varying horizons.
n_jobs : int, default=1
Number of parallel workers for ``histories_length`` and
``final_periods`` modes. 1 = sequential, -1 = all cores,
>1 = that many threads.
Returns
-------
DynBalancingResult or DynBalancingHistoryResult or DynBalancingHetResult
When neither ``histories_length`` nor ``final_periods`` is set,
returns a single :class:`DynBalancingResult`. Otherwise returns the
corresponding multi-result container.
- **att**: The ATE point estimate (:math:`\mu_1 - \mu_2`)
- **var_att**: Variance of the ATE
- **mu1**: Potential outcome estimate under ``ds1``
- **mu2**: Potential outcome estimate under ``ds2``
- **var_mu1**: Variance of ``mu1``
- **var_mu2**: Variance of ``mu2``
- **robust_quantile**: Chi-squared critical value for inference
- **gaussian_quantile**: Gaussian critical value for inference
- **gammas**: Balancing weights per treatment history
- **coefficients**: LASSO coefficients per treatment history
- **imbalances**: Covariate imbalance measures
- **estimation_params**: Metadata (observation count, variable names, etc.)
References
----------
.. [1] Viviano, D. and Bradic, J. (2026). "Dynamic covariate balancing:
estimating treatment effects over time with potential local projections."
*Biometrika*, asag016. https://doi.org/10.1093/biomet/asag016
.. [2] Acemoglu, D., Naidu, S., Restrepo, P., and Robinson, J.A. (2019).
"Democracy does cause growth." *Journal of Political Economy*, 127(1),
47-100. https://doi.org/10.1086/700936
Examples
--------
The dataset below contains 141 countries observed across six five-year
periods (1989--2010) from the democracy and growth study of
Acemoglu et al. (2019) [2]_. The treatment ``D`` is a binary democracy
indicator that can switch on and off across periods, and the outcome
``Y`` is log GDP per capita. This is the same application used in [1]_.
We estimate the effect of being democratic for two consecutive periods
compared to not being democratic, controlling for five country-level
covariates and region fixed effects. The treatment histories
``ds1=[1, 1]`` and ``ds2=[0, 0]`` specify the two sequences to compare,
read left to right from the earliest to the most recent period:
.. code-block:: python
from moderndid import load_acemoglu, dyn_balancing
df = load_acemoglu()
result = dyn_balancing(
data=df,
yname="Y",
tname="Time",
idname="Unit",
treatment_name="D",
ds1=[1, 1],
ds2=[0, 0],
xformla="~ V1 + V2 + V3 + V4 + V5",
fixed_effects=["region"],
)
print(result)
.. code-block:: text
==============================================================================
Dynamic Covariate Balancing Estimation
==============================================================================
DCB estimation for the ATE:
┌────────┬────────────┬──────────┬────────────────────────┐
│ ATE │ Std. Error │ Pr(>|t|) │ [95% Conf. Interval] │
├────────┼────────────┼──────────┼────────────────────────┤
│ 0.3011 │ 0.2032 │ 0.1383 │ [ -0.0971, 0.6993] │
└────────┴────────────┴──────────┴────────────────────────┘
------------------------------------------------------------------------------
Signif. codes: '*' confidence interval does not cover 0
------------------------------------------------------------------------------
Potential Outcomes
------------------------------------------------------------------------------
mu(ds1): 8.0044 (0.1397)
mu(ds2): 7.7033 (0.1476)
------------------------------------------------------------------------------
Data Info
------------------------------------------------------------------------------
Treatment history ds1: [1, 1]
Treatment history ds2: [0, 0]
Outcome variable: Y
Units: 137
Observations: 274
------------------------------------------------------------------------------
Estimation Details
------------------------------------------------------------------------------
Balancing: DCB
Coefficient estimation: lasso_plain
------------------------------------------------------------------------------
Inference
------------------------------------------------------------------------------
Significance level: 0.05
Analytical standard errors
Robust (chi-squared) critical values
==============================================================================
Viviano and Bradic (2026)
"""
if histories_length is not None and final_periods is not None:
raise ValueError("histories_length and final_periods are mutually exclusive.")
if impulse_response and histories_length is None:
raise ValueError("impulse_response=True requires histories_length.")
if not ds1:
raise ValueError("ds1 must be a non-empty list of treatment values.")
if not ds2:
raise ValueError("ds2 must be a non-empty list of treatment values.")
if len(ds1) != len(ds2):
raise ValueError(f"ds1 and ds2 must have the same length, got {len(ds1)} and {len(ds2)}.")
if balancing not in ("dcb", "aipw", "ipw", "ipw_msm"):
raise ValueError(f"balancing must be one of 'dcb', 'aipw', 'ipw', 'ipw_msm', got {balancing!r}.")
if method not in ("lasso_plain", "lasso_subsample"):
raise ValueError(f"method must be one of 'lasso_plain', 'lasso_subsample', got {method!r}.")
if not 0 < alp < 1:
raise ValueError(f"alp must be between 0 and 1 (exclusive), got {alp}.")
if lb > ub:
raise ValueError(f"lb ({lb}) must be less than or equal to ub ({ub}).")
if continuous_treatment:
raise NotImplementedError(
"Continuous treatment estimation is not yet implemented. "
"The reference R package (DynBalancing) also lacks this feature."
)
if alp > 0.1:
warnings.warn("Significance level larger than 0.1 selected.", stacklevel=2)
if histories_length is None and final_periods is None and len(ds1) == 1:
warnings.warn("ds1 contains one element. No dynamics will be considered.", stacklevel=2)
if pooled and clustervars is None:
clustervars = [idname]
if histories_length is not None:
return _run_history(
data=data,
yname=yname,
tname=tname,
idname=idname,
treatment_name=treatment_name,
ds1=ds1,
ds2=ds2,
histories_length=histories_length,
xformla=xformla,
fixed_effects=fixed_effects,
pooled=pooled,
clustervars=clustervars,
balancing=balancing,
method=method,
alp=alp,
final_period=final_period,
initial_period=initial_period,
adaptive_balancing=adaptive_balancing,
debias=debias,
continuous_treatment=continuous_treatment,
lb=lb,
ub=ub,
regularization=regularization,
fast_adaptive=fast_adaptive,
grid_length=grid_length,
n_beta_nonsparse=n_beta_nonsparse,
ratio_coefficients=ratio_coefficients,
nfolds=nfolds,
lags=lags,
robust_quantile=robust_quantile,
demeaned_fe=demeaned_fe,
impulse_response=impulse_response,
n_jobs=n_jobs,
)
if final_periods is not None:
return _run_het(
data=data,
yname=yname,
tname=tname,
idname=idname,
treatment_name=treatment_name,
ds1=ds1,
ds2=ds2,
final_periods=final_periods,
xformla=xformla,
fixed_effects=fixed_effects,
pooled=pooled,
clustervars=clustervars,
balancing=balancing,
method=method,
alp=alp,
initial_period=initial_period,
adaptive_balancing=adaptive_balancing,
debias=debias,
continuous_treatment=continuous_treatment,
lb=lb,
ub=ub,
regularization=regularization,
fast_adaptive=fast_adaptive,
grid_length=grid_length,
n_beta_nonsparse=n_beta_nonsparse,
ratio_coefficients=ratio_coefficients,
nfolds=nfolds,
lags=lags,
robust_quantile=robust_quantile,
demeaned_fe=demeaned_fe,
n_jobs=n_jobs,
)
df = to_polars(data)
config = DynBalancingConfig(
yname=yname,
tname=tname,
idname=idname,
treatment_name=treatment_name,
ds1=list(ds1),
ds2=list(ds2),
xformla=xformla,
fixed_effects=fixed_effects,
pooled=pooled,
clustervars=clustervars,
balancing=balancing,
method=method,
alp=alp,
final_period=final_period,
initial_period=initial_period,
adaptive_balancing=adaptive_balancing,
debias=debias,
continuous_treatment=continuous_treatment,
lb=lb,
ub=ub,
regularization=regularization,
fast_adaptive=fast_adaptive,
grid_length=grid_length,
n_beta_nonsparse=n_beta_nonsparse,
ratio_coefficients=ratio_coefficients,
nfolds=nfolds,
lags=lags,
robust_quantile=robust_quantile,
demeaned_fe=demeaned_fe,
)
dp = PreprocessDataBuilder().with_data(df).with_config(config).validate().transform().build()
n_periods = config.n_periods
outcome = dp.outcome_vector
treatment_matrix = dp.treatment_matrix
cluster = dp.cluster
dim_fe = dp.dim_fe
covariates_t = _reindex_covariates(dp.covariate_dict, config.time_periods)
ds1_arr = np.array(ds1, dtype=float)
ds2_arr = np.array(ds2, dtype=float)
if balancing == "dcb":
res1 = compute_dcb_estimator(
n_periods,
outcome,
treatment_matrix,
covariates_t,
ds1_arr,
method=method,
adaptive_balancing=adaptive_balancing,
debias=debias,
regularization=regularization,
nfolds=nfolds,
lb=lb,
ub=ub,
grid_length=grid_length,
n_beta_nonsparse=n_beta_nonsparse,
ratio_coefficients=ratio_coefficients,
lags=lags,
dim_fe=dim_fe,
fast_adaptive=fast_adaptive,
)
res2 = compute_dcb_estimator(
n_periods,
outcome,
treatment_matrix,
covariates_t,
ds2_arr,
method=method,
adaptive_balancing=adaptive_balancing,
debias=debias,
regularization=regularization,
nfolds=nfolds,
lb=lb,
ub=ub,
grid_length=grid_length,
n_beta_nonsparse=n_beta_nonsparse,
ratio_coefficients=ratio_coefficients,
lags=lags,
dim_fe=dim_fe,
fast_adaptive=fast_adaptive,
)
mu1 = res1.mu_hat
mu2 = res2.mu_hat
if debias:
mu1 -= res1.bias
mu2 -= res2.bias
if cluster is not None:
var1 = compute_variance_clustered(res1.gammas, res1.predictions, res1.not_nas, outcome, cluster)
var2 = compute_variance_clustered(res2.gammas, res2.predictions, res2.not_nas, outcome, cluster)
else:
var1 = compute_variance(res1.gammas, res1.predictions, res1.not_nas, outcome)
var2 = compute_variance(res2.gammas, res2.predictions, res2.not_nas, outcome)
gammas_out = {"ds1": res1.gammas, "ds2": res2.gammas}
coefficients_out = {"ds1": res1.coef_t, "ds2": res2.coef_t}
imbalances_out: dict = {}
else:
ipw_method = balancing if balancing != "aipw" else "aipw"
res1_ipw = compute_ipw_estimator(
n_periods,
outcome,
treatment_matrix,
covariates_t,
ds1_arr,
method=ipw_method,
regularization=regularization,
lags=lags,
dim_fe=dim_fe,
)
res2_ipw = compute_ipw_estimator(
n_periods,
outcome,
treatment_matrix,
covariates_t,
ds2_arr,
method=ipw_method,
regularization=regularization,
lags=lags,
dim_fe=dim_fe,
)
mu1 = res1_ipw.mu_hat
mu2 = res2_ipw.mu_hat
var1 = res1_ipw.variance
var2 = res2_ipw.variance
gammas_out = {}
coefficients_out = {}
imbalances_out = {}
ate = mu1 - mu2
var_ate = var1 + var2
quantiles = compute_quantiles(alp, n_periods, robust_quantile)
estimation_params = {
"yname": yname,
"tname": tname,
"idname": idname,
"treatment_name": treatment_name,
"balancing": balancing,
"method": method,
"n_units": config.n_units,
"n_obs": len(dp.panel),
"n_periods": n_periods,
"ds1": list(ds1),
"ds2": list(ds2),
"alp": alp,
"adaptive_balancing": adaptive_balancing,
"debias": debias,
"clustervars": clustervars,
}
return DynBalancingResult(
att=ate,
var_att=var_ate,
mu1=mu1,
mu2=mu2,
var_mu1=var1,
var_mu2=var2,
robust_quantile=quantiles.robust_quantile_ate,
gaussian_quantile=quantiles.gaussian_quantile_ate,
gammas=gammas_out,
coefficients=coefficients_out,
imbalances=imbalances_out,
estimation_params=estimation_params,
)
def _run_history(
*, ds1, ds2, histories_length, impulse_response=False, n_jobs=1, **kwargs
) -> DynBalancingHistoryResult:
"""Dispatch for histories_length mode."""
if not histories_length:
raise ValueError("histories_length must be a non-empty list.")
t_all = len(ds1)
for h in histories_length:
if h < 1 or h > t_all:
raise ValueError(f"All entries in histories_length must be between 1 and {t_all} (len(ds1)), got {h}.")
sorted_lengths = sorted(histories_length)
if impulse_response:
args_list = [([1] + [0] * (h - 1), [0] * h, kwargs) for h in sorted_lengths]
else:
args_list = [(ds1[-h:], ds2[-h:], kwargs) for h in sorted_lengths]
results = parallel_map(_call_dyn_balancing, args_list, n_jobs=n_jobs)
summary = pl.DataFrame(
{
"period_length": sorted_lengths,
"att": [r.att for r in results],
"var_att": [r.var_att for r in results],
"mu1": [r.mu1 for r in results],
"var_mu1": [r.var_mu1 for r in results],
"mu2": [r.mu2 for r in results],
"var_mu2": [r.var_mu2 for r in results],
"robust_quantile": [r.robust_quantile for r in results],
"gaussian_quantile": [r.gaussian_quantile for r in results],
}
)
return DynBalancingHistoryResult(summary=summary, results=results)
def _run_het(*, ds1, ds2, final_periods, n_jobs=1, **kwargs) -> DynBalancingHetResult:
"""Dispatch for final_periods mode."""
if not final_periods:
raise ValueError("final_periods must be a non-empty list.")
sorted_periods = sorted(final_periods)
args_list = [(ds1, ds2, {**kwargs, "final_period": p}) for p in sorted_periods]
results = parallel_map(_call_dyn_balancing, args_list, n_jobs=n_jobs)
summary = pl.DataFrame(
{
"final_period": sorted_periods,
"att": [r.att for r in results],
"var_att": [r.var_att for r in results],
"mu1": [r.mu1 for r in results],
"var_mu1": [r.var_mu1 for r in results],
"mu2": [r.mu2 for r in results],
"var_mu2": [r.var_mu2 for r in results],
"robust_quantile": [r.robust_quantile for r in results],
"gaussian_quantile": [r.gaussian_quantile for r in results],
}
)
return DynBalancingHetResult(summary=summary, results=results)
def _call_dyn_balancing(dd1, dd2, kwargs):
"""Call dyn_balancing with unpacked arguments for parallel_map."""
return dyn_balancing(ds1=dd1, ds2=dd2, **kwargs)
def _reindex_covariates(covariate_dict: dict[int, np.ndarray], time_periods: np.ndarray) -> dict[int, np.ndarray]:
"""Re-key covariate dict from actual period values to 0-based indices."""
sorted_periods = sorted(time_periods)
return {i: covariate_dict[p] for i, p in enumerate(sorted_periods) if p in covariate_dict}