"""Plotnine-based plotting functions for moderndid."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal
from plotnine import (
aes,
element_text,
facet_wrap,
geom_errorbar,
geom_hline,
geom_line,
geom_point,
geom_ribbon,
geom_vline,
ggplot,
labs,
position_dodge,
scale_color_manual,
scale_x_continuous,
theme,
theme_gray,
)
from moderndid.core.converters import (
aggteresult_to_polars,
dddaggresult_to_polars,
dddmpresult_to_polars,
didinterresult_to_polars,
doseresult_to_polars,
honestdid_to_polars,
mpresult_to_polars,
pteresult_to_polars,
)
from moderndid.did.container import AGGTEResult, MPResult
from moderndid.didinter.container import DIDInterResult
from moderndid.didtriple.container import DDDAggResult, DDDMultiPeriodRCResult, DDDMultiPeriodResult
from moderndid.plots.themes import COLORS
if TYPE_CHECKING:
from moderndid.didcont.container import DoseResult, PTEResult
from moderndid.didhonest.honest_did import HonestDiDResult
[docs]
def plot_gt(
result: MPResult | DDDMultiPeriodResult | DDDMultiPeriodRCResult,
show_ci: bool = True,
ref_line: float | None = 0,
title: str = "Group",
xlab: str | None = None,
ylab: str | None = None,
ncol: int = 1,
**_kwargs: Any,
) -> ggplot:
"""Plot group-time average treatment effects.
Parameters
----------
result : MPResult, DDDMultiPeriodResult, or DDDMultiPeriodRCResult
Multi-period result object containing group-time ATT estimates.
This should be the output from ``att_gt()`` or ``ddd()``.
show_ci : bool, default=True
Whether to show confidence intervals as error bars.
ref_line : float or None, default=0
Y-value for reference line. Set to None to hide.
title : str, default="Group"
Title prefix for each facet panel.
xlab : str, optional
X-axis label. Defaults to "Time".
ylab : str, optional
Y-axis label. Defaults to "ATT".
ncol : int, default=1
Number of columns in the facet grid. Use 1 for vertical stacking.
Returns
-------
ggplot
A plotnine ggplot object that can be further customized.
"""
if isinstance(result, MPResult):
df = mpresult_to_polars(result)
plot_title = "Group-Time Average Treatment Effects"
elif isinstance(result, (DDDMultiPeriodResult, DDDMultiPeriodRCResult)):
df = dddmpresult_to_polars(result)
plot_title = "Group-Time DDD Treatment Effects"
else:
raise TypeError(
f"plot_gt requires MPResult, DDDMultiPeriodResult, or DDDMultiPeriodRCResult, got {type(result).__name__}"
)
df = df.with_columns([df["group"].cast(int).cast(str).alias("group_label")])
x_breaks = sorted(df["time"].unique().to_list())
plot = (
ggplot(df, aes(x="time", y="att", color="treatment_status"))
+ geom_point(size=3, alpha=0.8)
+ scale_color_manual(
values={"Pre": COLORS["pre_treatment"], "Post": COLORS["post_treatment"]},
limits=["Pre", "Post"],
name="Treatment Status",
)
+ scale_x_continuous(breaks=x_breaks)
+ facet_wrap("~group_label", ncol=ncol, labeller=lambda x: f"{title} {x}", scales="free_x")
+ labs(
x=xlab or "Time",
y=ylab or "ATT",
title=plot_title,
)
+ theme_gray()
+ theme(
strip_text=element_text(size=11, weight="bold"),
plot_title=element_text(margin={"b": 25}),
legend_position="bottom",
)
)
if show_ci:
plot = plot + geom_errorbar(
aes(ymin="ci_lower", ymax="ci_upper"),
width=0.3,
alpha=0.7,
)
if ref_line is not None:
plot = plot + geom_hline(yintercept=ref_line, linetype="dashed", color="black", alpha=0.5)
return plot
[docs]
def plot_event_study(
result: AGGTEResult | PTEResult | DDDAggResult,
show_ci: bool = True,
ref_line: float | None = 0,
ref_period: float | None = -1,
xlab: str | None = None,
ylab: str | None = None,
title: str | None = None,
**_kwargs: Any,
) -> ggplot:
"""Create event study plot for dynamic treatment effects.
Parameters
----------
result : AGGTEResult, PTEResult, or DDDAggResult
Aggregated treatment effect result with dynamic/eventstudy aggregation,
or PTEResult with event_study attribute.
show_ci : bool, default=True
Whether to show confidence intervals as error bars.
ref_line : float or None, default=0
Y-value for reference line. Set to None to hide.
ref_period : float or None, default=-1
X-value for vertical reference period line. Set to None to hide.
xlab : str, optional
X-axis label. Defaults to "Event Time".
ylab : str, optional
Y-axis label. Defaults to "ATT".
title : str, optional
Plot title. Defaults based on result type.
Returns
-------
ggplot
A plotnine ggplot object that can be further customized.
"""
if hasattr(result, "event_study") and not isinstance(result, (AGGTEResult, DDDAggResult)):
df = pteresult_to_polars(result)
default_title = "Event Study"
elif isinstance(result, DDDAggResult):
if result.aggregation_type != "eventstudy":
raise ValueError(f"Event study plot requires eventstudy aggregation, got {result.aggregation_type}")
df = dddaggresult_to_polars(result)
default_title = "DDD Event Study"
elif isinstance(result, AGGTEResult):
if result.aggregation_type != "dynamic":
raise ValueError(f"Event study plot requires dynamic aggregation, got {result.aggregation_type}")
df = aggteresult_to_polars(result)
default_title = "Event Study"
else:
raise TypeError(
f"plot_event_study requires AGGTEResult, PTEResult, or DDDAggResult, got {type(result).__name__}"
)
plot = ggplot(df, aes(x="event_time", y="att"))
if show_ci:
plot = plot + geom_errorbar(
aes(ymin="ci_lower", ymax="ci_upper", color="treatment_status"),
width=0.2,
size=0.8,
)
if ref_line is not None:
plot = plot + geom_hline(yintercept=ref_line, linetype="dashed", color="#7f8c8d", alpha=0.7)
if ref_period is not None:
plot = plot + geom_vline(xintercept=ref_period, linetype="dashed", color="gray", size=0.4)
else:
plot = plot + geom_line(color=COLORS["line"], size=0.8, alpha=0.6, linetype="dotted")
x_breaks = sorted(df["event_time"].unique().to_list())
if ref_period is not None and ref_period not in x_breaks:
x_breaks = sorted([*x_breaks, ref_period])
plot = (
plot
+ geom_point(aes(color="treatment_status"), size=3.5)
+ scale_color_manual(
values={"Pre": COLORS["pre_treatment"], "Post": COLORS["post_treatment"]},
limits=["Pre", "Post"],
name="Treatment Status",
)
+ scale_x_continuous(breaks=x_breaks)
+ labs(
x=xlab or "Event Time",
y=ylab or "ATT",
title=title or default_title,
)
+ theme_gray()
+ theme(legend_position="bottom")
)
return plot
[docs]
def plot_agg(
result: AGGTEResult | DDDAggResult,
show_ci: bool = True,
ref_line: float | None = 0,
xlab: str | None = None,
ylab: str | None = None,
title: str | None = None,
**_kwargs: Any,
) -> ggplot:
"""Create plot for aggregated treatment effects by group or calendar time.
Parameters
----------
result : AGGTEResult or DDDAggResult
Aggregated treatment effect result with group or calendar aggregation.
show_ci : bool, default=True
Whether to show confidence intervals as error bars.
ref_line : float or None, default=0
Y-value for reference line. Set to None to hide.
xlab : str, optional
X-axis label. Defaults based on aggregation type.
ylab : str, optional
Y-axis label. Defaults to "ATT".
title : str, optional
Plot title. Defaults based on aggregation type.
Returns
-------
ggplot
A plotnine ggplot object that can be further customized.
"""
if isinstance(result, DDDAggResult):
if result.aggregation_type not in ("group", "calendar"):
raise ValueError(
f"plot_agg requires group or calendar aggregation, got {result.aggregation_type}. "
f"Use plot_event_study for eventstudy aggregation."
)
df = dddaggresult_to_polars(result)
is_ddd = True
elif isinstance(result, AGGTEResult):
if result.aggregation_type not in ("group", "calendar"):
raise ValueError(
f"plot_agg requires group or calendar aggregation, got {result.aggregation_type}. "
f"Use plot_event_study for dynamic aggregation."
)
df = aggteresult_to_polars(result)
is_ddd = False
else:
raise TypeError(f"plot_agg requires AGGTEResult or DDDAggResult, got {type(result).__name__}")
if result.aggregation_type == "group":
default_xlab = "Treatment Cohort"
default_title = "DDD Effects by Treatment Cohort" if is_ddd else "Effects by Treatment Cohort"
else:
default_xlab = "Calendar Time"
default_title = "DDD Effects by Calendar Time" if is_ddd else "Effects by Calendar Time"
plot = ggplot(df, aes(x="event_time", y="att"))
if show_ci:
plot = plot + geom_errorbar(
aes(ymin="ci_lower", ymax="ci_upper"),
width=0.2,
size=0.8,
color=COLORS["post_treatment"],
)
if ref_line is not None:
plot = plot + geom_hline(yintercept=ref_line, linetype="dashed", color="#7f8c8d", alpha=0.7)
x_breaks = sorted(df["event_time"].unique().to_list())
plot = (
plot
+ geom_line(color=COLORS["line"], size=0.8, alpha=0.6, linetype="dotted")
+ geom_point(color=COLORS["post_treatment"], size=3.5)
+ scale_x_continuous(breaks=x_breaks)
+ labs(
x=xlab or default_xlab,
y=ylab or "ATT",
title=title or default_title,
)
+ theme_gray()
)
return plot
[docs]
def plot_dose_response(
result: DoseResult,
effect_type: Literal["att", "acrt"] = "att",
show_ci: bool = True,
ref_line: float | None = 0,
xlab: str | None = None,
ylab: str | None = None,
title: str | None = None,
**_kwargs: Any,
) -> ggplot:
"""Plot dose-response function for continuous treatment.
Parameters
----------
result : DoseResult
Continuous treatment dose-response result.
effect_type : {'att', 'acrt'}, default='att'
Type of effect to plot:
- 'att': Average Treatment Effect on Treated
- 'acrt': Average Causal Response on Treated (marginal effect)
show_ci : bool, default=True
Whether to show confidence bands.
ref_line : float or None, default=0
Y-value for reference line. Set to None to hide.
xlab : str, optional
X-axis label. Defaults to "Dose".
ylab : str, optional
Y-axis label. Defaults based on effect_type.
title : str, optional
Plot title. Defaults based on effect_type.
Returns
-------
ggplot
A plotnine ggplot object that can be further customized.
"""
df = doseresult_to_polars(result, effect_type=effect_type)
default_ylabel = "ATT(d)" if effect_type == "att" else "ACRT(d)"
default_title = f"Dose-Response: {default_ylabel}"
line_color = "#2c3e50"
fill_color = "#95a5a6"
plot = ggplot(df, aes(x="dose", y="effect"))
if show_ci:
plot = plot + geom_ribbon(
aes(ymin="ci_lower", ymax="ci_upper"),
fill=fill_color,
alpha=0.25,
)
if ref_line is not None:
plot = plot + geom_hline(yintercept=ref_line, linetype="dashed", color="#7f8c8d", alpha=0.7)
plot = (
plot
+ geom_line(color=line_color, size=1.2)
+ geom_point(color=line_color, size=2.5)
+ labs(
x=xlab or "Dose",
y=ylab or default_ylabel,
title=title or default_title,
)
+ theme_gray()
)
return plot
[docs]
def plot_sensitivity(
result: HonestDiDResult,
ref_line: float | None = 0,
xlab: str | None = None,
ylab: str | None = None,
title: str | None = None,
**_kwargs: Any,
) -> ggplot:
"""Create sensitivity analysis plot for HonestDiD results.
Parameters
----------
result : HonestDiDResult
Honest DiD sensitivity analysis result.
ref_line : float or None, default=0
Y-value for reference line. Set to None to hide.
xlab : str, optional
X-axis label. Defaults based on sensitivity type.
ylab : str, optional
Y-axis label. Defaults to "Confidence Interval".
title : str, optional
Plot title. Defaults based on sensitivity type.
Returns
-------
ggplot
A plotnine ggplot object that can be further customized.
"""
df = honestdid_to_polars(result)
is_smoothness = result.sensitivity_type == "smoothness"
default_xlab = "M" if is_smoothness else r"$\bar{M}$"
default_title = f"Sensitivity Analysis ({result.sensitivity_type.replace('_', ' ').title()})"
method_colors = {
"Original": COLORS["original"],
"FLCI": COLORS["flci"],
"Conditional": COLORS["conditional"],
"C-F": COLORS["c_f"],
"C-LF": COLORS["c_lf"],
}
methods = df["method"].unique().to_list()
available_colors = {m: method_colors.get(m, "#34495e") for m in methods}
n_methods = len(methods)
dodge_width = 0.05 * (df["param_value"].max() - df["param_value"].min()) if n_methods > 1 else 0
plot = (
ggplot(df, aes(x="param_value", y="midpoint", color="method"))
+ geom_point(size=3, position=position_dodge(width=dodge_width))
+ geom_errorbar(
aes(ymin="lb", ymax="ub"),
width=0.02 * (df["param_value"].max() - df["param_value"].min()),
size=0.8,
position=position_dodge(width=dodge_width),
)
+ scale_color_manual(values=available_colors, name="Method")
+ labs(
x=xlab or default_xlab,
y=ylab or "Confidence Interval",
title=title or default_title,
)
+ theme_gray()
+ theme(legend_position="bottom")
)
if ref_line is not None:
plot = plot + geom_hline(yintercept=ref_line, linetype="solid", color="black", alpha=0.4)
return plot
[docs]
def plot_multiplegt(
result: DIDInterResult,
show_ci: bool = True,
ref_line: float | None = 0,
xlab: str | None = None,
ylab: str | None = None,
title: str | None = None,
**_kwargs: Any,
) -> ggplot:
"""Create event study plot for intertemporal treatment effects.
Parameters
----------
result : DIDInterResult
Intertemporal treatment effects result from did_multiplegt().
show_ci : bool, default=True
Whether to show confidence intervals as error bars.
ref_line : float or None, default=0
Y-value for reference line. Set to None to hide.
xlab : str, optional
X-axis label. Defaults to "Horizon".
ylab : str, optional
Y-axis label. Defaults to "Effect".
title : str, optional
Plot title. Defaults to "Intertemporal Treatment Effects".
Returns
-------
ggplot
A plotnine ggplot object that can be further customized.
"""
df = didinterresult_to_polars(result)
x_breaks = sorted(df["horizon"].unique().to_list())
if 0 not in x_breaks:
x_breaks = sorted([*x_breaks, 0])
plot = ggplot(df, aes(x="horizon", y="att"))
if show_ci:
plot = plot + geom_errorbar(
aes(ymin="ci_lower", ymax="ci_upper", color="treatment_status"),
width=0.2,
size=0.8,
)
if ref_line is not None:
plot = plot + geom_hline(yintercept=ref_line, linetype="dashed", color="#7f8c8d", alpha=0.7)
plot = plot + geom_vline(xintercept=0, linetype="dashed", color="gray", size=0.4)
plot = (
plot
+ geom_point(aes(color="treatment_status"), size=3.5)
+ scale_color_manual(
values={"Pre": COLORS["pre_treatment"], "Post": COLORS["post_treatment"]},
limits=["Pre", "Post"],
name="Treatment Status",
)
+ scale_x_continuous(breaks=x_breaks)
+ labs(
x=xlab or "Horizon",
y=ylab or "Effect",
title=title or "Intertemporal Treatment Effects",
)
+ theme_gray()
+ theme(legend_position="bottom")
)
return plot