"""Variance estimation."""
import numpy as np
import polars as pl
from scipy import stats
from .numba import compute_cluster_sums
def build_treatment_paths(df, horizon, config):
r"""Build hierarchical cohort identifiers based on treatment trajectories.
Constructs treatment path variables that track each group's treatment history
from period :math:`F_g` (first switch) through :math:`F_g - 1 + \ell`. Groups
are assigned to cohorts based on their baseline treatment :math:`D_{g,1}`,
switch timing :math:`F_g`, and subsequent treatment values. This enables
comparison of switchers only to non-switchers with the same baseline treatment,
which is required for the parallel trends assumption in [1]_.
Parameters
----------
df : pl.DataFrame
Data with F_g (first switch period) and d_sq (baseline treatment) columns.
horizon : int
Current horizon :math:`\ell` being computed.
config : DIDInterConfig
Configuration object with column names.
Returns
-------
pl.DataFrame
DataFrame with path_0, path_1, ..., path_h columns identifying treatment
trajectories, and validity flags for cohorts with sufficient observations.
References
----------
.. [1] de Chaisemartin, C., & D'Haultfoeuille, X. (2024). Difference-in-
Differences Estimators of Intertemporal Treatment Effects.
*Review of Economics and Statistics*, 106(6), 1723-1736.
"""
gname = config.gname
tname = config.tname
dname = config.dname
h = abs(horizon)
df = df.with_columns(
pl.when(pl.col(tname) == pl.col("F_g") + h - 1)
.then(pl.col(dname))
.otherwise(pl.lit(None))
.alias("_treat_at_horizon")
)
df = df.with_columns(pl.col("_treat_at_horizon").mean().over(gname).alias(f"treat_h{h}"))
if h == 1:
df = df.with_columns(pl.col("d_sq").alias("treat_h0"))
df = df.with_columns(pl.struct(["treat_h0", "F_g"]).hash(seed=42).alias("path_0"))
if h > 1 and f"treat_h{h - 1}" in df.columns:
df = df.with_columns(
pl.when(pl.col(f"treat_h{h}").is_null())
.then(pl.col(f"treat_h{h - 1}"))
.otherwise(pl.col(f"treat_h{h}"))
.alias(f"treat_h{h}")
)
prev_path = f"path_{h - 1}" if h > 1 else "path_0"
if prev_path in df.columns:
df = df.with_columns(pl.struct([prev_path, f"treat_h{h}"]).hash(seed=42).alias(f"path_{h}"))
if h == 1 and "path_0" in df.columns:
df = df.with_columns(pl.col(gname).n_unique().over("path_0").alias("n_groups_path_0"))
df = df.with_columns((pl.col("n_groups_path_0") > 1).cast(pl.Int64).alias("valid_cohort_0"))
if f"path_{h}" in df.columns:
df = df.with_columns(pl.col(gname).n_unique().over(f"path_{h}").alias(f"n_groups_path_{h}"))
df = df.with_columns((pl.col(f"n_groups_path_{h}") > 1).cast(pl.Int64).alias(f"valid_cohort_{h}"))
df = df.drop("_treat_at_horizon")
return df
def compute_cohort_dof(df, horizon, config, cluster_col=None):
"""Compute DOF-adjusted cohort means using hierarchical path groupings.
Parameters
----------
df : pl.DataFrame
Data with path columns and switcher flags.
horizon : int
Current horizon.
config : DIDInterConfig
Configuration object.
cluster_col : str or None
Column name for cluster-robust DOF counting.
Returns
-------
pl.DataFrame
DataFrame with dof_switcher_{h} and cohort_mean_{h} columns.
"""
h = abs(horizon)
trends = config.trends_nonparam or []
switcher_flag = f"is_switcher_{h}"
weighted_diff = f"weighted_diff_{h}"
dist_col = f"dist_to_switch_{h}"
is_switcher = pl.col(switcher_flag) == 1
base_group_vars = ["d_sq", "F_g", "d_fg", dist_col]
group_vars = base_group_vars + list(trends)
group_vars = [c for c in group_vars if c in df.columns]
weight_sum_col = f"weight_sum_{h}_switcher"
diff_sum_col = f"diff_sum_{h}_switcher"
val_weight = pl.when(is_switcher).then(pl.col("weight_gt")).otherwise(None)
val_diff = pl.when(is_switcher).then(pl.col(weighted_diff)).otherwise(None)
df = df.with_columns(
pl.when(is_switcher).then(val_weight.sum().over(group_vars)).otherwise(None).alias(weight_sum_col),
pl.when(is_switcher).then(val_diff.sum().over(group_vars)).otherwise(None).alias(diff_sum_col),
)
dof_col = f"dof_switcher_{h}"
if cluster_col is None:
val_dof = pl.when(is_switcher).then(pl.col(switcher_flag)).otherwise(None)
df = df.with_columns(pl.when(is_switcher).then(val_dof.sum().over(group_vars)).otherwise(None).alias(dof_col))
else:
cluster_flag = f"_cluster_flag_{h}"
df = df.with_columns(pl.when(is_switcher).then(pl.col(cluster_col)).otherwise(None).alias(cluster_flag))
df = df.with_columns(
pl.when(pl.col(cluster_flag).is_not_null())
.then(pl.col(cluster_flag).n_unique().over(group_vars))
.otherwise(None)
.alias(dof_col)
)
df = df.drop(cluster_flag)
ws = pl.col(weight_sum_col).fill_null(1.0)
ds = pl.col(diff_sum_col).fill_null(0.0)
df = df.with_columns((ds / ws).alias(f"cohort_mean_{h}"))
return df
def compute_control_dof(df, horizon, config, cluster_col=None):
"""Compute DOF for control units (non-switchers).
Parameters
----------
df : pl.DataFrame
Data with never_change column.
horizon : int
Current horizon.
config : DIDInterConfig
Configuration object.
cluster_col : str or None
Column name for cluster-robust DOF counting.
Returns
-------
pl.DataFrame
DataFrame with dof_control_{h} and control_mean_{h} columns.
"""
h = abs(horizon)
tname = config.tname
trends = config.trends_nonparam or []
never_col = f"never_change_{h}"
weighted_diff = f"weighted_diff_{h}"
if never_col not in df.columns:
return df
is_control = pl.col(never_col) == 1.0
group_vars = [tname, "d_sq", *list(trends)]
weight_sum_col = f"control_weight_sum_{h}"
diff_sum_col = f"control_diff_sum_{h}"
dof_col = f"dof_control_{h}"
mean_col = f"control_mean_{h}"
val_weight = pl.when(is_control).then(pl.col("weight_gt")).otherwise(None)
val_diff = pl.when(is_control).then(pl.col(weighted_diff)).otherwise(None)
df = df.with_columns(
pl.when(is_control).then(val_weight.sum().over(group_vars)).otherwise(None).alias(weight_sum_col),
pl.when(is_control).then(val_diff.sum().over(group_vars)).otherwise(None).alias(diff_sum_col),
)
if cluster_col is None:
val_dof = pl.when(is_control).then(pl.lit(1)).otherwise(None)
df = df.with_columns(pl.when(is_control).then(val_dof.sum().over(group_vars)).otherwise(None).alias(dof_col))
else:
cluster_flag = f"_control_cluster_{h}"
df = df.with_columns(pl.when(is_control).then(pl.col(cluster_col)).otherwise(None).alias(cluster_flag))
df = df.with_columns(
pl.when(pl.col(cluster_flag).is_not_null())
.then(pl.col(cluster_flag).n_unique().over(group_vars))
.otherwise(None)
.alias(dof_col)
)
df = df.drop(cluster_flag)
df = df.with_columns(
pl.when(pl.col(weight_sum_col) > 0)
.then(pl.col(diff_sum_col) / pl.col(weight_sum_col))
.otherwise(pl.lit(0.0))
.alias(mean_col)
)
return df
def compute_union_dof(df, horizon, config, cluster_col=None):
"""Compute DOF for the union of switchers and controls.
Parameters
----------
df : pl.DataFrame
Data with switcher and control flags.
horizon : int
Current horizon.
config : DIDInterConfig
Configuration object.
cluster_col : str or None
Column name for cluster-robust DOF counting.
Returns
-------
pl.DataFrame
DataFrame with dof_union_{h} and union_mean_{h} columns.
"""
h = abs(horizon)
tname = config.tname
trends = config.trends_nonparam or []
switcher_flag = f"is_switcher_{h}"
never_col = f"never_change_{h}"
weighted_diff = f"weighted_diff_{h}"
if switcher_flag not in df.columns or never_col not in df.columns:
return df
is_union = (pl.col(switcher_flag) == 1) | (pl.col(never_col) == 1.0)
group_vars = [tname, "d_sq", *list(trends)]
union_flag = f"is_union_{h}"
weight_sum_col = f"union_weight_sum_{h}"
diff_sum_col = f"union_diff_sum_{h}"
dof_col = f"dof_union_{h}"
mean_col = f"union_mean_{h}"
df = df.with_columns(is_union.cast(pl.Int64).alias(union_flag))
val_weight = pl.when(is_union).then(pl.col("weight_gt")).otherwise(None)
val_diff = pl.when(is_union).then(pl.col(weighted_diff)).otherwise(None)
df = df.with_columns(
pl.when(is_union).then(val_weight.sum().over(group_vars)).otherwise(None).alias(weight_sum_col),
pl.when(is_union).then(val_diff.sum().over(group_vars)).otherwise(None).alias(diff_sum_col),
)
if cluster_col is None:
val_dof = pl.when(is_union).then(pl.col(union_flag)).otherwise(None)
df = df.with_columns(pl.when(is_union).then(val_dof.sum().over(group_vars)).otherwise(None).alias(dof_col))
else:
cluster_flag = f"_union_cluster_{h}"
df = df.with_columns(pl.when(is_union).then(pl.col(cluster_col)).otherwise(None).alias(cluster_flag))
df = df.with_columns(
pl.when(pl.col(cluster_flag).is_not_null())
.then(pl.col(cluster_flag).n_unique().over(group_vars))
.otherwise(None)
.alias(dof_col)
)
df = df.drop(cluster_flag)
df = df.with_columns(
pl.when(pl.col(weight_sum_col) > 0)
.then(pl.col(diff_sum_col) / pl.col(weight_sum_col))
.otherwise(pl.lit(0.0))
.alias(mean_col)
)
return df
def compute_e_hat(df, horizon, config):
"""Compute cohort mean with 3-level DOF fallback.
Parameters
----------
df : pl.DataFrame
Data with DOF and mean columns.
horizon : int
Current horizon.
config : DIDInterConfig
Configuration object.
Returns
-------
pl.DataFrame
DataFrame with E_hat_{h} column.
"""
h = abs(horizon)
tname = config.tname
e_hat_col = f"E_hat_{h}"
dof_s_col = f"dof_switcher_{h}"
dof_ns_col = f"dof_control_{h}"
dof_union_col = f"dof_union_{h}"
mean_s_col = f"cohort_mean_{h}"
mean_ns_col = f"control_mean_{h}"
mean_union_col = f"union_mean_{h}"
time = pl.col(tname)
fg = pl.col("F_g")
at_target = (fg - 1 + h) == time
before_switch = time < fg
relevant = at_target | before_switch
dof_s = pl.col(dof_s_col) if dof_s_col in df.columns else pl.lit(None)
dof_ns = pl.col(dof_ns_col) if dof_ns_col in df.columns else pl.lit(None)
dof_union = pl.col(dof_union_col) if dof_union_col in df.columns else pl.lit(None)
mean_s = pl.col(mean_s_col) if mean_s_col in df.columns else pl.lit(0.0)
mean_ns = pl.col(mean_ns_col) if mean_ns_col in df.columns else pl.lit(0.0)
mean_union = pl.col(mean_union_col) if mean_union_col in df.columns else pl.lit(0.0)
s_safe = dof_s.fill_null(9999)
ns_safe = dof_ns.fill_null(9999)
union_safe = dof_union.fill_null(9999)
use_switcher_mean = at_target & (s_safe >= 2)
use_control_mean = before_switch & (ns_safe >= 2)
use_union_mean = (union_safe >= 2) & ((at_target & (s_safe == 1)) | (before_switch & (ns_safe == 1)))
df = df.with_columns(
pl.when(~relevant)
.then(pl.lit(None))
.when(use_switcher_mean)
.then(mean_s)
.when(use_control_mean)
.then(mean_ns)
.when(use_union_mean)
.then(mean_union)
.otherwise(pl.lit(0.0))
.alias(e_hat_col)
)
return df
def compute_dof_scaling(df, horizon, config):
"""Compute DOF scaling factor for variance correction with 3-level fallback.
Parameters
----------
df : pl.DataFrame
Data with dof_switcher, dof_control, and dof_union columns.
horizon : int
Current horizon.
config : DIDInterConfig
Configuration object. Uses ``less_conservative_se`` to control whether
DOF adjustments are applied.
Returns
-------
pl.DataFrame
DataFrame with dof_scale_{h} column.
"""
h = abs(horizon)
tname = config.tname
dof_col = f"dof_scale_{h}"
dof_s_col = f"dof_switcher_{h}"
dof_ns_col = f"dof_control_{h}"
dof_union_col = f"dof_union_{h}"
df = df.with_columns(pl.lit(1.0).alias(dof_col))
if getattr(config, "less_conservative_se", False):
return df
time = pl.col(tname)
fg = pl.col("F_g")
at_target_time = (fg - 1 + h) == time
before_switch = time < fg
dof_s = pl.col(dof_s_col) if dof_s_col in df.columns else pl.lit(9999)
dof_ns = pl.col(dof_ns_col) if dof_ns_col in df.columns else pl.lit(9999)
dof_union = pl.col(dof_union_col) if dof_union_col in df.columns else pl.lit(9999)
s_safe = dof_s.fill_null(9999)
ns_safe = dof_ns.fill_null(9999)
union_safe = dof_union.fill_null(9999)
use_s_dof = at_target_time & (s_safe > 1)
use_ns_dof = before_switch & (ns_safe > 1)
use_union_s = at_target_time & (s_safe == 1) & (union_safe >= 2)
use_union_ns = before_switch & (ns_safe == 1) & (union_safe >= 2)
df = df.with_columns(
pl.when(use_s_dof)
.then((dof_s / (dof_s - 1)).sqrt())
.when(use_ns_dof)
.then((dof_ns / (dof_ns - 1)).sqrt())
.when(use_union_s | use_union_ns)
.then((dof_union / (dof_union - 1)).sqrt())
.otherwise(pl.col(dof_col))
.alias(dof_col)
)
return df
[docs]
def compute_clustered_variance(influence_func, cluster_ids, n_groups):
r"""Compute clustered standard error from influence function.
Computes standard errors for :math:`\text{DID}_\ell` estimators using the
influence function approach with optional clustering. The variance is computed as
.. math::
\widehat{\text{Var}}(\text{DID}_\ell) = \frac{1}{G^2}
\sum_{c=1}^{C} \left(\sum_{g \in c} \psi_g\right)^2
where :math:`\psi_g` is the influence function for group :math:`g`, :math:`G`
is the total number of groups, and :math:`C` is the number of clusters. When
not clustering, each group is its own cluster.
Parameters
----------
influence_func : ndarray
Influence function values :math:`\psi_g` for each group.
cluster_ids : ndarray
Cluster identifiers for each group.
n_groups : int
Total number of groups :math:`G`.
Returns
-------
float
Clustered standard error.
References
----------
.. [1] de Chaisemartin, C., & D'Haultfoeuille, X. (2024). Difference-in-
Differences Estimators of Intertemporal Treatment Effects.
*Review of Economics and Statistics*, 106(6), 1723-1736.
"""
cluster_sums, unique_clusters = compute_cluster_sums(influence_func, cluster_ids)
n_clusters = len(unique_clusters)
if n_clusters <= 1:
return np.sqrt(np.sum(influence_func**2)) / n_groups
std_error = np.sqrt(np.sum(cluster_sums**2)) / n_groups
return std_error
[docs]
def compute_joint_test(estimates, vcov):
r"""Compute joint Wald test that all estimates are zero.
Computes a chi-squared test statistic for the null hypothesis
:math:`H_0: \delta_1 = \delta_2 = \cdots = \delta_L = 0`. For placebo
effects, this tests the parallel trends assumption by checking whether
pre-treatment outcome trends differ between switchers and non-switchers.
The test statistic is
.. math::
W = \hat{\boldsymbol{\delta}}' \widehat{\mathbf{V}}^{-1} \hat{\boldsymbol{\delta}}
\sim \chi^2_L
where :math:`\hat{\boldsymbol{\delta}}` is the vector of estimates and
:math:`\widehat{\mathbf{V}}` is the variance-covariance matrix.
Parameters
----------
estimates : ndarray
Point estimates :math:`\hat{\boldsymbol{\delta}}`.
vcov : ndarray
Variance-covariance matrix :math:`\widehat{\mathbf{V}}`.
Returns
-------
dict or None
Dictionary with chi2_stat, df, p_value, and warnings list,
or None if computation fails.
References
----------
.. [1] de Chaisemartin, C., & D'Haultfoeuille, X. (2024). Difference-in-
Differences Estimators of Intertemporal Treatment Effects.
*Review of Economics and Statistics*, 106(6), 1723-1736.
"""
if vcov is None:
return None
valid_mask = ~np.isnan(estimates)
if np.sum(valid_mask) < 1:
return None
valid_estimates = estimates[valid_mask]
valid_vcov = vcov[np.ix_(valid_mask, valid_mask)]
warnings_list = []
eigenvalues = np.linalg.eigvalsh(valid_vcov)
positive_eigenvalues = eigenvalues[eigenvalues > 1e-10]
if len(positive_eigenvalues) < len(valid_estimates):
warnings_list.append(
"The variance-covariance matrix of the effects tested is not invertible. The test cannot be computed."
)
return {
"chi2_stat": np.nan,
"df": len(valid_estimates),
"p_value": np.nan,
"warnings": warnings_list,
}
condition_ratio = positive_eigenvalues.max() / positive_eigenvalues.min()
if condition_ratio >= 1000:
warnings_list.append(
"The variance-covariance matrix of the effects tested is close "
f"to singular (condition ratio: {condition_ratio:.1f}). The chi-squared test "
"may be unreliable."
)
try:
chi2_stat = float(valid_estimates @ np.linalg.pinv(valid_vcov) @ valid_estimates)
df = len(valid_estimates)
p_value = 1 - stats.chi2.cdf(chi2_stat, df)
return {
"chi2_stat": chi2_stat,
"df": df,
"p_value": p_value,
"warnings": warnings_list,
}
except np.linalg.LinAlgError:
return None