Distributed Estimation#
ModernDiD includes distributed backends for att_gt and
ddd on top of Dask and
Apache Spark. Both backends scale from one
machine to a full cluster and return the same result object types as local
estimation, so your post-estimation workflow stays the same.
This page focuses on practical usage. For internal implementation details, see Distributed Backend Architecture.
When to use distributed estimation#
Distributed estimation is most useful when the dataset does not fit in memory on one machine, when local runtime is too long for your iteration cycle, when you need to run many model specifications over the same large dataset, or when you already maintain a Dask or Spark cluster for analytics workloads.
Distributed execution has scheduler and communication overhead. If local estimation is fast and memory safe, local execution is often simpler and easier to debug.
Requirements#
Install the extra for the backend you plan to use.
uv pip install moderndid[dask] # Dask backend
uv pip install moderndid[spark] # Spark backend
The all extra includes Dask but not Spark (which requires a JVM),
so Spark users should install moderndid[spark] explicitly. ModernDiD cannot set up your cluster environment for you.
Creating clusters, managing workers/executors, and building distributed
DataFrames is the responsibility of the user. Once you have a Dask or Spark
DataFrame, ModernDiD handles the rest.
If you are new to either framework:
Quick start#
The distributed backend activates automatically when data is a Dask
or PySpark DataFrame. If data is any other type, the local estimator
runs instead. This lets you move from local development to cluster execution
without rewriting estimator arguments.
Dask
import dask.dataframe as dd
import moderndid as did
ddf = dd.read_parquet("panel_data.parquet")
result = did.att_gt(
data=ddf,
yname="y",
tname="time",
idname="id",
gname="group",
xformla="~ x1 + x2",
est_method="dr",
)
# Post-estimation code is unchanged
event_study = did.aggte(result, type="dynamic")
did.plot_event_study(event_study)
Spark
from pyspark.sql import SparkSession
import moderndid as did
spark = SparkSession.builder.master("local[*]").getOrCreate()
sdf = spark.read.parquet("panel_data.parquet")
result = did.att_gt(
data=sdf,
yname="y",
tname="time",
idname="id",
gname="group",
xformla="~ x1 + x2",
est_method="dr",
)
event_study = did.aggte(result, type="dynamic")
did.plot_event_study(event_study)
The same pattern works for triple differences with both backends.
result = did.ddd(
data=ddf, # or sdf
yname="y",
tname="time",
idname="id",
gname="group",
pname="partition",
est_method="dr",
)
Return objects are the same classes as local estimation, so downstream
utilities like aggte, agg_ddd, plot_event_study, and related
plotting functions continue to work without changes.
Interfaces#
You can use distributed estimation through two API layers. The high-level
wrappers att_gt and ddd are the recommended entry points. The
low-level functions give you explicit control over the client or session.
Dask
High-level wrappers do not expose a client argument. Creating a
Client registers it as the global default, so estimator calls pick it
up automatically. The low-level functions dask_att_gt and dask_ddd
in moderndid.dask accept a client argument directly.
import moderndid as did
from moderndid.dask import dask_att_gt
from dask.distributed import Client
# High-level (uses the global default client)
client = Client("scheduler-address:8786")
result_a = did.att_gt(data=ddf, yname="y", tname="time", idname="id", gname="group")
# Low-level (accepts client explicitly)
result_b = dask_att_gt(data=ddf, yname="y", tname="time", idname="id", gname="group", client=client)
Spark
High-level wrappers do not expose a spark argument. ModernDiD uses the
active session or creates a local one automatically. The low-level functions
spark_att_gt and spark_ddd in moderndid.spark accept a
spark argument directly.
import moderndid as did
from moderndid.spark import spark_att_gt
from pyspark.sql import SparkSession
# High-level (uses the active session)
spark = SparkSession.builder.master("local[*]").getOrCreate()
result_a = did.att_gt(data=sdf, yname="y", tname="time", idname="id", gname="group")
# Low-level (accepts spark session explicitly)
result_b = spark_att_gt(data=sdf, yname="y", tname="time", idname="id", gname="group", spark=spark)
Current support and limits#
The distributed path supports both panel and repeated cross-section data
for any number of time periods. All standard estimation options work in
distributed mode, including control_group, anticipation,
base_period, and est_method (callable
est_method is not supported). When boot=True, the multiplier
bootstrap runs fully distributed with Mammen two-point weights generated
on workers and tree-reduced to the driver.
The n_jobs parameter is not used in distributed mode. For ddd,
boot_type is also ignored.
Preparing input data#
The distributed estimators expect a Dask or PySpark DataFrame in long format.
Include all required columns in the same frame and keep time and treatment
group columns numeric. For panel data (panel=True), ensure unit
identifiers are stable across periods and prefer one record per unit-period.
For repeated cross-section data (panel=False), each row is an independent
observation and unit identifiers are not required.
Reading large datasets#
Dask. When the full dataset is too large to build in driver memory, write it to Parquet in batches and read it back as a Dask DataFrame. This keeps driver memory constant regardless of total dataset size.
import gc
import dask
import dask.dataframe as dd
import moderndid as did
from dask.distributed import Client
client = Client() # or connect to an existing scheduler
n_workers = len(client.scheduler_info()["workers"])
N_TOTAL = 100_000_000
CHUNK_SIZE = 500_000
N_CHUNKS = N_TOTAL // CHUNK_SIZE
BATCH_SIZE = n_workers * 2 # chunks written per round
PARQUET_PATH = "/tmp/panel_data"
# Define a delayed function that builds one chunk
@dask.delayed
def _generate_chunk(chunk_id, n):
dgp = did.gen_did_scalable(
n=n, dgp_type=1, n_periods=10, n_cohorts=6,
n_covariates=30, panel=True, random_state=chunk_id,
)
df = dgp["data"].to_pandas()
df["id"] = df["id"] + chunk_id * n # ensure globally unique IDs
return df
# Write in batches so the driver never holds the full dataset
meta = _generate_chunk(0, 10).compute() # schema for Dask
for batch_start in range(0, N_CHUNKS, BATCH_SIZE):
batch_end = min(batch_start + BATCH_SIZE, N_CHUNKS)
chunks = [_generate_chunk(i, CHUNK_SIZE)
for i in range(batch_start, batch_end)]
ddf_batch = dd.from_delayed(chunks, meta=meta)
ddf_batch.to_parquet(
PARQUET_PATH,
append=(batch_start > 0),
engine="pyarrow",
)
del ddf_batch, chunks
gc.collect()
# Read back as a distributed DataFrame
ddf = dd.read_parquet(PARQUET_PATH, engine="pyarrow")
Spark. Spark reads data lazily, so you can point directly at Parquet, CSV, or any Spark-supported source and the data stays distributed across executors without staging batches on the driver.
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
sdf = spark.read.parquet("hdfs:///data/panel_data.parquet")
Partition layout#
Partition layout affects both runtime and memory. Extremely small partitions increase scheduler overhead, and extremely large partitions increase worker/executor memory pressure. If you plan to run multiple model specifications over the same data, repartition once and persist/cache before the first fit.
# Dask
ddf = ddf.repartition(npartitions=64)
# Spark
sdf = sdf.repartition(64)
Persist and sanity-check#
Both frameworks evaluate lazily, so heavy work can execute later than expected, including inside estimator calls. A good pattern is to persist/cache and sanity-check the input before fitting. This boundary separates data-pipeline failures from estimator failures, which makes debugging faster.
Dask
from dask.distributed import wait
ddf = ddf.persist()
wait(ddf)
print(ddf.columns)
print(ddf[["id", "time", "group"]].head())
Spark
sdf = sdf.cache()
sdf.count() # force materialization
print(sdf.columns)
sdf.select("id", "time", "group").show(5)
From local to distributed#
A practical workflow is to develop and validate your specification locally on a sample, then scale up to the full dataset by swapping in a distributed DataFrame. The estimator arguments stay the same. Only the input type changes.
import polars as pl
import moderndid as did
shared_args = dict(
yname="y", tname="time", idname="id", gname="group",
xformla="~ x1 + x2", est_method="dr", control_group="nevertreated",
)
# Develop locally on a sample
sample = pl.read_parquet("panel_data.parquet").sample(n=10_000, seed=42)
local_result = did.att_gt(data=sample, **shared_args)
# Scale with Dask
import dask.dataframe as dd
ddf = dd.read_parquet("panel_data.parquet")
dask_result = did.att_gt(data=ddf, **shared_args)
# Or scale with Spark
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
sdf = spark.read.parquet("panel_data.parquet")
spark_result = did.att_gt(data=sdf, **shared_args)
Connecting to a cluster#
Dask. If no client exists, ModernDiD creates a local client
automatically. For multi-node runs, create a Client pointing at your
scheduler. Creating the client registers it as the global default.
import moderndid as did
from dask.distributed import Client
client = Client("tcp://scheduler-host:8786")
result = did.att_gt(
data=ddf,
yname="y",
tname="time",
idname="id",
gname="group",
)
This works with any Dask-compatible scheduler endpoint. If you need to
direct calls to a specific client (for example, when multiple clients are
active), use client.as_current().
For local development with controlled resources, create a LocalCluster
explicitly to set worker count and memory limits.
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=4, threads_per_worker=2, memory_limit="4GB")
client = Client(cluster)
result = did.att_gt(data=ddf, yname="y", tname="time", idname="id", gname="group")
Spark. If no active session exists, ModernDiD creates a local Spark
session automatically. For cluster runs, create a SparkSession pointing
at your cluster manager.
from pyspark.sql import SparkSession
import moderndid as did
spark = (
SparkSession.builder
.master("yarn")
.appName("moderndid-estimation")
.config("spark.executor.memory", "8g")
.config("spark.executor.cores", "4")
.getOrCreate()
)
sdf = spark.read.parquet("hdfs:///data/panel_data.parquet")
result = did.att_gt(
data=sdf,
yname="y",
tname="time",
idname="id",
gname="group",
)
This works with any Spark-compatible cluster manager (standalone, YARN,
Mesos, Kubernetes). For Databricks, the SparkSession is pre-configured
and available as spark in notebooks, so you can pass Spark DataFrames
directly to estimators without additional setup.
For local development with controlled resources, configure executor settings explicitly.
from pyspark.sql import SparkSession
spark = (
SparkSession.builder
.master("local[4]")
.config("spark.driver.memory", "4g")
.config("spark.executor.memory", "4g")
.getOrCreate()
)
result = did.att_gt(data=sdf, yname="y", tname="time", idname="id", gname="group")
Distributed-specific parameters#
In addition to standard estimator arguments, distributed execution uses a small set of controls.
client(Dask only)A
distributed.Clientfor low-level calls (dask_att_gt,dask_ddd). If omitted, ModernDiD first triesClient.current()and creates a local client when none is active.spark(Spark only)A
pyspark.sql.SparkSessionfor low-level calls (spark_att_gt,spark_ddd). If omitted, ModernDiD first triesSparkSession.getActiveSession()and creates a local session when none is active.n_partitionsNumber of partitions per cell computation. The default equals total worker threads (Dask) or Spark’s default parallelism (Spark), and the backend can increase it when estimated partition design matrices would be too large.
max_cohortsMaximum number of treatment cohorts processed concurrently. Defaults to the number of workers (Dask) or executor cores (Spark). Lower values reduce peak memory because fewer cohort-wide tables are active at once. Higher values can improve throughput on memory-rich clusters.
backendSet to
"cupy"to run partition-level linear algebra on worker GPUs. See Combining GPU and Dask and Combining GPU and Spark for setup details.
from moderndid.dask import dask_att_gt
result = dask_att_gt(
data=ddf,
yname="y",
tname="time",
idname="id",
gname="group",
xformla="~ x1 + x2",
est_method="dr",
client=client,
n_partitions=64,
max_cohorts=4,
)
from moderndid.spark import spark_att_gt
result = spark_att_gt(
data=sdf,
yname="y",
tname="time",
idname="id",
gname="group",
xformla="~ x1 + x2",
est_method="dr",
spark=spark,
n_partitions=64,
max_cohorts=4,
)
Start with defaults and record runtime and peak memory. Increase
n_partitions if workers are idle for long stretches, or reduce it if
scheduler overhead dominates task time. Reduce max_cohorts when workers
approach memory limits, and increase it gradually when memory headroom is
large.
Supported estimation features#
All standard estimator arguments work in distributed mode with the same interface as local estimation.
Bootstrap (
boot=True) — Mammen two-point weights generated on workers and tree-reduced to the driver.cbanddefaults toFalse(local defaults toTrue); setcband=Trueexplicitly for uniform bands.Clustered SEs (
clustervars) — one-way and two-way clustering supported. Pass a list, not a bare string.Repeated cross-sections (
panel=False) — fully supported for bothatt_gtandddd.Unbalanced panels (
allow_unbalanced_panel) — supported. DefaultFalselogs a warning with the number of dropped units.Sampling weights (
weightsname) — supported. Weight column must be present in the distributed DataFrame.GPU on workers (
backend="cupy") — runs partition-level linear algebra on worker GPUs. See Combining GPU and Dask and Combining GPU and Spark for setup.
Running multiple specifications#
When running multiple specifications over the same data, persist/cache the distributed DataFrame once and reuse it across calls. This avoids re-reading and re-shuffling the data for each specification.
from dask.distributed import Client, wait
client = Client("tcp://scheduler-host:8786")
ddf = dd.read_parquet("panel_data.parquet").persist()
wait(ddf)
specifications = [
{"est_method": "dr", "control_group": "nevertreated"},
{"est_method": "dr", "control_group": "notyettreated"},
{"est_method": "reg", "control_group": "nevertreated"},
]
results = {}
for spec in specifications:
results[str(spec)] = did.att_gt(
data=ddf,
yname="y",
tname="time",
idname="id",
gname="group",
xformla="~ x1 + x2",
**spec,
)
Monitoring the cluster#
For long-running jobs, the cluster dashboard provides real-time visibility into task progress, worker memory, and the task stream.
Dask: Access the dashboard at
http://scheduler-host:8787/statusafter creating a client.Spark: Access the Spark UI at
http://driver-host:4040after creating a session. On Databricks, the Spark UI is available directly from the cluster page.
Reproducibility#
Distributed reproducibility is best-effort. Set random_state whenever you
need stable bootstrap draws, then keep cluster conditions as constant as
possible. Results can still vary across runs when worker count, threads per
worker, partition placement, floating-point reduction order, or concurrent
cluster workloads change.
For reproducibility-sensitive comparisons, fix the cluster size and worker
hardware, fix the partition count and input file layout, set the estimator
random_state explicitly, and run on a quiet cluster.
Next steps#
Quickstart covers estimation options, aggregation types, and visualization for local workflows.
GPU Acceleration with CuPy describes GPU acceleration for local and distributed workloads.
Estimator Overview surveys all available estimators and their distributed support.
The Examples section walks through each estimator end-to-end with real and simulated data.
For architecture-level details on reduction patterns, memory strategy, and execution decomposition, see Distributed Backend Architecture.