Source code for moderndid.plots.plots

"""Plotnine-based plotting functions for moderndid."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal

from plotnine import (
    aes,
    element_text,
    facet_wrap,
    geom_errorbar,
    geom_hline,
    geom_line,
    geom_point,
    geom_ribbon,
    geom_vline,
    ggplot,
    labs,
    position_dodge,
    scale_color_manual,
    scale_x_continuous,
    theme,
    theme_gray,
)

from moderndid.core.converters import (
    aggteresult_to_polars,
    dddaggresult_to_polars,
    dddmpresult_to_polars,
    didinterresult_to_polars,
    doseresult_to_polars,
    honestdid_to_polars,
    mpresult_to_polars,
    pteresult_to_polars,
)
from moderndid.did.container import AGGTEResult, MPResult
from moderndid.didinter.container import DIDInterResult
from moderndid.didtriple.container import DDDAggResult, DDDMultiPeriodRCResult, DDDMultiPeriodResult
from moderndid.plots.themes import COLORS

if TYPE_CHECKING:
    from moderndid.didcont.container import DoseResult, PTEResult
    from moderndid.didhonest.honest_did import HonestDiDResult


[docs] def plot_gt( result: MPResult | DDDMultiPeriodResult | DDDMultiPeriodRCResult, show_ci: bool = True, ref_line: float | None = 0, title: str = "Group", xlab: str | None = None, ylab: str | None = None, ncol: int = 1, **_kwargs: Any, ) -> ggplot: """Plot group-time average treatment effects. Parameters ---------- result : MPResult, DDDMultiPeriodResult, or DDDMultiPeriodRCResult Multi-period result object containing group-time ATT estimates. This should be the output from ``att_gt()`` or ``ddd()``. show_ci : bool, default=True Whether to show confidence intervals as error bars. ref_line : float or None, default=0 Y-value for reference line. Set to None to hide. title : str, default="Group" Title prefix for each facet panel. xlab : str, optional X-axis label. Defaults to "Time". ylab : str, optional Y-axis label. Defaults to "ATT". ncol : int, default=1 Number of columns in the facet grid. Use 1 for vertical stacking. Returns ------- ggplot A plotnine ggplot object that can be further customized. """ if isinstance(result, MPResult): df = mpresult_to_polars(result) plot_title = "Group-Time Average Treatment Effects" elif isinstance(result, (DDDMultiPeriodResult, DDDMultiPeriodRCResult)): df = dddmpresult_to_polars(result) plot_title = "Group-Time DDD Treatment Effects" else: raise TypeError( f"plot_gt requires MPResult, DDDMultiPeriodResult, or DDDMultiPeriodRCResult, got {type(result).__name__}" ) df = df.with_columns([df["group"].cast(int).cast(str).alias("group_label")]) x_breaks = sorted(df["time"].unique().to_list()) plot = ( ggplot(df, aes(x="time", y="att", color="treatment_status")) + geom_point(size=3, alpha=0.8) + scale_color_manual( values={"Pre": COLORS["pre_treatment"], "Post": COLORS["post_treatment"]}, limits=["Pre", "Post"], name="Treatment Status", ) + scale_x_continuous(breaks=x_breaks) + facet_wrap("~group_label", ncol=ncol, labeller=lambda x: f"{title} {x}", scales="free_x") + labs( x=xlab or "Time", y=ylab or "ATT", title=plot_title, ) + theme_gray() + theme( strip_text=element_text(size=11, weight="bold"), plot_title=element_text(margin={"b": 25}), legend_position="bottom", ) ) if show_ci: plot = plot + geom_errorbar( aes(ymin="ci_lower", ymax="ci_upper"), width=0.3, alpha=0.7, ) if ref_line is not None: plot = plot + geom_hline(yintercept=ref_line, linetype="dashed", color="black", alpha=0.5) return plot
[docs] def plot_event_study( result: AGGTEResult | PTEResult | DDDAggResult, show_ci: bool = True, ref_line: float | None = 0, ref_period: float | None = -1, xlab: str | None = None, ylab: str | None = None, title: str | None = None, **_kwargs: Any, ) -> ggplot: """Create event study plot for dynamic treatment effects. Parameters ---------- result : AGGTEResult, PTEResult, or DDDAggResult Aggregated treatment effect result with dynamic/eventstudy aggregation, or PTEResult with event_study attribute. show_ci : bool, default=True Whether to show confidence intervals as error bars. ref_line : float or None, default=0 Y-value for reference line. Set to None to hide. ref_period : float or None, default=-1 X-value for vertical reference period line. Set to None to hide. xlab : str, optional X-axis label. Defaults to "Event Time". ylab : str, optional Y-axis label. Defaults to "ATT". title : str, optional Plot title. Defaults based on result type. Returns ------- ggplot A plotnine ggplot object that can be further customized. """ if hasattr(result, "event_study") and not isinstance(result, (AGGTEResult, DDDAggResult)): df = pteresult_to_polars(result) default_title = "Event Study" elif isinstance(result, DDDAggResult): if result.aggregation_type != "eventstudy": raise ValueError(f"Event study plot requires eventstudy aggregation, got {result.aggregation_type}") df = dddaggresult_to_polars(result) default_title = "DDD Event Study" elif isinstance(result, AGGTEResult): if result.aggregation_type != "dynamic": raise ValueError(f"Event study plot requires dynamic aggregation, got {result.aggregation_type}") df = aggteresult_to_polars(result) default_title = "Event Study" else: raise TypeError( f"plot_event_study requires AGGTEResult, PTEResult, or DDDAggResult, got {type(result).__name__}" ) plot = ggplot(df, aes(x="event_time", y="att")) if show_ci: plot = plot + geom_errorbar( aes(ymin="ci_lower", ymax="ci_upper", color="treatment_status"), width=0.2, size=0.8, ) if ref_line is not None: plot = plot + geom_hline(yintercept=ref_line, linetype="dashed", color="#7f8c8d", alpha=0.7) if ref_period is not None: plot = plot + geom_vline(xintercept=ref_period, linetype="dashed", color="gray", size=0.4) else: plot = plot + geom_line(color=COLORS["line"], size=0.8, alpha=0.6, linetype="dotted") x_breaks = sorted(df["event_time"].unique().to_list()) if ref_period is not None and ref_period not in x_breaks: x_breaks = sorted([*x_breaks, ref_period]) plot = ( plot + geom_point(aes(color="treatment_status"), size=3.5) + scale_color_manual( values={"Pre": COLORS["pre_treatment"], "Post": COLORS["post_treatment"]}, limits=["Pre", "Post"], name="Treatment Status", ) + scale_x_continuous(breaks=x_breaks) + labs( x=xlab or "Event Time", y=ylab or "ATT", title=title or default_title, ) + theme_gray() + theme(legend_position="bottom") ) return plot
[docs] def plot_agg( result: AGGTEResult | DDDAggResult, show_ci: bool = True, ref_line: float | None = 0, xlab: str | None = None, ylab: str | None = None, title: str | None = None, **_kwargs: Any, ) -> ggplot: """Create plot for aggregated treatment effects by group or calendar time. Parameters ---------- result : AGGTEResult or DDDAggResult Aggregated treatment effect result with group or calendar aggregation. show_ci : bool, default=True Whether to show confidence intervals as error bars. ref_line : float or None, default=0 Y-value for reference line. Set to None to hide. xlab : str, optional X-axis label. Defaults based on aggregation type. ylab : str, optional Y-axis label. Defaults to "ATT". title : str, optional Plot title. Defaults based on aggregation type. Returns ------- ggplot A plotnine ggplot object that can be further customized. """ if isinstance(result, DDDAggResult): if result.aggregation_type not in ("group", "calendar"): raise ValueError( f"plot_agg requires group or calendar aggregation, got {result.aggregation_type}. " f"Use plot_event_study for eventstudy aggregation." ) df = dddaggresult_to_polars(result) is_ddd = True elif isinstance(result, AGGTEResult): if result.aggregation_type not in ("group", "calendar"): raise ValueError( f"plot_agg requires group or calendar aggregation, got {result.aggregation_type}. " f"Use plot_event_study for dynamic aggregation." ) df = aggteresult_to_polars(result) is_ddd = False else: raise TypeError(f"plot_agg requires AGGTEResult or DDDAggResult, got {type(result).__name__}") if result.aggregation_type == "group": default_xlab = "Treatment Cohort" default_title = "DDD Effects by Treatment Cohort" if is_ddd else "Effects by Treatment Cohort" else: default_xlab = "Calendar Time" default_title = "DDD Effects by Calendar Time" if is_ddd else "Effects by Calendar Time" plot = ggplot(df, aes(x="event_time", y="att")) if show_ci: plot = plot + geom_errorbar( aes(ymin="ci_lower", ymax="ci_upper"), width=0.2, size=0.8, color=COLORS["post_treatment"], ) if ref_line is not None: plot = plot + geom_hline(yintercept=ref_line, linetype="dashed", color="#7f8c8d", alpha=0.7) x_breaks = sorted(df["event_time"].unique().to_list()) plot = ( plot + geom_line(color=COLORS["line"], size=0.8, alpha=0.6, linetype="dotted") + geom_point(color=COLORS["post_treatment"], size=3.5) + scale_x_continuous(breaks=x_breaks) + labs( x=xlab or default_xlab, y=ylab or "ATT", title=title or default_title, ) + theme_gray() ) return plot
[docs] def plot_dose_response( result: DoseResult, effect_type: Literal["att", "acrt"] = "att", show_ci: bool = True, ref_line: float | None = 0, xlab: str | None = None, ylab: str | None = None, title: str | None = None, **_kwargs: Any, ) -> ggplot: """Plot dose-response function for continuous treatment. Parameters ---------- result : DoseResult Continuous treatment dose-response result. effect_type : {'att', 'acrt'}, default='att' Type of effect to plot: - 'att': Average Treatment Effect on Treated - 'acrt': Average Causal Response on Treated (marginal effect) show_ci : bool, default=True Whether to show confidence bands. ref_line : float or None, default=0 Y-value for reference line. Set to None to hide. xlab : str, optional X-axis label. Defaults to "Dose". ylab : str, optional Y-axis label. Defaults based on effect_type. title : str, optional Plot title. Defaults based on effect_type. Returns ------- ggplot A plotnine ggplot object that can be further customized. """ df = doseresult_to_polars(result, effect_type=effect_type) default_ylabel = "ATT(d)" if effect_type == "att" else "ACRT(d)" default_title = f"Dose-Response: {default_ylabel}" line_color = "#2c3e50" fill_color = "#95a5a6" plot = ggplot(df, aes(x="dose", y="effect")) if show_ci: plot = plot + geom_ribbon( aes(ymin="ci_lower", ymax="ci_upper"), fill=fill_color, alpha=0.25, ) if ref_line is not None: plot = plot + geom_hline(yintercept=ref_line, linetype="dashed", color="#7f8c8d", alpha=0.7) plot = ( plot + geom_line(color=line_color, size=1.2) + geom_point(color=line_color, size=2.5) + labs( x=xlab or "Dose", y=ylab or default_ylabel, title=title or default_title, ) + theme_gray() ) return plot
[docs] def plot_sensitivity( result: HonestDiDResult, ref_line: float | None = 0, xlab: str | None = None, ylab: str | None = None, title: str | None = None, **_kwargs: Any, ) -> ggplot: """Create sensitivity analysis plot for HonestDiD results. Parameters ---------- result : HonestDiDResult Honest DiD sensitivity analysis result. ref_line : float or None, default=0 Y-value for reference line. Set to None to hide. xlab : str, optional X-axis label. Defaults based on sensitivity type. ylab : str, optional Y-axis label. Defaults to "Confidence Interval". title : str, optional Plot title. Defaults based on sensitivity type. Returns ------- ggplot A plotnine ggplot object that can be further customized. """ df = honestdid_to_polars(result) is_smoothness = result.sensitivity_type == "smoothness" default_xlab = "M" if is_smoothness else r"$\bar{M}$" default_title = f"Sensitivity Analysis ({result.sensitivity_type.replace('_', ' ').title()})" method_colors = { "Original": COLORS["original"], "FLCI": COLORS["flci"], "Conditional": COLORS["conditional"], "C-F": COLORS["c_f"], "C-LF": COLORS["c_lf"], } methods = df["method"].unique().to_list() available_colors = {m: method_colors.get(m, "#34495e") for m in methods} n_methods = len(methods) dodge_width = 0.05 * (df["param_value"].max() - df["param_value"].min()) if n_methods > 1 else 0 plot = ( ggplot(df, aes(x="param_value", y="midpoint", color="method")) + geom_point(size=3, position=position_dodge(width=dodge_width)) + geom_errorbar( aes(ymin="lb", ymax="ub"), width=0.02 * (df["param_value"].max() - df["param_value"].min()), size=0.8, position=position_dodge(width=dodge_width), ) + scale_color_manual(values=available_colors, name="Method") + labs( x=xlab or default_xlab, y=ylab or "Confidence Interval", title=title or default_title, ) + theme_gray() + theme(legend_position="bottom") ) if ref_line is not None: plot = plot + geom_hline(yintercept=ref_line, linetype="solid", color="black", alpha=0.4) return plot
[docs] def plot_multiplegt( result: DIDInterResult, show_ci: bool = True, ref_line: float | None = 0, xlab: str | None = None, ylab: str | None = None, title: str | None = None, **_kwargs: Any, ) -> ggplot: """Create event study plot for intertemporal treatment effects. Parameters ---------- result : DIDInterResult Intertemporal treatment effects result from did_multiplegt(). show_ci : bool, default=True Whether to show confidence intervals as error bars. ref_line : float or None, default=0 Y-value for reference line. Set to None to hide. xlab : str, optional X-axis label. Defaults to "Horizon". ylab : str, optional Y-axis label. Defaults to "Effect". title : str, optional Plot title. Defaults to "Intertemporal Treatment Effects". Returns ------- ggplot A plotnine ggplot object that can be further customized. """ df = didinterresult_to_polars(result) x_breaks = sorted(df["horizon"].unique().to_list()) if 0 not in x_breaks: x_breaks = sorted([*x_breaks, 0]) plot = ggplot(df, aes(x="horizon", y="att")) if show_ci: plot = plot + geom_errorbar( aes(ymin="ci_lower", ymax="ci_upper", color="treatment_status"), width=0.2, size=0.8, ) if ref_line is not None: plot = plot + geom_hline(yintercept=ref_line, linetype="dashed", color="#7f8c8d", alpha=0.7) plot = plot + geom_vline(xintercept=0, linetype="dashed", color="gray", size=0.4) plot = ( plot + geom_point(aes(color="treatment_status"), size=3.5) + scale_color_manual( values={"Pre": COLORS["pre_treatment"], "Post": COLORS["post_treatment"]}, limits=["Pre", "Post"], name="Treatment Status", ) + scale_x_continuous(breaks=x_breaks) + labs( x=xlab or "Horizon", y=ylab or "Effect", title=title or "Intertemporal Treatment Effects", ) + theme_gray() + theme(legend_position="bottom") ) return plot