Source code for moderndid.didcont.npiv.prodspline

"""Multivariate spline construction for continuous treatment DiD estimation."""

import warnings
from itertools import combinations, product
from typing import NamedTuple

import numpy as np

from ...cupy.backend import get_backend, to_numpy
from .gsl_bspline import gsl_bs, predict_gsl_bs


class MultivariateBasis(NamedTuple):
    """Container for multivariate spline basis construction results."""

    #: Spline basis matrix.
    basis: np.ndarray
    #: Dimension of the basis without tensor product.
    dim_no_tensor: int
    #: Matrix of degrees for each variable.
    degree_matrix: np.ndarray
    #: Number of segments for each variable.
    n_segments: np.ndarray
    #: Type of basis construction used.
    basis_type: str


[docs] def prodspline( x, K, z=None, indicator=None, xeval=None, zeval=None, knots="quantiles", basis="additive", x_min=None, x_max=None, deriv_index=1, deriv=0, ): r"""Create multivariate spline basis with B-spline components. Constructs additive, tensor product, or generalized linear product (GLP) basis functions for multivariate continuous and discrete predictors. Parameters ---------- x : ndarray Continuous predictor matrix of shape (n, p). K : ndarray Matrix of shape (p, 2) containing spline specifications: - Column 0: degree for each continuous variable - Column 1: number of segments - 1 for each variable z : ndarray, optional Discrete predictor matrix of shape (n, q). indicator : ndarray, optional Indicator vector of length q for discrete variables (1 to include). xeval : ndarray, optional Evaluation points for continuous variables. If None, uses x. zeval : ndarray, optional Evaluation points for discrete variables. If None, uses z. knots : {"quantiles", "uniform"}, default="quantiles" Method for knot placement: - "quantiles": Knots at data quantiles - "uniform": Uniformly spaced knots basis : {"additive", "tensor", "glp"}, default="additive" Type of basis construction: - "additive": Sum of univariate bases - "tensor": Full tensor product of all bases - "glp": Generalized linear product (hierarchical interactions) x_min : ndarray, optional Minimum values for each continuous variable. x_max : ndarray, optional Maximum values for each continuous variable. deriv_index : int, default=1 Index (1-based) of variable for derivative computation. deriv : int, default=0 Order of derivative to compute. Returns ------- MultivariateBasis NamedTuple containing: - **basis**: Complete basis matrix - **dim_no_tensor**: Number of columns before tensor product - **degree_matrix**: Copy of K matrix - **n_segments**: Number of segments for each variable - **basis_type**: Type of basis used References ---------- .. [1] Wood, S. N. (2017). Generalized Additive Models: An Introduction with R. Chapman and Hall/CRC. """ xp = get_backend() if x is None or K is None: raise ValueError("Must provide x and K.") if not isinstance(K, np.ndarray) or K.ndim != 2 or K.shape[1] != 2: raise ValueError("K must be a two-column matrix.") x = xp.atleast_2d(xp.asarray(x)) K = np.round(K).astype(int) num_x = x.shape[1] num_K = K.shape[0] if num_K != num_x: raise ValueError(f"Dimension of x and K incompatible ({num_x}, {num_K}).") if deriv < 0: raise ValueError("deriv is invalid.") if deriv_index < 1 or deriv_index > num_x: raise ValueError("deriv_index is invalid.") if deriv > K[deriv_index - 1, 0]: warnings.warn("deriv order too large, result will be zero.", UserWarning) num_z = 0 if z is not None: z = np.atleast_2d(z) num_z = z.shape[1] if indicator is None: raise ValueError("Must provide indicator when z is specified.") indicator = np.asarray(indicator) num_indicator = len(indicator) if num_indicator != num_z: raise ValueError(f"Dimension of z and indicator incompatible ({num_z}, {num_indicator}).") if xeval is None: xeval = x.copy() else: xeval = xp.atleast_2d(xp.asarray(xeval)) if xeval.shape[1] != num_x: raise ValueError("xeval must be of the same dimension as x.") if z is not None and zeval is None: zeval = z.copy() elif z is not None: zeval = xp.atleast_2d(xp.asarray(zeval)) gsl_intercept = basis not in ("additive", "glp") if np.any(K[:, 0] > 0) or (indicator is not None and np.any(indicator != 0)): tp = [] for i in range(num_x): if K[i, 0] > 0: if knots == "uniform": knots_vec = None else: probs = np.linspace(0, 1, K[i, 1] + 2) x_col_np = to_numpy(x[:, i]) knots_vec = np.quantile(x_col_np, probs) knots_vec = knots_vec + np.linspace( 0, 1e-10 * (np.max(x_col_np) - np.min(x_col_np)), len(knots_vec), ) if i == deriv_index - 1 and deriv != 0: basis_obj = gsl_bs( x=x[:, i], degree=K[i, 0], nbreak=K[i, 1] + 2, knots=knots_vec, deriv=deriv, x_min=x_min[i] if x_min is not None else None, x_max=x_max[i] if x_max is not None else None, intercept=gsl_intercept, ) else: basis_obj = gsl_bs( x=x[:, i], degree=K[i, 0], nbreak=K[i, 1] + 2, knots=knots_vec, x_min=x_min[i] if x_min is not None else None, x_max=x_max[i] if x_max is not None else None, intercept=gsl_intercept, ) tp.append(predict_gsl_bs(basis_obj, xeval[:, i])) if z is not None: for i in range(num_z): if indicator[i] == 1: if zeval is None: unique_vals = np.unique(z[:, i]) if len(unique_vals) > 1: dummies = np.column_stack([(z[:, i] == val).astype(float) for val in unique_vals[1:]]) tp.append(dummies) else: unique_vals = np.unique(z[:, i]) if len(unique_vals) > 1: dummies = np.column_stack([(zeval[:, i] == val).astype(float) for val in unique_vals[1:]]) tp.append(dummies) if len(tp) > 1: P = xp.hstack(tp) dim_P_no_tensor = P.shape[1] if basis == "tensor": P = tensor_prod_model_matrix(tp) elif basis == "glp": P = glp_model_matrix(tp) if deriv != 0: p_deriv_list = [np.zeros((1, b.shape[1])) for b in tp] # Find the index in `tp` that corresponds to the derivative variable. # `deriv_index` is 1-based for `x`. `tp` only contains bases for # variables with `K[i,0] > 0` or `indicator[i] == 1`. Derivatives are # only for continuous variables, so we only care about `K`. tp_idx = -1 spline_count = 0 if deriv_index > 0: for i in range(deriv_index - 1): if K[i, 0] > 0: spline_count += 1 if K[deriv_index - 1, 0] > 0: tp_idx = spline_count if tp_idx != -1 and tp_idx < len(p_deriv_list): p_deriv_list[tp_idx] = np.full((1, tp[tp_idx].shape[1]), np.nan) mask_basis = glp_model_matrix(p_deriv_list) mask = np.isnan(mask_basis.flatten()) P[:, ~mask] = 0 else: P = tp[0] if tp else np.ones((xeval.shape[0], 1)) dim_P_no_tensor = P.shape[1] else: dim_P_no_tensor = 0 P = xp.ones((xeval.shape[0], 1)) return MultivariateBasis( basis=P, dim_no_tensor=dim_P_no_tensor, degree_matrix=K.copy(), n_segments=K[:, 1] + 1 if K.size > 0 else np.array([]), basis_type=basis, )
def tensor_prod_model_matrix(bases): r"""Construct tensor product of marginal basis model matrices. Produces model matrices for tensor product smooths from marginal basis model matrices. The tensor product is computed row-wise using Kronecker products. Parameters ---------- bases : list of ndarray List of model matrices for marginal bases. Each matrix must have the same number of rows (observations). Returns ------- ndarray Tensor product model matrix of shape (n, prod(dims)) where n is the number of observations and dims are the dimensions of input matrices. References ---------- .. [1] Wood, S. N. (2006). Low-rank scale-invariant tensor product smooths for generalized additive mixed models. Biometrics, 62(4), 1025-1036. """ xp = get_backend() if not bases: raise ValueError("bases cannot be empty") for i, basis in enumerate(bases): if not hasattr(basis, "ndim"): raise TypeError(f"bases[{i}] must be an array, got {type(basis)}") if basis.ndim != 2: raise ValueError(f"bases[{i}] must be 2-dimensional") n_obs = bases[0].shape[0] for i, basis in enumerate(bases[1:], 1): if basis.shape[0] != n_obs: raise ValueError( f"All matrices must have same number of rows. bases[0] has {n_obs}, bases[{i}] has {basis.shape[0]}" ) dims = [basis.shape[1] for basis in bases] total_cols = int(np.prod(dims)) result = xp.empty((n_obs, total_cols), dtype=np.float64) for row in range(n_obs): row_vectors = [basis[row, :] for basis in bases] tensor_row = row_vectors[0].copy() for vec in row_vectors[1:]: tensor_row = xp.kron(tensor_row, vec) result[row, :] = tensor_row return result def glp_model_matrix(bases): r"""Construct generalized linear product (GLP) model matrix. Produces model matrices for generalized polynomial smooths from marginal basis model matrices. The GLP creates a hierarchical polynomial structure where terms of different orders can be included, providing a more parsimonious alternative to full tensor products while retaining good approximation capabilities. Parameters ---------- bases : list of ndarray List of model matrices for marginal bases. Each matrix must have the same number of rows (observations). Returns ------- ndarray GLP model matrix with hierarchical polynomial structure. References ---------- .. [1] Hall, P., & Racine, J. S. (2015). Infinite order cross-validated local polynomial regression. Journal of Econometrics, 185(2), 510-525. """ xp = get_backend() if not bases: raise ValueError("bases cannot be empty") for i, basis in enumerate(bases): if not hasattr(basis, "ndim"): raise TypeError(f"bases[{i}] must be an array, got {type(basis)}") if basis.ndim != 2: raise ValueError(f"bases[{i}] must be 2-dimensional") n_obs = bases[0].shape[0] for i, basis in enumerate(bases[1:], 1): if basis.shape[0] != n_obs: raise ValueError( f"All matrices must have same number of rows. bases[0] has {n_obs}, bases[{i}] has {basis.shape[0]}" ) if n_obs == 0: return xp.empty((0, 0)) num_bases = len(bases) result_matrices = [] for basis in bases: result_matrices.append(basis) for i, j in combinations(range(num_bases), 2): interaction = _compute_basis_interaction([bases[i], bases[j]]) result_matrices.append(interaction) for order in range(3, num_bases + 1): for indices in combinations(range(num_bases), order): selected_bases = [bases[idx] for idx in indices] interaction = _compute_basis_interaction(selected_bases) result_matrices.append(interaction) if result_matrices: return xp.hstack(result_matrices) return xp.ones((n_obs, 1)) def _compute_basis_interaction(bases): """Compute interaction terms between basis functions. Parameters ---------- bases : list of ndarray List of basis matrices to interact. Returns ------- ndarray Matrix of interaction terms. """ xp = get_backend() if len(bases) == 1: return bases[0] n_obs = bases[0].shape[0] dims = [basis.shape[1] for basis in bases] total_interactions = int(np.prod(dims)) result = xp.empty((n_obs, total_interactions)) for col_idx, indices in enumerate(product(*[range(dim) for dim in dims])): interaction_col = xp.ones(n_obs) for basis_idx, func_idx in enumerate(indices): interaction_col *= bases[basis_idx][:, func_idx] result[:, col_idx] = interaction_col return result