Source code for moderndid.npiv.prodspline

"""Multivariate spline construction for nonparametric estimation."""

import warnings
from itertools import combinations, product
from typing import NamedTuple

import numpy as np

from ..cupy.backend import get_backend, to_numpy
from .gsl_bspline import gsl_bs, predict_gsl_bs


class MultivariateBasis(NamedTuple):
    """Container for multivariate spline basis construction results."""

    #: Spline basis matrix.
    basis: np.ndarray
    #: Dimension of the basis without tensor product.
    dim_no_tensor: int
    #: Matrix of degrees for each variable.
    degree_matrix: np.ndarray
    #: Number of segments for each variable.
    n_segments: np.ndarray
    #: Type of basis construction used.
    basis_type: str


[docs] def prodspline( x, K, z=None, indicator=None, xeval=None, zeval=None, knots="quantiles", basis="additive", x_min=None, x_max=None, deriv_index=1, deriv=0, ): r"""Create multivariate spline basis with B-spline components. Constructs additive, tensor product, or generalized linear product (GLP) basis functions for multivariate continuous and discrete predictors. Parameters ---------- x : ndarray Continuous predictor matrix of shape (n, p). K : ndarray Matrix of shape (p, 2) containing spline specifications: - Column 0: degree for each continuous variable - Column 1: number of segments - 1 for each variable z : ndarray, optional Discrete predictor matrix of shape (n, q). indicator : ndarray, optional Indicator vector of length q for discrete variables (1 to include). xeval : ndarray, optional Evaluation points for continuous variables. If None, uses x. zeval : ndarray, optional Evaluation points for discrete variables. If None, uses z. knots : {"quantiles", "uniform"}, default="quantiles" Method for knot placement: - "quantiles": Knots at data quantiles - "uniform": Uniformly spaced knots basis : {"additive", "tensor", "glp"}, default="additive" Type of basis construction: - "additive": Sum of univariate bases - "tensor": Full tensor product of all bases - "glp": Generalized linear product (hierarchical interactions) x_min : ndarray, optional Minimum values for each continuous variable. x_max : ndarray, optional Maximum values for each continuous variable. deriv_index : int, default=1 Index (1-based) of variable for derivative computation. deriv : int, default=0 Order of derivative to compute. Returns ------- MultivariateBasis NamedTuple containing: - **basis**: Complete basis matrix - **dim_no_tensor**: Number of columns before tensor product - **degree_matrix**: Copy of K matrix - **n_segments**: Number of segments for each variable - **basis_type**: Type of basis used References ---------- .. [1] Wood, S. N. (2017). Generalized Additive Models: An Introduction with R. Chapman and Hall/CRC. """ xp = get_backend() if x is None or K is None: raise ValueError("Must provide x and K.") if not isinstance(K, np.ndarray) or K.ndim != 2 or K.shape[1] != 2: raise ValueError("K must be a two-column matrix.") x = xp.atleast_2d(xp.asarray(x)) K = np.round(K).astype(int) num_x = x.shape[1] num_K = K.shape[0] if num_K != num_x: raise ValueError(f"Dimension of x and K incompatible ({num_x}, {num_K}).") if deriv < 0: raise ValueError("deriv is invalid.") if deriv_index < 1 or deriv_index > num_x: raise ValueError("deriv_index is invalid.") if deriv > K[deriv_index - 1, 0]: warnings.warn("deriv order too large, result will be zero.", UserWarning) num_z = 0 if z is not None: z = np.atleast_2d(z) num_z = z.shape[1] if indicator is None: raise ValueError("Must provide indicator when z is specified.") indicator = np.asarray(indicator) num_indicator = len(indicator) if num_indicator != num_z: raise ValueError(f"Dimension of z and indicator incompatible ({num_z}, {num_indicator}).") if xeval is None: xeval = x.copy() else: xeval = xp.atleast_2d(xp.asarray(xeval)) if xeval.shape[1] != num_x: raise ValueError("xeval must be of the same dimension as x.") if z is not None and zeval is None: zeval = z.copy() elif z is not None: zeval = xp.atleast_2d(xp.asarray(zeval)) gsl_intercept = basis not in ("additive", "glp") if np.any(K[:, 0] > 0) or (indicator is not None and np.any(indicator != 0)): tp = [] for i in range(num_x): if K[i, 0] > 0: if knots == "uniform": knots_vec = None else: probs = np.linspace(0, 1, K[i, 1] + 2) x_col_np = to_numpy(x[:, i]) knots_vec = np.quantile(x_col_np, probs) knots_vec = knots_vec + np.linspace( 0, 1e-10 * (np.max(x_col_np) - np.min(x_col_np)), len(knots_vec), ) if i == deriv_index - 1 and deriv != 0: basis_obj = gsl_bs( x=x[:, i], degree=K[i, 0], nbreak=K[i, 1] + 2, knots=knots_vec, deriv=deriv, x_min=x_min[i] if x_min is not None else None, x_max=x_max[i] if x_max is not None else None, intercept=gsl_intercept, ) else: basis_obj = gsl_bs( x=x[:, i], degree=K[i, 0], nbreak=K[i, 1] + 2, knots=knots_vec, x_min=x_min[i] if x_min is not None else None, x_max=x_max[i] if x_max is not None else None, intercept=gsl_intercept, ) tp.append(predict_gsl_bs(basis_obj, xeval[:, i])) if z is not None: for i in range(num_z): if indicator[i] == 1: if zeval is None: unique_vals = np.unique(z[:, i]) if len(unique_vals) > 1: dummies = np.column_stack([(z[:, i] == val).astype(float) for val in unique_vals[1:]]) tp.append(dummies) else: unique_vals = np.unique(z[:, i]) if len(unique_vals) > 1: dummies = np.column_stack([(zeval[:, i] == val).astype(float) for val in unique_vals[1:]]) tp.append(dummies) if len(tp) > 1: P = xp.hstack(tp) dim_P_no_tensor = P.shape[1] if basis == "tensor": P = tensor_prod_model_matrix(tp) elif basis == "glp": P = glp_model_matrix(tp) if deriv != 0: p_deriv_list = [np.zeros((1, b.shape[1])) for b in tp] # Find the index in `tp` that corresponds to the derivative variable. # `deriv_index` is 1-based for `x`. `tp` only contains bases for # variables with `K[i,0] > 0` or `indicator[i] == 1`. Derivatives are # only for continuous variables, so we only care about `K`. tp_idx = -1 spline_count = 0 if deriv_index > 0: for i in range(deriv_index - 1): if K[i, 0] > 0: spline_count += 1 if K[deriv_index - 1, 0] > 0: tp_idx = spline_count if tp_idx != -1 and tp_idx < len(p_deriv_list): p_deriv_list[tp_idx] = np.full((1, tp[tp_idx].shape[1]), np.nan) mask_basis = glp_model_matrix(p_deriv_list) mask = np.isnan(mask_basis.flatten()) P[:, ~mask] = 0 else: P = tp[0] if tp else np.ones((xeval.shape[0], 1)) dim_P_no_tensor = P.shape[1] else: dim_P_no_tensor = 0 P = xp.ones((xeval.shape[0], 1)) return MultivariateBasis( basis=P, dim_no_tensor=dim_P_no_tensor, degree_matrix=K.copy(), n_segments=K[:, 1] + 1 if K.size > 0 else np.array([]), basis_type=basis, )
def tensor_prod_model_matrix(bases): r"""Construct tensor product of marginal basis model matrices. Produces model matrices for tensor product smooths from marginal basis model matrices. The tensor product is computed row-wise using Kronecker products. Parameters ---------- bases : list of ndarray List of model matrices for marginal bases. Each matrix must have the same number of rows (observations). Returns ------- ndarray Tensor product model matrix of shape (n, prod(dims)) where n is the number of observations and dims are the dimensions of input matrices. References ---------- .. [1] Wood, S. N. (2006). Low-rank scale-invariant tensor product smooths for generalized additive mixed models. Biometrics, 62(4), 1025-1036. """ xp = get_backend() if not bases: raise ValueError("bases cannot be empty") for i, basis in enumerate(bases): if not hasattr(basis, "ndim"): raise TypeError(f"bases[{i}] must be an array, got {type(basis)}") if basis.ndim != 2: raise ValueError(f"bases[{i}] must be 2-dimensional") n_obs = bases[0].shape[0] for i, basis in enumerate(bases[1:], 1): if basis.shape[0] != n_obs: raise ValueError( f"All matrices must have same number of rows. bases[0] has {n_obs}, bases[{i}] has {basis.shape[0]}" ) dims = [basis.shape[1] for basis in bases] total_cols = int(np.prod(dims)) result = xp.empty((n_obs, total_cols), dtype=np.float64) for row in range(n_obs): row_vectors = [basis[row, :] for basis in bases] tensor_row = row_vectors[0].copy() for vec in row_vectors[1:]: tensor_row = xp.kron(tensor_row, vec) result[row, :] = tensor_row return result def glp_model_matrix(bases): r"""Construct generalized linear product (GLP) model matrix. Produces model matrices for generalized polynomial smooths from marginal basis model matrices. The GLP creates a hierarchical polynomial structure where terms of different orders can be included, providing a more parsimonious alternative to full tensor products while retaining good approximation capabilities. Parameters ---------- bases : list of ndarray List of model matrices for marginal bases. Each matrix must have the same number of rows (observations). Returns ------- ndarray GLP model matrix with hierarchical polynomial structure. References ---------- .. [1] Hall, P., & Racine, J. S. (2015). Infinite order cross-validated local polynomial regression. Journal of Econometrics, 185(2), 510-525. """ xp = get_backend() if not bases: raise ValueError("bases cannot be empty") for i, basis in enumerate(bases): if not hasattr(basis, "ndim"): raise TypeError(f"bases[{i}] must be an array, got {type(basis)}") if basis.ndim != 2: raise ValueError(f"bases[{i}] must be 2-dimensional") n_obs = bases[0].shape[0] for i, basis in enumerate(bases[1:], 1): if basis.shape[0] != n_obs: raise ValueError( f"All matrices must have same number of rows. bases[0] has {n_obs}, bases[{i}] has {basis.shape[0]}" ) if n_obs == 0: return xp.empty((0, 0)) num_bases = len(bases) result_matrices = [] for basis in bases: result_matrices.append(basis) for i, j in combinations(range(num_bases), 2): interaction = _compute_basis_interaction([bases[i], bases[j]]) result_matrices.append(interaction) for order in range(3, num_bases + 1): for indices in combinations(range(num_bases), order): selected_bases = [bases[idx] for idx in indices] interaction = _compute_basis_interaction(selected_bases) result_matrices.append(interaction) if result_matrices: return xp.hstack(result_matrices) return xp.ones((n_obs, 1)) def _compute_basis_interaction(bases): """Compute interaction terms between basis functions. Parameters ---------- bases : list of ndarray List of basis matrices to interact. Returns ------- ndarray Matrix of interaction terms. """ xp = get_backend() if len(bases) == 1: return bases[0] n_obs = bases[0].shape[0] dims = [basis.shape[1] for basis in bases] total_interactions = int(np.prod(dims)) result = xp.empty((n_obs, total_interactions)) for col_idx, indices in enumerate(product(*[range(dim) for dim in dims])): interaction_col = xp.ones(n_obs) for basis_idx, func_idx in enumerate(indices): interaction_col *= bases[basis_idx][:, func_idx] result[:, col_idx] = interaction_col return result