"""Multi-period difference-in-differences group-time ATT computation."""
from __future__ import annotations
import warnings
from typing import NamedTuple
import numpy as np
import scipy.sparse as sp
import statsmodels.api as sm
from statsmodels.tools.sm_exceptions import PerfectSeparationError
from moderndid.core.parallel import parallel_map
from moderndid.core.preprocess import ControlGroup, DIDData, EstimationMethod
from moderndid.cupy.backend import get_backend
from moderndid.drdid.estimators.drdid_panel import drdid_panel
from moderndid.drdid.estimators.drdid_rc import drdid_rc
from moderndid.drdid.estimators.reg_did_panel import reg_did_panel
from moderndid.drdid.estimators.reg_did_rc import reg_did_rc
from moderndid.drdid.estimators.std_ipw_did_panel import std_ipw_did_panel
from moderndid.drdid.estimators.std_ipw_did_rc import std_ipw_did_rc
[docs]
class ATTgtResult(NamedTuple):
"""Container for a single group-time ATT estimate.
Attributes
----------
att : float
The estimated ATT for this group-time cell.
group : float
The treatment group identifier (first treatment period).
year : float
The time period for this estimate.
post : int
Indicator for whether this is a post-treatment period (1) or
pre-treatment period (0).
"""
#: Estimated ATT for this group-time cell.
att: float
#: Treatment group identifier (first treatment period).
group: float
#: Time period for this estimate.
year: float
#: Indicator for post-treatment (1) or pre-treatment (0).
post: int
[docs]
class ComputeATTgtResult(NamedTuple):
"""Container for compute_att_gt results.
Attributes
----------
attgt_list : list[ATTgtResult]
List of group-time ATT results, one per (group, time) cell.
influence_functions : scipy.sparse.csr_matrix
Sparse matrix of influence functions with shape
(n_units, n_group_time_cells).
"""
#: List of group-time ATT results, one per (group, time) cell.
attgt_list: list[ATTgtResult]
#: Sparse matrix of influence functions.
influence_functions: sp.csr_matrix
def compute_att_gt(data: DIDData, n_jobs=1):
"""Compute group-time average treatment effects.
Parameters
----------
data : DIDData
Preprocessed DiD data object containing all necessary data and configuration.
n_jobs : int, default=1
Number of parallel jobs. 1 = sequential, -1 = all cores, >1 = that many workers.
Returns
-------
ComputeATTgtResult
NamedTuple containing list of ATT results and influence functions
"""
n_units = data.config.id_count
time_periods = data.config.time_periods
n_time_periods = len(time_periods) - 1 if data.config.base_period != "universal" else len(time_periods)
group_time_pairs = [(g, t) for g in range(data.config.treated_groups_count) for t in range(n_time_periods)]
args_list = [(g_idx, t_idx, data) for g_idx, t_idx in group_time_pairs]
cell_results = parallel_map(_process_gt_cell_did, args_list, n_jobs=n_jobs)
att_results = []
influence_func_list = []
for att_result, inf_func in cell_results:
if att_result is not None:
att_results.append(att_result)
influence_func_list.append(inf_func)
if influence_func_list:
xp = get_backend()
if xp is not np:
try:
import cupyx.scipy.sparse as cusp
influence_matrix = xp.column_stack([xp.asarray(f) for f in influence_func_list])
sparse_influence_funcs = cusp.csr_matrix(influence_matrix)
except ImportError:
influence_matrix = np.column_stack(influence_func_list)
sparse_influence_funcs = sp.csr_matrix(influence_matrix)
else:
influence_matrix = np.column_stack(influence_func_list)
sparse_influence_funcs = sp.csr_matrix(influence_matrix)
else:
sparse_influence_funcs = sp.csr_matrix((n_units, 0))
return ComputeATTgtResult(attgt_list=att_results, influence_functions=sparse_influence_funcs)
def run_att_gt_estimation(
group_idx,
time_idx,
data,
):
"""Run ATT estimation for a given group-time pair.
Parameters
----------
group_idx : int
Index of the treated group.
time_idx : int
Index of the time period.
data : DIDData
Preprocessed DiD data object.
Returns
-------
dict or None
Dictionary with ATT and influence function, or None if estimation not feasible.
"""
time_factor = 1 if data.config.base_period != "universal" else 0
if data.config.base_period == "universal":
pre_periods = np.where(
data.config.time_periods < (data.config.treated_groups[group_idx] - data.config.anticipation)
)[0]
pre_treatment_idx = pre_periods[-1] if len(pre_periods) > 0 else None
else:
pre_treatment_idx = time_idx
is_post_treatment = data.config.treated_groups[group_idx] <= data.config.time_periods[time_idx + time_factor]
if is_post_treatment and data.config.base_period != "universal":
pre_periods = np.where(
data.config.time_periods < (data.config.treated_groups[group_idx] - data.config.anticipation)
)[0]
if len(pre_periods) == 0:
warnings.warn(
f"No pre-treatment periods for group first treated at {data.config.treated_groups[group_idx]}. "
"Units from this group are dropped.",
UserWarning,
)
return None
pre_treatment_idx = pre_periods[-1]
if (
data.config.base_period == "universal"
and pre_treatment_idx is not None
and data.config.time_periods[pre_treatment_idx] == data.config.time_periods[time_idx + time_factor]
):
return None
cohort_index = get_did_cohort_index(group_idx, time_idx, time_factor, pre_treatment_idx, data)
has_treated = np.any(cohort_index == 1)
has_control = np.any(cohort_index == 0)
if not (has_treated and has_control):
return None
if data.config.panel:
cohort_data = {
"D": cohort_index,
"y1": data.outcomes_tensor[time_idx + time_factor],
"y0": data.outcomes_tensor[pre_treatment_idx],
"weights": data.weights,
}
covariates = data.covariates_tensor[min(pre_treatment_idx, time_idx)]
else:
post_mask = (data.data[data.config.tname] == data.config.time_periods[time_idx + time_factor]).to_numpy()
cohort_data = {
"D": cohort_index,
"y": data.data[data.config.yname].to_numpy(),
"post": post_mask.astype(int),
"weights": data.data["weights"].to_numpy(),
}
if data.config.allow_unbalanced_panel:
cohort_data["rowid"] = data.data[".rowid"].to_numpy()
covariates = data.covariates_matrix
if not callable(data.config.est_method):
valid_obs = ~np.isnan(cohort_index)
d_valid = cohort_index[valid_obs]
g_val = data.config.treated_groups[group_idx]
t_val = data.config.time_periods[time_idx + time_factor]
if not data.config.panel:
post = cohort_data["post"][valid_obs]
G = d_valid == 1
C = d_valid == 0
skip = False
if np.sum(G & (post == 1)) == 0:
warnings.warn(f"No treated units in group {g_val} in time period {t_val}", UserWarning)
skip = True
if np.sum(G & (post == 0)) == 0:
warnings.warn(f"No treated units in group {g_val} in pre-treatment period", UserWarning)
skip = True
if np.sum(C & (post == 1)) == 0:
warnings.warn(f"No control units for group {g_val} in time period {t_val}", UserWarning)
skip = True
if np.sum(C & (post == 0)) == 0:
warnings.warn(f"No control units for group {g_val} in pre-treatment period", UserWarning)
skip = True
if skip:
return None
overlap_violated = False
reg_ill_conditioned = False
if data.config.est_method in (EstimationMethod.DOUBLY_ROBUST, EstimationMethod.IPW):
cov_valid = covariates[valid_obs] if covariates.ndim > 1 else np.ones((valid_obs.sum(), 1))
G_valid = (d_valid == 1).astype(float)
try:
logit_model = sm.GLM(G_valid, cov_valid, family=sm.families.Binomial())
logit_result = logit_model.fit(maxiter=100, disp=False)
pscores = logit_result.fittedvalues
if np.max(pscores) >= 0.999:
overlap_violated = True
warnings.warn(
f"Overlap condition violated for {g_val} in time period {t_val}",
UserWarning,
)
except (np.linalg.LinAlgError, PerfectSeparationError):
pass
if data.config.est_method in (EstimationMethod.DOUBLY_ROBUST, EstimationMethod.REGRESSION):
cov_valid = covariates[valid_obs] if covariates.ndim > 1 else np.ones((valid_obs.sum(), 1))
control_covs = cov_valid[d_valid == 0]
if control_covs.shape[0] > 0:
gram = control_covs.T @ control_covs
try:
cond = np.linalg.cond(gram)
if cond > 1.0 / np.finfo(float).eps:
reg_ill_conditioned = True
warnings.warn(
f"Not enough control units for group {g_val} in time period {t_val} "
"to run specified regression",
UserWarning,
)
except np.linalg.LinAlgError:
reg_ill_conditioned = True
warnings.warn(
f"Singular covariate matrix for group {g_val} in time period {t_val}",
UserWarning,
)
if overlap_violated or reg_ill_conditioned:
return None
try:
return run_drdid(cohort_data, covariates, data)
except (ValueError, RuntimeError, np.linalg.LinAlgError) as e:
warnings.warn(
f"Error in computing 2x2 DiD for (g,t) = ({data.config.treated_groups[group_idx]},"
f"{data.config.time_periods[time_idx]}): {e}",
UserWarning,
)
return None
def get_did_cohort_index(
group_idx,
time_idx,
time_factor,
pre_treatment_idx,
data,
):
"""Get cohort indices for current group-time pair.
Parameters
----------
group_idx : int
Index of the treated group.
time_idx : int
Index of the time period.
time_factor : int
Time factor (1 for varying base period, 0 for universal).
pre_treatment_idx : int
Index of the pre-treatment period.
data : DIDData
Preprocessed DiD data object.
Returns
-------
np.ndarray
Array of 1s (treated), 0s (control), and NaNs indicating cohort membership.
"""
if data.config.panel:
# Determine control group boundaries
if data.config.control_group == ControlGroup.NOT_YET_TREATED:
# Find first cohort treated after the relevant period
relevant_period = (
data.config.time_periods[max(time_idx, pre_treatment_idx) + time_factor] + data.config.anticipation
)
future_cohorts = data.cohort_counts.filter(data.cohort_counts["cohort"] > relevant_period)
min_control = future_cohorts["cohort"][0] if len(future_cohorts) > 0 else np.inf
else: # nevertreated
min_control = np.inf
max_control = np.inf
n_units = len(data.time_invariant_data) if data.config.allow_unbalanced_panel else data.config.id_count
cohort_index = np.full(n_units, np.nan)
cohort_values = data.cohort_counts["cohort"].to_numpy()
if max_control not in cohort_values:
max_control = cohort_values[-1]
# Control group indices - this includes all matching cohorts
control_mask = (data.cohort_counts["cohort"] >= min_control) & (data.cohort_counts["cohort"] <= max_control)
control_mask_np = control_mask.to_numpy()
cohort_sizes = data.cohort_counts["cohort_size"].to_numpy()
for idx in np.where(control_mask_np)[0]:
start_control = int(cohort_sizes[:idx].sum()) if idx > 0 else 0
end_control = int(cohort_sizes[: idx + 1].sum())
cohort_index[start_control:end_control] = 0
# Treated group indices
treated_mask = data.cohort_counts["cohort"] == data.config.treated_groups[group_idx]
treated_mask_np = treated_mask.to_numpy()
if treated_mask_np.any():
treat_idx = int(np.argmax(treated_mask_np))
cohort_sizes = data.cohort_counts["cohort_size"].to_numpy()
start_treat = int(cohort_sizes[:treat_idx].sum()) if treat_idx > 0 else 0
end_treat = int(cohort_sizes[: treat_idx + 1].sum())
cohort_index[start_treat:end_treat] = 1
else:
n_units = len(data.data)
cohort_index = np.full(n_units, np.nan)
treated_flag = (data.data[data.config.gname] == data.config.treated_groups[group_idx]).to_numpy()
if data.config.control_group == ControlGroup.NEVER_TREATED:
control_flag = (data.data[data.config.gname] == np.inf).to_numpy()
else: # NOT_YET_TREATED
relevant_period = (
data.config.time_periods[max(time_idx, pre_treatment_idx) + time_factor] + data.config.anticipation
)
control_flag = (
(data.data[data.config.gname] == np.inf)
| (
(data.data[data.config.gname] > relevant_period)
& (data.data[data.config.gname] != data.config.treated_groups[group_idx])
)
).to_numpy()
keep_periods = (
data.data[data.config.tname]
.is_in([data.config.time_periods[time_idx + time_factor], data.config.time_periods[pre_treatment_idx]])
.to_numpy()
)
cohort_index[keep_periods & control_flag] = 0
cohort_index[keep_periods & treated_flag] = 1
return cohort_index
def run_drdid(
cohort_data,
covariates,
data,
):
"""Run DR-DiD estimation for current group-time pair.
Parameters
----------
cohort_data : dict
Dictionary containing outcome and treatment data for the cohort.
covariates : ndarray
Covariate matrix for the estimation.
data : DIDData
Preprocessed DiD data object.
Returns
-------
dict
Dictionary with ATT estimate and influence function.
"""
n = len(cohort_data["D"])
est_method = data.config.est_method
valid_obs = ~np.isnan(cohort_data["D"])
if valid_obs.sum() == 0:
return {"att": np.nan, "inf_func": np.full(n, np.nan)}
if data.config.panel:
y1 = cohort_data["y1"][valid_obs]
y0 = cohort_data["y0"][valid_obs]
d = cohort_data["D"][valid_obs]
weights = cohort_data["weights"][valid_obs]
if covariates.ndim > 1:
cov_valid = covariates[valid_obs]
else:
cov_valid = covariates[valid_obs] if len(covariates) > 1 else np.ones(valid_obs.sum())
if callable(est_method):
result = est_method(y1=y1, y0=y0, d=d, covariates=cov_valid, i_weights=weights, influence_func=True)
elif est_method == EstimationMethod.IPW:
result = std_ipw_did_panel(
y1=y1, y0=y0, d=d, covariates=cov_valid, i_weights=weights, boot=False, influence_func=True
)
elif est_method == EstimationMethod.REGRESSION:
result = reg_did_panel(
y1=y1, y0=y0, d=d, covariates=cov_valid, i_weights=weights, boot=False, influence_func=True
)
else: # DOUBLY_ROBUST (default)
result = drdid_panel(
y1=y1, y0=y0, d=d, covariates=cov_valid, i_weights=weights, boot=False, influence_func=True
)
influence_func = np.zeros(n)
influence_func[valid_obs] = (n / valid_obs.sum()) * result.att_inf_func
else:
y = cohort_data["y"][valid_obs]
post = cohort_data["post"][valid_obs]
d = cohort_data["D"][valid_obs]
weights = cohort_data["weights"][valid_obs]
if covariates.ndim > 1:
cov_valid = covariates[valid_obs]
else:
cov_valid = covariates if len(covariates) == n else covariates[valid_obs]
if callable(est_method):
result = est_method(y=y, post=post, d=d, covariates=cov_valid, i_weights=weights, influence_func=True)
elif est_method == EstimationMethod.IPW:
result = std_ipw_did_rc(
y=y, post=post, d=d, covariates=cov_valid, i_weights=weights, boot=False, influence_func=True
)
elif est_method == EstimationMethod.REGRESSION:
result = reg_did_rc(
y=y, post=post, d=d, covariates=cov_valid, i_weights=weights, boot=False, influence_func=True
)
else: # DOUBLY_ROBUST (default)
result = drdid_rc(
y=y, post=post, d=d, covariates=cov_valid, i_weights=weights, boot=False, influence_func=True
)
# Handle influence function for unbalanced panel
if data.config.allow_unbalanced_panel and "rowid" in cohort_data:
inf_func_long = np.zeros(n)
inf_func_long[valid_obs] = (data.config.id_count / valid_obs.sum()) * result.att_inf_func
unique_ids = np.unique(cohort_data["rowid"])
influence_func = np.zeros(len(unique_ids))
for i, uid in enumerate(unique_ids):
mask = cohort_data["rowid"] == uid
influence_func[i] = inf_func_long[mask].sum()
else:
influence_func = np.zeros(n)
influence_func[valid_obs] = (n / valid_obs.sum()) * result.att_inf_func
if np.isnan(result.att):
return {"att": np.nan, "inf_func": np.full(n, np.nan)}
return {"att": result.att, "inf_func": influence_func}
def _process_gt_cell_did(group_idx, time_idx, data):
"""Process a single (group, time) cell for DiD estimation.
Returns
-------
tuple of (ATTgtResult or None, ndarray or None)
"""
n_units = data.config.id_count
time_factor = 1 if data.config.base_period != "universal" else 0
estimation_result = run_att_gt_estimation(group_idx, time_idx, data)
is_post_treatment = int(data.config.treated_groups[group_idx] <= data.config.time_periods[time_idx + time_factor])
if estimation_result is None or estimation_result["att"] is None:
if data.config.base_period == "universal":
pre_periods = np.where(
data.config.time_periods < (data.config.treated_groups[group_idx] - data.config.anticipation)
)[0]
is_reference = (
len(pre_periods) > 0
and data.config.time_periods[pre_periods[-1]] == data.config.time_periods[time_idx + time_factor]
)
att_val = 0.0 if is_reference else np.nan
inf_func = np.zeros(n_units) if is_reference else np.full(n_units, np.nan)
return (
ATTgtResult(
att=att_val,
group=data.config.treated_groups[group_idx],
year=data.config.time_periods[time_idx + time_factor],
post=is_post_treatment,
),
inf_func,
)
return (None, None)
att_estimate = estimation_result["att"]
influence_func = estimation_result["inf_func"]
if np.isnan(att_estimate):
influence_func = np.full(n_units, np.nan)
return (
ATTgtResult(
att=att_estimate,
group=data.config.treated_groups[group_idx],
year=data.config.time_periods[time_idx + time_factor],
post=is_post_treatment,
),
influence_func,
)