From f0854900e8fac81072780a3403c535c8886e6e91 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 1 Mar 2026 19:23:38 -0500 Subject: [PATCH 1/7] Fix TROP nuclear norm solver: correct proximal gradient step size + FISTA acceleration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The soft-thresholding threshold for the L matrix was λ_nn/max(δ) when the correct value is λ_nn/(2·max(δ)), derived from the Lipschitz constant L_f = 2·max(δ) of the quadratic gradient. This over-shrinking caused singular values to be reduced too aggressively. Fix applied to all four code paths: Python joint, Python twostep, Rust joint, Rust twostep. Also adds FISTA/Nesterov acceleration to the twostep inner solver for faster L convergence (O(1/k²) vs O(1/k)). Co-Authored-By: Claude Opus 4.6 (1M context) --- CHANGELOG.md | 13 ++++++++ diff_diff/trop.py | 27 ++++++++++----- docs/methodology/REGISTRY.md | 10 ++++-- rust/src/trop.rs | 65 +++++++++++++++++++++++++++--------- tests/test_trop.py | 60 +++++++++++++++++++++++++++++++++ 5 files changed, 147 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ae5788..b4a7dc0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Changed +- **BREAKING: TROP nuclear norm solver step size fix** — The proximal gradient + threshold for the L matrix (both `method="global"` and `method="twostep"` with + finite `lambda_nn`) was over-shrinking singular values by a factor of 2. The + soft-thresholding threshold was λ_nn/max(δ) when the correct value is + λ_nn/(2·max(δ)), derived from the Lipschitz constant L_f=2·max(δ) of the + quadratic gradient. This fix produces higher-rank L matrices and closer + agreement with exact convex optimization solutions. Users with finite + `lambda_nn` will observe different ATT estimates. Added FISTA/Nesterov + acceleration to the twostep inner solver for faster L convergence. + ## [2.6.1] - 2026-03-08 ### Added diff --git a/diff_diff/trop.py b/diff_diff/trop.py index 41f02d0..4b1cf28 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -889,8 +889,8 @@ def _solve_joint_with_lowrank( gradient_step = L + delta_norm * (R - L) # Soft-threshold singular values - # Use eta * lambda_nn for proper proximal step size (matches Rust) - eta = 1.0 / delta_max if delta_max > 0 else 1.0 + # Lipschitz constant of ∇f is L_f = 2·max(δ), step size η = 1/L_f + eta = 1.0 / (2.0 * delta_max) if delta_max > 0 else 0.5 L = self._soft_threshold_svd(gradient_step, eta * lambda_nn) # Check convergence @@ -2047,20 +2047,29 @@ def _weighted_nuclear_norm_solve( # Initialize L L = L_init.copy() + L_prev = L.copy() + t_fista = 1.0 - # Proximal gradient iteration with weighted soft-impute + # Proximal gradient iteration with FISTA/Nesterov acceleration # This solves: min_L ||W^{1/2} ⊙ (R - L)||_F^2 + λ||L||_* - # Using: L_{k+1} = prox_{λ/η}(L_k + W ⊙ (R - L_k)) - # where η is the step size (we use η = 1 with normalized weights) + # Lipschitz constant L_f = 2·max(W_norm) = 2, so η = 1/2 + # Threshold = η·λ_nn = λ_nn/2 for _ in range(max_inner_iter): L_old = L.copy() - # Gradient step: L_k + W ⊙ (R - L_k) - # For W=0 observations, this keeps L_k unchanged - gradient_step = L + W_norm * (R_masked - L) + # FISTA momentum + t_fista_new = (1.0 + np.sqrt(1.0 + 4.0 * t_fista**2)) / 2.0 + momentum = (t_fista - 1.0) / t_fista_new + L_momentum = L + momentum * (L - L_prev) + + # Gradient step from momentum point: L_m + W ⊙ (R - L_m) + # For W=0 observations, this keeps L_m unchanged + gradient_step = L_momentum + W_norm * (R_masked - L_momentum) # Proximal step: soft-threshold singular values - L = self._soft_threshold_svd(gradient_step, lambda_nn) + L_prev = L.copy() + L = self._soft_threshold_svd(gradient_step, lambda_nn / 2.0) + t_fista = t_fista_new # Check convergence if np.max(np.abs(L - L_old)) < self.tol: diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index bb7550c..38e41e5 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1200,13 +1200,17 @@ where: 1. **Without low-rank (λ_nn = ∞)**: Standard weighted least squares - Build design matrix with unit/time dummies + treatment indicator - - Solve via iterative coordinate descent for (μ, α, β, τ) + - Solve via np.linalg.lstsq for (μ, α, β, τ) 2. **With low-rank (finite λ_nn)**: Alternating minimization - Alternate between: - Fix L, solve weighted LS for (μ, α, β, τ) - - Fix (μ, α, β, τ), soft-threshold SVD for L (proximal step) - - Continue until convergence + - Fix (μ, α, β, τ), proximal gradient for L: + - Lipschitz constant of ∇f is L_f = 2·max(δ) + - Step size η = 1/L_f = 1/(2·max(δ)) + - Proximal operator: soft_threshold(gradient_step, η·λ_nn) + - Twostep inner solver uses FISTA/Nesterov acceleration (O(1/k²)) + - Continue until max(|L_new - L_old|) < tol **LOOCV parameter selection** (unified with twostep, Equation 5): Following paper's Equation 5 and footnote 2: diff --git a/rust/src/trop.rs b/rust/src/trop.rs index 9e275d5..fb5936c 100644 --- a/rust/src/trop.rs +++ b/rust/src/trop.rs @@ -660,9 +660,9 @@ fn estimate_model( } }); - // Compute step size for proximal gradient: η ≤ 1/max(W) + // Lipschitz constant of ∇f is L_f = 2·max(W), so prox threshold = λ/(2·max(W)) let w_max = w_masked.iter().cloned().fold(0.0_f64, f64::max); - let eta = if w_max > 0.0 { 1.0 / w_max } else { 1.0 }; + let prox_threshold = if w_max > 0.0 { lambda_nn / (2.0 * w_max) } else { lambda_nn / 2.0 }; // Weight sums per unit and time let weight_sum_per_unit: Array1 = w_masked.sum_axis(Axis(0)); @@ -722,9 +722,9 @@ fn estimate_model( } // Step 2: Update L with WEIGHTED nuclear norm penalty - // Paper alignment: Use proximal gradient instead of direct soft-thresholding - // L ← prox_{η·λ_nn·||·||_*}(L + η·(W ⊙ (R - L))) - // where R = Y - α - β + // Inner FISTA-accelerated proximal gradient loop (α, β fixed) + // L ← prox_{threshold·||·||_*}(L + W_norm ⊙ (R - L)) + // where R = Y - α - β, W_norm = W/max(W) // Compute target residual R = Y - α - β let mut r_target = Array2::::zeros((n_periods, n_units)); @@ -734,18 +734,50 @@ fn estimate_model( } } - // Weighted proximal gradient step: - // gradient_step = L + η * W ⊙ (R - L) - // For W=0 cells (treated obs), this keeps L unchanged - let mut gradient_step = Array2::::zeros((n_periods, n_units)); - for t in 0..n_periods { - for i in 0..n_units { - gradient_step[[t, i]] = l[[t, i]] + eta * w_masked[[t, i]] * (r_target[[t, i]] - l[[t, i]]); + // For W=0 cells, use current L instead of R (prevent absorbing treatment) + let r_masked = Array2::from_shape_fn((n_periods, n_units), |(t, i)| { + if w_masked[[t, i]] > 0.0 { r_target[[t, i]] } else { l[[t, i]] } + }); + + // Normalize weights: W_norm = W / W_max (max becomes 1) + let w_norm = Array2::from_shape_fn((n_periods, n_units), |(t, i)| { + if w_max > 0.0 { w_masked[[t, i]] / w_max } else { w_masked[[t, i]] } + }); + + // FISTA inner loop for L update + let mut l_prev = l.clone(); + let mut t_fista = 1.0_f64; + let max_inner_iter = 10; + + for _ in 0..max_inner_iter { + let l_inner_old = l.clone(); + + // FISTA momentum + let t_fista_new = (1.0 + (1.0 + 4.0 * t_fista * t_fista).sqrt()) / 2.0; + let momentum = (t_fista - 1.0) / t_fista_new; + let l_momentum = Array2::from_shape_fn((n_periods, n_units), |(t, i)| { + l[[t, i]] + momentum * (l[[t, i]] - l_prev[[t, i]]) + }); + + // Gradient step from momentum point + let mut gradient_step = Array2::::zeros((n_periods, n_units)); + for t in 0..n_periods { + for i in 0..n_units { + gradient_step[[t, i]] = l_momentum[[t, i]] + w_norm[[t, i]] * (r_masked[[t, i]] - l_momentum[[t, i]]); + } } - } - // Proximal step: soft-threshold singular values with scaled lambda - l = soft_threshold_svd(&gradient_step, eta * lambda_nn)?; + // Proximal step: soft-threshold with corrected threshold + l_prev = l.clone(); + l = soft_threshold_svd(&gradient_step, prox_threshold)?; + t_fista = t_fista_new; + + // Check inner convergence + let l_inner_diff = max_abs_diff_2d(&l, &l_inner_old); + if l_inner_diff < tol { + break; + } + } // Check convergence let alpha_diff = max_abs_diff(&alpha, &alpha_old); @@ -1327,8 +1359,9 @@ fn solve_joint_with_lowrank( } // Weighted proximal step for L (soft-threshold SVD) + // Lipschitz constant of ∇f is L_f = 2·max(δ), step size η = 1/L_f let delta_max = delta.iter().cloned().fold(0.0_f64, f64::max); - let eta = if delta_max > 0.0 { 1.0 / delta_max } else { 1.0 }; + let eta = if delta_max > 0.0 { 1.0 / (2.0 * delta_max) } else { 0.5 }; // gradient_step = L + eta * delta * (R - L) // NaN outcomes get zero weight so they don't affect gradient diff --git a/tests/test_trop.py b/tests/test_trop.py index 41a137f..7ac5d4b 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -2592,6 +2592,66 @@ def test_n_post_periods_counts_observed_treatment(self): ) +class TestTROPNuclearNormSolver: + """Tests for proximal gradient step size correctness and objective monotonicity.""" + + def test_proximal_step_size_correctness(self): + """Verify L converges to prox_{λ/2}(R) for uniform weights.""" + trop_est = TROP(method="joint", n_bootstrap=2) + + # Small problem with known solution + rng = np.random.default_rng(42) + R = rng.normal(0, 1, (4, 3)) + delta = np.ones((4, 3)) + lambda_nn = 0.5 + + # Run solver (many iterations to ensure convergence) + L = np.zeros_like(R) + for _ in range(500): + delta_max = np.max(delta) + delta_norm = delta / delta_max + gradient_step = L + delta_norm * (R - L) + eta = 1.0 / (2.0 * delta_max) + L = trop_est._soft_threshold_svd(gradient_step, eta * lambda_nn) + + # Analytical solution for uniform weights: prox_{λ/2}(R) + L_exact = trop_est._soft_threshold_svd(R, lambda_nn / 2.0) + + np.testing.assert_array_almost_equal(L, L_exact, decimal=4) + + def test_lowrank_objective_decreases(self): + """Verify objective f(L) + λ||L||_* is non-increasing across iterations.""" + # Generate small problem + rng = np.random.default_rng(42) + R = rng.normal(0, 1, (6, 4)) + delta = rng.uniform(0.5, 2.0, (6, 4)) + lambda_nn = 0.3 + + trop_est = TROP(method="joint", n_bootstrap=2) + L = np.zeros_like(R) + objectives = [] + + for _ in range(50): + # Compute objective + f_val = np.sum(delta * (R - L) ** 2) + _, s, _ = np.linalg.svd(L, full_matrices=False) + obj = f_val + lambda_nn * np.sum(s) + objectives.append(obj) + + # Proximal gradient step + delta_max = np.max(delta) + delta_norm = delta / delta_max + gradient_step = L + delta_norm * (R - L) + eta = 1.0 / (2.0 * delta_max) + L = trop_est._soft_threshold_svd(gradient_step, eta * lambda_nn) + + # Objective should be non-increasing (within numerical tolerance) + for k in range(1, len(objectives)): + assert objectives[k] <= objectives[k - 1] + 1e-10, ( + f"Objective increased at step {k}: {objectives[k]} > {objectives[k-1]}" + ) + + class TestTROPJointMethod: """Tests for TROP method='joint'. From 6d0a8becf4a162d3c53e40b860b6a188752a1273 Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 7 Mar 2026 16:59:33 -0500 Subject: [PATCH 2/7] Fix twostep nuclear norm solver threshold scaling for non-uniform weights MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The proximal threshold was hardcoded as λ/(2) which is only correct when W_max=1. Changed to λ/(2·W_max) to match the joint solver and Rust backend. Added test with non-uniform weights and updated REGISTRY.md algorithm docs. Co-Authored-By: Claude Opus 4.6 --- diff_diff/trop.py | 6 +++--- docs/methodology/REGISTRY.md | 9 ++++++-- tests/test_trop.py | 42 ++++++++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/diff_diff/trop.py b/diff_diff/trop.py index 4b1cf28..61ee34a 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -2052,8 +2052,8 @@ def _weighted_nuclear_norm_solve( # Proximal gradient iteration with FISTA/Nesterov acceleration # This solves: min_L ||W^{1/2} ⊙ (R - L)||_F^2 + λ||L||_* - # Lipschitz constant L_f = 2·max(W_norm) = 2, so η = 1/2 - # Threshold = η·λ_nn = λ_nn/2 + # Lipschitz constant L_f = 2·max(W), so η = 1/(2·max(W)) + # Threshold = η·λ_nn = λ_nn/(2·max(W)) for _ in range(max_inner_iter): L_old = L.copy() @@ -2068,7 +2068,7 @@ def _weighted_nuclear_norm_solve( # Proximal step: soft-threshold singular values L_prev = L.copy() - L = self._soft_threshold_svd(gradient_step, lambda_nn / 2.0) + L = self._soft_threshold_svd(gradient_step, lambda_nn / (2.0 * W_max)) t_fista = t_fista_new # Check convergence diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 38e41e5..2046286 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1081,10 +1081,15 @@ Optimization (Equation 2): ``` (α̂, β̂, L̂) = argmin_{α,β,L} Σ_j Σ_s θ_s^{i,t} ω_j^{i,t} (1-W_js)(Y_js - α_j - β_s - L_js)² + λ_nn ||L||_* ``` -Solved via alternating minimization with soft-thresholding of singular values for L: +Solved via alternating minimization. For α, β (or μ, α, β, τ in joint): weighted least +squares (closed form). For L: proximal gradient with step size η = 1/(2·max(W)): ``` -L̂ = U × soft_threshold(Σ, λ_nn) × V' +Gradient step: G = L + (W/max(W)) ⊙ (R - L) +Proximal step: L = U × soft_threshold(Σ, η·λ_nn) × V' (SVD of G = UΣV') ``` +where R is the residual after removing fixed effects (and τ·D in joint mode). +The twostep solver's inner L update uses FISTA/Nesterov acceleration (O(1/k²) convergence); +the Python joint solver uses a single proximal gradient step per outer alternating iteration. Per-observation weights (Equation 3): ``` diff --git a/tests/test_trop.py b/tests/test_trop.py index 7ac5d4b..abf21e3 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -2651,6 +2651,48 @@ def test_lowrank_objective_decreases(self): f"Objective increased at step {k}: {objectives[k]} > {objectives[k-1]}" ) + def test_twostep_nonuniform_weights_objective(self): + """Verify objective decreases with non-uniform weights (W_max < 1).""" + rng = np.random.default_rng(123) + R = rng.normal(0, 1, (6, 4)) + W = rng.uniform(0.1, 0.8, (6, 4)) + lambda_nn = 0.3 + + trop_est = TROP(method="twostep", n_bootstrap=2) + + # Initial objective with L=0 + L_init = np.zeros_like(R) + f_init = np.sum(W * (R - L_init) ** 2) + _, s_init, _ = np.linalg.svd(L_init, full_matrices=False) + obj_init = f_init + lambda_nn * np.sum(s_init) + + # Solve + L_final = trop_est._weighted_nuclear_norm_solve( + Y=R, + W=W, + L_init=L_init, + alpha=np.zeros(R.shape[1]), + beta=np.zeros(R.shape[0]), + lambda_nn=lambda_nn, + max_inner_iter=20, + ) + + # Final objective + f_final = np.sum(W * (R - L_final) ** 2) + _, s_final, _ = np.linalg.svd(L_final, full_matrices=False) + obj_final = f_final + lambda_nn * np.sum(s_final) + + assert obj_final <= obj_init + 1e-10, ( + f"Objective did not decrease: {obj_final} > {obj_init}" + ) + + # Soft-thresholding should reduce nuclear norm vs residual + nuclear_norm_R = np.sum(np.linalg.svd(R, compute_uv=False)) + nuclear_norm_L = np.sum(s_final) + assert nuclear_norm_L < nuclear_norm_R, ( + f"Nuclear norm not reduced: {nuclear_norm_L} >= {nuclear_norm_R}" + ) + class TestTROPJointMethod: """Tests for TROP method='joint'. From 44124ca15f9c71f145851955b9b780f3ececf176 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 8 Mar 2026 08:27:51 -0400 Subject: [PATCH 3/7] Guard W_max==0 division in twostep nuclear norm solver + update docstrings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add conditional threshold when W_max==0 to prevent ZeroDivisionError, matching Rust backend behavior (trop.rs:665) - Update Python and Rust docstrings to reflect correct FISTA/Nesterov acceleration formulas (L_f = 2·max(W), η = 1/(2·max(W))) - Add regression test for all-zero weights edge case Co-Authored-By: Claude Opus 4.6 --- diff_diff/trop.py | 12 +++++++----- rust/src/trop.rs | 7 ++++--- tests/test_trop.py | 20 ++++++++++++++++++++ 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/diff_diff/trop.py b/diff_diff/trop.py index 61ee34a..652c940 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -1990,10 +1990,11 @@ def _weighted_nuclear_norm_solve( paper's Equation 2 (page 7). The full objective is: min_L Σ W_{ti}(R_{ti} - L_{ti})² + λ_nn||L||_* - This uses a proximal gradient / soft-impute approach (Mazumder et al. 2010): - L_{k+1} = prox_{λ||·||_*}(L_k + W ⊙ (R - L_k)) - - where W ⊙ denotes element-wise multiplication with normalized weights. + This uses proximal gradient descent (Mazumder et al. 2010) with + FISTA/Nesterov acceleration. Lipschitz constant L_f = 2·max(W), + step size η = 1/(2·max(W)), proximal threshold η·λ_nn: + G_k = L_k + (W/max(W)) ⊙ (R - L_k) + L_{k+1} = prox_{η·λ_nn·||·||_*}(G_k) IMPORTANT: For observations with W=0 (treated observations), we keep L values from the previous iteration rather than setting L = R, which @@ -2068,7 +2069,8 @@ def _weighted_nuclear_norm_solve( # Proximal step: soft-threshold singular values L_prev = L.copy() - L = self._soft_threshold_svd(gradient_step, lambda_nn / (2.0 * W_max)) + threshold = lambda_nn / (2.0 * W_max) if W_max > 0 else lambda_nn / 2.0 + L = self._soft_threshold_svd(gradient_step, threshold) t_fista = t_fista_new # Check convergence diff --git a/rust/src/trop.rs b/rust/src/trop.rs index fb5936c..b503f4c 100644 --- a/rust/src/trop.rs +++ b/rust/src/trop.rs @@ -620,9 +620,10 @@ fn compute_weight_matrix( /// /// Minimizes: Σ W_{ti}(Y_{ti} - α_i - β_t - L_{ti})² + λ_nn||L||_* /// -/// Paper alignment: Uses weighted proximal gradient for L update: -/// L ← prox_{η·λ_nn·||·||_*}(L + η·(W ⊙ (R - L))) -/// where η ≤ 1/max(W) for convergence. +/// Paper alignment: Uses weighted proximal gradient for L update with +/// Lipschitz constant L_f = 2·max(W), step size η = 1/(2·max(W)): +/// G = L + (W/max(W)) ⊙ (R - L) +/// L ← prox_{η·λ_nn·||·||_*}(G) /// /// Returns None if estimation fails due to numerical issues. #[allow(clippy::too_many_arguments)] diff --git a/tests/test_trop.py b/tests/test_trop.py index abf21e3..726c2e6 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -2693,6 +2693,26 @@ def test_twostep_nonuniform_weights_objective(self): f"Nuclear norm not reduced: {nuclear_norm_L} >= {nuclear_norm_R}" ) + def test_zero_weights_no_division_error(self): + """Verify solver handles all-zero weights without ZeroDivisionError.""" + rng = np.random.default_rng(99) + Y = rng.normal(0, 1, (6, 4)) + W = np.zeros((6, 4)) + L_init = rng.normal(0, 1, (6, 4)) + + trop_est = TROP(method="twostep", n_bootstrap=2) + result = trop_est._weighted_nuclear_norm_solve( + Y=Y, + W=W, + L_init=L_init, + alpha=np.zeros(4), + beta=np.zeros(6), + lambda_nn=0.3, + ) + + assert np.isfinite(result).all(), "Result contains NaN or Inf" + assert result.shape == (6, 4), f"Expected (6, 4), got {result.shape}" + class TestTROPJointMethod: """Tests for TROP method='joint'. From bb6bf5c54ce7c2898fb6497167fc4c98bf92a525 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 8 Mar 2026 18:31:35 -0400 Subject: [PATCH 4/7] =?UTF-8?q?Add=20(1-W)=20weight=20masking=20to=20TROP?= =?UTF-8?q?=20global=20method=20+=20rename=20joint=E2=86=92global?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Align the TROP global method with the paper's Eq. 2 by adding (1-W) masking so the model is fit on control data only, then extracting treatment effects post-hoc as residuals (tau_it = Y - mu - alpha - beta - L). Key changes: - Apply (1-W) masking in _compute_joint_weights, zeroing treated cells - Remove tau from the joint solvers (no longer identifiable under masking) - Extract per-observation treatment effects post-hoc; ATT = mean(tau_it) - Add FISTA/Nesterov acceleration to the nuclear norm solver (O(1/k²)) - Rename method='joint' to method='global' with FutureWarning deprecation - Extract _solve_joint_model and _extract_posthoc_tau helpers to reduce duplication - Mirror all changes in Rust backend Monte Carlo validation (20 reps × 5 configs) shows: - No-lowrank configs: exact match with CVXPY reference (|Δτ| = 0) - Low-rank configs: mean |Δτ| = 0.0004 (λ_nn=0.1), 0.026 (λ_nn=0.01) - 100% of comparisons within 0.10 of reference Co-Authored-By: Claude Opus 4.6 (1M context) --- .gitignore | 3 + diff_diff/trop.py | 277 ++++++++++++++++++++++------------- docs/api/trop.rst | 42 +++--- docs/methodology/REGISTRY.md | 48 +++--- docs/tutorials/10_trop.ipynb | 6 +- rust/src/trop.rs | 213 +++++++++++++++------------ tests/test_trop.py | 199 ++++++++++++++++++++++++- 7 files changed, 546 insertions(+), 242 deletions(-) diff --git a/.gitignore b/.gitignore index c1aa8ce..1e39833 100644 --- a/.gitignore +++ b/.gitignore @@ -87,3 +87,6 @@ trop_avg_ref/ # Academic papers (local only, not for distribution) papers/ + +# Local analysis notebooks (not committed) +analysis/ diff --git a/diff_diff/trop.py b/diff_diff/trop.py index 652c940..656daa8 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -71,11 +71,16 @@ class TROP: a model for each treated observation, averaging the individual treatment effects. More flexible but computationally intensive. - - 'joint': Joint weighted least squares optimization. Estimates a - single scalar treatment effect τ along with fixed effects and - optional low-rank factor adjustment. Faster but assumes homogeneous - treatment effects. Uses alternating minimization when nuclear norm - penalty is finite. + - 'global': Global weighted least squares with post-hoc treatment + effect extraction. Fits a single model on control observations + using (1-W) masked weights per paper Eq. 2, then computes + per-observation treatment effects as residuals: + tau_it = Y_it - mu - alpha_i - beta_t - L_it for treated cells. + ATT is the mean of these effects. Faster than twostep but uses + global weights instead of per-observation weights. + + - 'joint': Deprecated alias for 'global'. Will be removed in a + future version. lambda_time_grid : list, optional Grid of time weight decay parameters. 0.0 = uniform weights (disabled). @@ -144,11 +149,20 @@ def __init__( seed: Optional[int] = None, ): # Validate method parameter - valid_methods = ("twostep", "joint") + # 'global' is the preferred name; 'joint' is a deprecated alias + valid_methods = ("twostep", "joint", "global") if method not in valid_methods: raise ValueError( f"method must be one of {valid_methods}, got '{method}'" ) + if method == "joint": + warnings.warn( + "method='joint' is deprecated and will be removed in a future " + "version. Use method='global' instead.", + FutureWarning, + stacklevel=2, + ) + method = "global" self.method = method # Default grids from paper @@ -635,8 +649,75 @@ def _compute_joint_weights( # Outer product: (n_periods x n_units) delta = np.outer(delta_time, delta_unit) + # (1-W) masking: zero out treated observations per paper Eq. 2 + # Model is fit on control data only; tau extracted post-hoc + delta = delta * (1 - D) + return delta + def _solve_joint_model( + self, + Y: np.ndarray, + delta: np.ndarray, + lambda_nn: float, + ) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]: + """ + Dispatch to no-lowrank or with-lowrank solver based on lambda_nn. + + Returns (mu, alpha, beta, L) in all cases. + """ + n_periods, n_units = Y.shape + if lambda_nn >= 1e10: + mu, alpha, beta = self._solve_joint_no_lowrank(Y, delta) + L = np.zeros((n_periods, n_units)) + else: + mu, alpha, beta, L = self._solve_joint_with_lowrank( + Y, delta, lambda_nn, self.max_iter, self.tol + ) + return mu, alpha, beta, L + + @staticmethod + def _extract_posthoc_tau( + Y: np.ndarray, + D: np.ndarray, + mu: float, + alpha: np.ndarray, + beta: np.ndarray, + L: np.ndarray, + idx_to_unit: Optional[Dict] = None, + idx_to_period: Optional[Dict] = None, + ) -> Tuple[float, Dict, List[float]]: + """ + Extract post-hoc treatment effects: tau_it = Y - mu - alpha - beta - L. + + Returns (att, treatment_effects_dict, tau_values_list). + When idx_to_unit/idx_to_period are None, treatment_effects uses raw indices. + """ + counterfactual = mu + alpha[np.newaxis, :] + beta[:, np.newaxis] + L + tau_matrix = Y - counterfactual + + treated_mask = D == 1 + finite_mask = np.isfinite(Y) + valid_treated = treated_mask & finite_mask + + tau_values = tau_matrix[valid_treated].tolist() + att = float(np.mean(tau_values)) if tau_values else np.nan + + # Build treatment effects dict + treatment_effects: Dict = {} + n_periods, n_units = D.shape + for t in range(n_periods): + for i in range(n_units): + if D[t, i] == 1: + uid = idx_to_unit[i] if idx_to_unit is not None else i + tid = idx_to_period[t] if idx_to_period is not None else t + if finite_mask[t, i]: + treatment_effects[(uid, tid)] = tau_matrix[t, i] + else: + treatment_effects[(uid, tid)] = np.nan + + return att, treatment_effects, tau_values + def _loocv_score_joint( self, Y: np.ndarray, @@ -698,14 +779,7 @@ def _loocv_score_joint( delta_ex[t_ex, i_ex] = 0.0 try: - # Fit joint model excluding this observation - if lambda_nn >= 1e10: - mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y, D, delta_ex) - L = np.zeros((n_periods, n_units)) - else: - mu, alpha, beta, L, tau = self._solve_joint_with_lowrank( - Y, D, delta_ex, lambda_nn, self.max_iter, self.tol - ) + mu, alpha, beta, L = self._solve_joint_model(Y, delta_ex, lambda_nn) # Pseudo treatment effect: τ = Y - μ - α - β - L if np.isfinite(Y[t_ex, i_ex]): @@ -725,33 +799,32 @@ def _loocv_score_joint( def _solve_joint_no_lowrank( self, Y: np.ndarray, - D: np.ndarray, delta: np.ndarray, - ) -> Tuple[float, np.ndarray, np.ndarray, float]: + ) -> Tuple[float, np.ndarray, np.ndarray]: """ - Solve joint TWFE + treatment via weighted least squares (no low-rank). + Solve TWFE via weighted least squares on control data (no low-rank). + + Solves: min Σ (1-W)*δ_{it}(Y_{it} - μ - α_i - β_t)² - Solves: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - τ*W_{it})² + The (1-W) masking is already applied to delta by _compute_joint_weights, + so treated observations have zero weight and do not affect the fit. Parameters ---------- Y : np.ndarray Outcome matrix (n_periods x n_units). - D : np.ndarray - Treatment indicator matrix (n_periods x n_units). delta : np.ndarray - Weight matrix (n_periods x n_units). + Weight matrix (n_periods x n_units), already (1-W) masked. Returns ------- - Tuple[float, np.ndarray, np.ndarray, float] - (mu, alpha, beta, tau) estimated parameters. + Tuple[float, np.ndarray, np.ndarray] + (mu, alpha, beta) estimated parameters. """ n_periods, n_units = Y.shape # Flatten matrices for regression y = Y.flatten() # length n_periods * n_units - w = D.flatten() weights = delta.flatten() # Handle NaN values: zero weight for NaN outcomes/weights, impute with 0 @@ -769,12 +842,10 @@ def _solve_joint_no_lowrank( if sum_w < 1e-10: raise ValueError("All weights are zero - cannot estimate") - # Build design matrix: [intercept, unit_dummies, time_dummies, treatment] - # Total columns: 1 + n_units + n_periods + 1 - # But we need to drop one unit and one time dummy for identification - # Drop first unit (unit 0) and first time (time 0) + # Build design matrix: [intercept, unit_dummies, time_dummies] + # Drop first unit (unit 0) and first time (time 0) for identification n_obs = n_periods * n_units - n_params = 1 + (n_units - 1) + (n_periods - 1) + 1 + n_params = 1 + (n_units - 1) + (n_periods - 1) X = np.zeros((n_obs, n_params)) X[:, 0] = 1.0 # intercept @@ -789,9 +860,6 @@ def _solve_joint_no_lowrank( for i in range(n_units): X[t * n_units + i, (n_units - 1) + t] = 1.0 - # Treatment indicator - X[:, -1] = w - # Apply weights X_weighted = X * sqrt_weights[:, np.newaxis] y_weighted = y * sqrt_weights @@ -809,32 +877,31 @@ def _solve_joint_no_lowrank( alpha[1:] = coeffs[1:n_units] beta = np.zeros(n_periods) beta[1:] = coeffs[n_units:(n_units + n_periods - 1)] - tau = coeffs[-1] - return float(mu), alpha, beta, float(tau) + return float(mu), alpha, beta def _solve_joint_with_lowrank( self, Y: np.ndarray, - D: np.ndarray, delta: np.ndarray, lambda_nn: float, max_iter: int = 100, tol: float = 1e-6, - ) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, float]: + ) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]: """ - Solve joint TWFE + treatment + low-rank via alternating minimization. + Solve TWFE + low-rank on control data via alternating minimization. - Solves: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - L_{it} - τ*W_{it})² + λ_nn||L||_* + Solves: min Σ (1-W)*δ_{it}(Y_{it} - μ - α_i - β_t - L_{it})² + λ_nn||L||_* + + The (1-W) masking is already applied to delta by _compute_joint_weights, + so treated observations have zero weight and do not affect the fit. Parameters ---------- Y : np.ndarray Outcome matrix (n_periods x n_units). - D : np.ndarray - Treatment indicator matrix (n_periods x n_units). delta : np.ndarray - Weight matrix (n_periods x n_units). + Weight matrix (n_periods x n_units), already (1-W) masked. lambda_nn : float Nuclear norm regularization parameter. max_iter : int, default=100 @@ -844,8 +911,8 @@ def _solve_joint_with_lowrank( Returns ------- - Tuple[float, np.ndarray, np.ndarray, np.ndarray, float] - (mu, alpha, beta, L, tau) estimated parameters. + Tuple[float, np.ndarray, np.ndarray, np.ndarray] + (mu, alpha, beta, L) estimated parameters. """ n_periods, n_units = Y.shape @@ -859,45 +926,60 @@ def _solve_joint_with_lowrank( delta_masked = delta.copy() delta_masked[nan_mask] = 0.0 + # Precompute normalized weights and threshold (constant across iterations) + delta_max = np.max(delta_masked) + if delta_max > 0: + delta_norm = delta_masked / delta_max + else: + delta_norm = delta_masked + threshold = lambda_nn / (2.0 * delta_max) if delta_max > 0 else lambda_nn / 2.0 + # Initialize L = 0 L = np.zeros((n_periods, n_units)) for iteration in range(max_iter): L_old = L.copy() - # Step 1: Fix L, solve for (mu, alpha, beta, tau) - # Adjusted outcome: Y - L (using NaN-safe Y) - # Pass masked delta to exclude NaN observations from WLS + # Step 1: Fix L, solve for (mu, alpha, beta) Y_adj = Y_safe - L - mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y_adj, D, delta_masked) + mu, alpha, beta = self._solve_joint_no_lowrank(Y_adj, delta_masked) - # Step 2: Fix (mu, alpha, beta, tau), update L - # Residual: R = Y - mu - alpha - beta - tau*D (using NaN-safe Y) - R = Y_safe - mu - alpha[np.newaxis, :] - beta[:, np.newaxis] - tau * D + # Step 2: Fix (mu, alpha, beta), update L with FISTA acceleration + R = Y_safe - mu - alpha[np.newaxis, :] - beta[:, np.newaxis] - # Weighted proximal step for L (soft-threshold SVD) - # Normalize weights (using masked delta to exclude NaN observations) - delta_max = np.max(delta_masked) - if delta_max > 0: - delta_norm = delta_masked / delta_max - else: - delta_norm = delta_masked + # For delta=0 observations (treated/NaN), keep L rather than R + R_masked = np.where(delta_masked > 0, R, L) - # Weighted average between current L and target R - # L_next = L + delta_norm * (R - L), then soft-threshold - # NaN observations have delta_norm=0, so they don't influence L update - gradient_step = L + delta_norm * (R - L) + # Inner FISTA loop for L update + L_inner = L.copy() + L_inner_prev = L_inner # share reference initially (no copy needed) + t_fista = 1.0 - # Soft-threshold singular values - # Lipschitz constant of ∇f is L_f = 2·max(δ), step size η = 1/L_f - eta = 1.0 / (2.0 * delta_max) if delta_max > 0 else 0.5 - L = self._soft_threshold_svd(gradient_step, eta * lambda_nn) + for _ in range(20): + # FISTA momentum + t_fista_new = (1.0 + np.sqrt(1.0 + 4.0 * t_fista**2)) / 2.0 + momentum = (t_fista - 1.0) / t_fista_new + L_momentum = L_inner + momentum * (L_inner - L_inner_prev) - # Check convergence + # Gradient step from momentum point + gradient_step = L_momentum + delta_norm * (R_masked - L_momentum) + + # Proximal step: soft-threshold singular values + L_inner_prev = L_inner + L_inner = self._soft_threshold_svd(gradient_step, threshold) + t_fista = t_fista_new + + # Convergence check (L_inner_prev holds the pre-SVD value) + if np.max(np.abs(L_inner - L_inner_prev)) < tol: + break + + L = L_inner + + # Outer convergence check if np.max(np.abs(L - L_old)) < tol: break - return mu, alpha, beta, L, tau + return mu, alpha, beta, L def _fit_joint( self, @@ -908,10 +990,11 @@ def _fit_joint( time: str, ) -> TROPResults: """ - Fit TROP using joint weighted least squares method. + Fit TROP using global weighted least squares method. - This method estimates a single scalar treatment effect τ along with - fixed effects and optional low-rank factor adjustment. + Fits a single model on control observations using (1-W) masked weights, + then extracts per-observation treatment effects as post-hoc residuals. + ATT is the mean of these heterogeneous effects. Parameters ---------- @@ -1026,7 +1109,7 @@ def _fit_joint( unique_starts = sorted(set(first_treat_by_unit)) if len(unique_starts) > 1: raise ValueError( - f"method='joint' requires simultaneous treatment adoption, but your data " + f"method='global' requires simultaneous treatment adoption, but your data " f"shows staggered adoption (units first treated at periods {unique_starts}). " f"Use method='twostep' which properly handles staggered adoption designs." ) @@ -1140,25 +1223,12 @@ def _fit_joint( Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods ) - if lambda_nn >= 1e10: - mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y, D, delta) - L = np.zeros((n_periods, n_units)) - else: - mu, alpha, beta, L, tau = self._solve_joint_with_lowrank( - Y, D, delta, lambda_nn, self.max_iter, self.tol - ) + mu, alpha, beta, L = self._solve_joint_model(Y, delta, lambda_nn) - # ATT is the scalar treatment effect - att = tau - - # Compute individual treatment effects for reporting (same τ for all) - treatment_effects = {} - for t in range(n_periods): - for i in range(n_units): - if D[t, i] == 1: - unit_id = idx_to_unit[i] - time_id = idx_to_period[t] - treatment_effects[(unit_id, time_id)] = tau + # Post-hoc tau extraction (per paper Eq. 2) + att, treatment_effects, tau_values = self._extract_posthoc_tau( + Y, D, mu, alpha, beta, L, idx_to_unit, idx_to_period + ) # Compute effective rank of L _, s, _ = np.linalg.svd(L, full_matrices=False) @@ -1363,9 +1433,9 @@ def _fit_joint_with_fixed_lambda( treated_periods: int, ) -> float: """ - Fit joint model with fixed tuning parameters. + Fit global model with fixed tuning parameters. - Returns only the treatment effect τ. + Returns the ATT (mean of post-hoc per-observation treatment effects). """ lambda_time, lambda_unit, lambda_nn = fixed_lambda @@ -1388,20 +1458,15 @@ def _fit_joint_with_fixed_lambda( .values ) - # Compute weights + # Compute weights (includes (1-W) masking) delta = self._compute_joint_weights( Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods ) - # Fit model - if lambda_nn >= 1e10: - _, _, _, tau = self._solve_joint_no_lowrank(Y, D, delta) - else: - _, _, _, _, tau = self._solve_joint_with_lowrank( - Y, D, delta, lambda_nn, self.max_iter, self.tol - ) - - return tau + # Fit model on control data and extract post-hoc tau + mu, alpha, beta, L = self._solve_joint_model(Y, delta, lambda_nn) + att, _, _ = self._extract_posthoc_tau(Y, D, mu, alpha, beta, L) + return att def fit( self, @@ -1456,7 +1521,7 @@ def fit( raise ValueError(f"Missing columns: {missing}") # Dispatch based on estimation method - if self.method == "joint": + if self.method == "global": return self._fit_joint(data, outcome, treatment, unit, time) # Below is the twostep method (default) @@ -2555,6 +2620,14 @@ def get_params(self) -> Dict[str, Any]: def set_params(self, **params) -> "TROP": """Set estimator parameters.""" for key, value in params.items(): + if key == "method" and value == "joint": + warnings.warn( + "method='joint' is deprecated and will be removed in a " + "future version. Use method='global' instead.", + FutureWarning, + stacklevel=2, + ) + value = "global" if hasattr(self, key): setattr(self, key, value) else: diff --git a/docs/api/trop.rst b/docs/api/trop.rst index e359b41..65b6d9f 100644 --- a/docs/api/trop.rst +++ b/docs/api/trop.rst @@ -119,26 +119,32 @@ This provides the **triple robustness** property (Theorem 5.1): the estimator is consistent if any one of the three components (unit weights, time weights, factor model) is correctly specified. -**Joint Method** (``method='joint'``) +**Global Method** (``method='global'``) -An alternative approach that estimates a single scalar treatment effect: +An alternative approach that fits a single model on control data and extracts +treatment effects as post-hoc residuals: 1. **Compute weights**: Distance-based unit and time weights computed once - (distance to center of treated block, RMSE to average treated trajectory) + (distance to center of treated block, RMSE to average treated trajectory), + with ``(1-W)`` masking to zero out treated observations (per paper Eq. 2). -2. **Joint optimization**: Solve weighted least squares problem +2. **Fit control model**: Solve weighted least squares on control data only .. math:: - \min_{\mu, \alpha, \beta, L, \tau} \sum_{i,t} \delta_{it} (Y_{it} - \mu - \alpha_i - \beta_t - L_{it} - W_{it} \tau)^2 + \lambda_{nn} \|L\|_* + \min_{\mu, \alpha, \beta, L} \sum_{i,t} (1 - W_{it}) \delta_{it} (Y_{it} - \mu - \alpha_i - \beta_t - L_{it})^2 + \lambda_{nn} \|L\|_* - where τ is a **single scalar** (homogeneous treatment effect). +3. **Post-hoc treatment effects**: For each treated observation: -3. **With low-rank** (finite λ_nn): Uses alternating minimization between - weighted LS for (μ, α, β, τ) and soft-threshold SVD for L. + .. math:: + + \hat{\tau}_{it} = Y_{it} - \hat{\mu} - \hat{\alpha}_i - \hat{\beta}_t - \hat{L}_{it}, \quad \text{ATT} = \text{mean}(\hat{\tau}_{it}) + +The global method is **faster** (single optimization vs N_treated optimizations). +Treatment effects are **heterogeneous** per-observation residuals; ATT is their mean. -The joint method is **faster** (single optimization vs N_treated optimizations) -but assumes **homogeneous treatment effects** across all treated observations. +``method='joint'`` is a deprecated alias for ``method='global'`` and will be +removed in a future version. .. list-table:: :header-rows: 1 @@ -146,13 +152,13 @@ but assumes **homogeneous treatment effects** across all treated observations. * - Feature - Two-Step (default) - - Joint + - Global * - Treatment effect - - Per-observation τ_{it} - - Single scalar τ - * - Flexibility - - Heterogeneous effects - - Homogeneous assumption + - Per-observation τ_{it} (per-obs models) + - Per-observation τ_{it} (single model) + * - Fitting + - N_treated models with tailored weights + - One model with global weights * - Speed - Slower (N_treated fits) - Faster (single fit) @@ -160,8 +166,8 @@ but assumes **homogeneous treatment effects** across all treated observations. - Observation-specific - Global (center of treated block) -Use ``method='twostep'`` when treatment effects may vary across observations. -Use ``method='joint'`` for faster estimation when effects are expected to be homogeneous. +Use ``method='twostep'`` for observation-specific weight optimization. +Use ``method='global'`` for faster estimation with global weights. Example Usage ------------- diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 2046286..48cca33 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1175,48 +1175,58 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² - [x] D matrix semantics documented (absorbing state, not event indicator) - [x] Unbalanced panels supported (missing observations don't trigger false violations) -### TROP Joint Optimization Method +### TROP Global Estimation Method -**Method**: `method="joint"` in TROP estimator +**Method**: `method="global"` in TROP estimator (`method="joint"` is a deprecated alias) -**Approach**: Joint weighted least squares with optional nuclear norm penalty. -Estimates fixed effects, factor matrix, and scalar treatment effect simultaneously. +**Approach**: Global weighted least squares on control data with (1-W) masking, +followed by post-hoc treatment effect extraction. Per paper Eq. 2. -**Objective function** (Equation J1): +**Objective function** (Equation G1): ``` -min_{μ, α, β, L, τ} Σ_{i,t} δ_{it} × (Y_{it} - μ - α_i - β_t - L_{it} - W_{it}×τ)² + λ_nn×||L||_* +min_{μ, α, β, L} Σ_{i,t} (1-W_{it}) × δ_{it} × (Y_{it} - μ - α_i - β_t - L_{it})² + λ_nn×||L||_* ``` where: +- (1-W_{it}) masks out treated observations — model is fit on control data only - δ_{it} = δ_time(t) × δ_unit(i) are observation weights (product of time and unit weights) - μ is the intercept - α_i are unit fixed effects - β_t are time fixed effects - L_{it} is the low-rank factor component -- τ is a **single scalar** (homogeneous treatment effect assumption) -- W_{it} is the treatment indicator + +**Post-hoc treatment effect extraction**: +``` +τ̂_{it} = Y_{it} - μ̂ - α̂_i - β̂_t - L̂_{it} for all (i,t) where W_{it} = 1 +ATT = mean(τ̂_{it}) over all treated observations +``` + +Treatment effects are **heterogeneous** per-observation values. ATT is their mean. **Weight computation** (differs from twostep): - Time weights: δ_time(t) = exp(-λ_time × |t - center|) where center = T - treated_periods/2 - Unit weights: δ_unit(i) = exp(-λ_unit × RMSE(i, treated_avg)) where RMSE is computed over pre-treatment periods comparing to average treated trajectory +- (1-W) masking applied after outer product: δ_{it} = 0 for all treated cells **Implementation approach** (without CVXPY): 1. **Without low-rank (λ_nn = ∞)**: Standard weighted least squares - - Build design matrix with unit/time dummies + treatment indicator - - Solve via np.linalg.lstsq for (μ, α, β, τ) + - Build design matrix with unit/time dummies (no treatment indicator) + - Solve via np.linalg.lstsq for (μ, α, β) using (1-W)-masked weights 2. **With low-rank (finite λ_nn)**: Alternating minimization - Alternate between: - - Fix L, solve weighted LS for (μ, α, β, τ) - - Fix (μ, α, β, τ), proximal gradient for L: + - Fix L, solve weighted LS for (μ, α, β) + - Fix (μ, α, β), proximal gradient for L: - Lipschitz constant of ∇f is L_f = 2·max(δ) - Step size η = 1/L_f = 1/(2·max(δ)) - Proximal operator: soft_threshold(gradient_step, η·λ_nn) - Twostep inner solver uses FISTA/Nesterov acceleration (O(1/k²)) - Continue until max(|L_new - L_old|) < tol +3. **Post-hoc**: Extract τ̂_{it} = Y_{it} - μ̂ - α̂_i - β̂_t - L̂_{it} for treated cells + **LOOCV parameter selection** (unified with twostep, Equation 5): Following paper's Equation 5 and footnote 2: ``` @@ -1225,10 +1235,10 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² where τ̂_js^loocv is the pseudo-treatment effect at control observation (j,s) with that observation excluded from fitting. -For joint method, LOOCV works as follows: +For global method, LOOCV works as follows: 1. For each control observation (t, i): - Zero out weight δ_{ti} = 0 (exclude from weighted objective) - - Fit joint model on remaining data → obtain (μ̂, α̂, β̂, L̂) + - Fit global model on remaining data → obtain (μ̂, α̂, β̂, L̂) - Compute pseudo-treatment: τ̂_{ti} = Y_{ti} - μ̂ - α̂_i - β̂_t - L̂_{ti} 2. Score = Σ τ̂_{ti}² (sum of squared pseudo-treatment effects) 3. Select λ combination that minimizes Q(λ) @@ -1238,13 +1248,15 @@ For joint method, LOOCV works as follows: - `bootstrap_trop_variance_joint()` - Parallel bootstrap variance estimation **Key differences from twostep method**: -- Treatment effect τ is a single scalar (homogeneous assumption) vs. per-observation τ_{it} - Global weights (distance to treated block center) vs. per-observation weights - Single model fit per λ combination vs. N_treated fits +- Treatment effects are post-hoc residuals from a single global model (global) + vs. post-hoc residuals from per-observation models (twostep) +- Both use (1-W) masking (control-only fitting) - Faster computation for large panels **Assumptions**: -- **Simultaneous adoption (enforced)**: The joint method requires all treated units +- **Simultaneous adoption (enforced)**: The global method requires all treated units to receive treatment at the same time. A `ValueError` is raised if staggered adoption is detected (units first treated at different periods). Treatment timing is inferred once and held constant for bootstrap variance estimation. @@ -1255,9 +1267,9 @@ For joint method, LOOCV works as follows: **Requirements checklist:** - [x] Same LOOCV framework as twostep (Equation 5) - [x] Global weight computation using treated block center -- [x] Weighted least squares with treatment indicator +- [x] (1-W) masking for control-only fitting (per paper Eq. 2) - [x] Alternating minimization for nuclear norm penalty -- [x] Returns scalar τ (homogeneous treatment effect) +- [x] Returns ATT = mean of per-observation post-hoc τ̂_{it} - [x] Rust acceleration for LOOCV and bootstrap --- diff --git a/docs/tutorials/10_trop.ipynb b/docs/tutorials/10_trop.ipynb index ffa52b8..8844660 100644 --- a/docs/tutorials/10_trop.ipynb +++ b/docs/tutorials/10_trop.ipynb @@ -598,14 +598,14 @@ }, { "cell_type": "code", - "source": "# Compare estimation methods\nprint(\"Estimation method comparison:\")\nprint(\"=\"*60)\n\nimport time\n\n# Two-step method (default)\nstart = time.time()\ntrop_twostep = TROP(\n method='twostep',\n lambda_time_grid=[0.0, 1.0],\n lambda_unit_grid=[0.0, 1.0], \n lambda_nn_grid=[0.0, 0.1],\n n_bootstrap=20,\n seed=42\n)\nresults_twostep = trop_twostep.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n)\ntwostep_time = time.time() - start\n\n# Joint method\nstart = time.time()\ntrop_joint = TROP(\n method='joint',\n lambda_time_grid=[0.0, 1.0],\n lambda_unit_grid=[0.0, 1.0], \n lambda_nn_grid=[0.0, 0.1],\n n_bootstrap=20,\n seed=42\n)\nresults_joint = trop_joint.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n)\njoint_time = time.time() - start\n\nprint(f\"\\n{'Method':<15} {'ATT':>10} {'SE':>10} {'Time (s)':>12}\")\nprint(\"-\"*60)\nprint(f\"{'Two-step':<15} {results_twostep.att:>10.4f} {results_twostep.se:>10.4f} {twostep_time:>12.2f}\")\nprint(f\"{'Joint':<15} {results_joint.att:>10.4f} {results_joint.se:>10.4f} {joint_time:>12.2f}\")\nprint(f\"\\nTrue ATT: {true_att}\")\nprint(f\"Two-step bias: {results_twostep.att - true_att:.4f}\")\nprint(f\"Joint bias: {results_joint.att - true_att:.4f}\")", + "source": "# Compare estimation methods\nprint(\"Estimation method comparison:\")\nprint(\"=\"*60)\n\nimport time\n\n# Two-step method (default)\nstart = time.time()\ntrop_twostep = TROP(\n method='twostep',\n lambda_time_grid=[0.0, 1.0],\n lambda_unit_grid=[0.0, 1.0], \n lambda_nn_grid=[0.0, 0.1],\n n_bootstrap=20,\n seed=42\n)\nresults_twostep = trop_twostep.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n)\ntwostep_time = time.time() - start\n\n# Global method\nstart = time.time()\ntrop_global = TROP(\n method='global',\n lambda_time_grid=[0.0, 1.0],\n lambda_unit_grid=[0.0, 1.0], \n lambda_nn_grid=[0.0, 0.1],\n n_bootstrap=20,\n seed=42\n)\nresults_global = trop_global.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n)\nglobal_time = time.time() - start\n\nprint(f\"\\n{'Method':<15} {'ATT':>10} {'SE':>10} {'Time (s)':>12}\")\nprint(\"-\"*60)\nprint(f\"{'Two-step':<15} {results_twostep.att:>10.4f} {results_twostep.se:>10.4f} {twostep_time:>12.2f}\")\nprint(f\"{'Global':<15} {results_global.att:>10.4f} {results_global.se:>10.4f} {global_time:>12.2f}\")\nprint(f\"\\nTrue ATT: {true_att}\")\nprint(f\"Two-step bias: {results_twostep.att - true_att:.4f}\")\nprint(f\"Global bias: {results_global.att - true_att:.4f}\")", "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", - "source": "## 10. Estimation Methods: Two-Step vs Joint\n\nTROP supports two estimation methods via the `method` parameter:\n\n**Two-Step Method** (`method='twostep'`, default):\n- Follows Algorithm 2 from the paper\n- Computes observation-specific weights for each treated observation\n- Fits a model per treated observation, then averages the individual effects\n- More flexible, allows for heterogeneous treatment effects\n- Computationally intensive (N_treated optimizations)\n\n**Joint Method** (`method='joint'`):\n- Weighted least squares with a single scalar treatment effect τ\n- Weights computed once (distance to center of treated block)\n- With low-rank: uses alternating minimization between weighted LS and soft-threshold SVD\n- Faster but assumes homogeneous treatment effects", + "source": "## 10. Estimation Methods: Two-Step vs Global\n\nTROP supports two estimation methods via the `method` parameter:\n\n**Two-Step Method** (`method='twostep'`, default):\n- Follows Algorithm 2 from the paper\n- Computes observation-specific weights for each treated observation\n- Fits a model per treated observation, then averages the individual effects\n- More flexible, allows for heterogeneous treatment effects\n- Computationally intensive (N_treated optimizations)\n\n**Global Method** (`method='global'`):\n- Fits a single model on control data using (1-W) masked weights (per paper Eq. 2)\n- Extracts per-observation treatment effects as post-hoc residuals: τ_it = Y_it - μ - α_i - β_t - L_it\n- ATT = mean(τ_it) over treated observations\n- Faster (single optimization) with global weights\n\nNote: `method='joint'` is a deprecated alias for `method='global'`.", "metadata": {} }, { @@ -638,7 +638,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "## Summary\n\nKey takeaways for TROP:\n\n1. **Best use cases**: Factor confounding, unobserved time-varying confounders with interactive effects\n2. **Factor estimation**: Nuclear norm regularization with LOOCV for tuning\n3. **Three tuning parameters**: λ_time, λ_unit, λ_nn selected automatically via LOOCV\n4. **Unit weights**: Exponential distance-based weighting of control units, where distance is computed as RMS outcome difference on control periods excluding the target period\n5. **Time weights**: Exponential decay weighting of pre-treatment periods\n6. **Weights**: Importance weights controlling relative contribution of observations (higher = more relevant)\n7. **Estimation methods**:\n - `method='twostep'` (default): Per-observation estimation, allows heterogeneous effects\n - `method='joint'`: Single scalar treatment effect, faster but assumes homogeneity\n\n**When to use TROP vs SDID**:\n- Use **SDID** when parallel trends is plausible and factors are not a concern\n- Use **TROP** when you suspect factor confounding (regional shocks, economic cycles, latent factors)\n- Running both provides a useful robustness check\n\n**When to use twostep vs joint method**:\n- Use **twostep** (default) for maximum flexibility and heterogeneous treatment effects\n- Use **joint** for faster estimation when effects are expected to be homogeneous\n\n**Reference**:\n- Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536" + "source": "## Summary\n\nKey takeaways for TROP:\n\n1. **Best use cases**: Factor confounding, unobserved time-varying confounders with interactive effects\n2. **Factor estimation**: Nuclear norm regularization with LOOCV for tuning\n3. **Three tuning parameters**: λ_time, λ_unit, λ_nn selected automatically via LOOCV\n4. **Unit weights**: Exponential distance-based weighting of control units, where distance is computed as RMS outcome difference on control periods excluding the target period\n5. **Time weights**: Exponential decay weighting of pre-treatment periods\n6. **Weights**: Importance weights controlling relative contribution of observations (higher = more relevant)\n7. **Estimation methods**:\n - `method='twostep'` (default): Per-observation estimation, allows heterogeneous effects\n - `method='global'`: Single model with (1-W) masking, post-hoc heterogeneous effects, faster\n\n**When to use TROP vs SDID**:\n- Use **SDID** when parallel trends is plausible and factors are not a concern\n- Use **TROP** when you suspect factor confounding (regional shocks, economic cycles, latent factors)\n- Running both provides a useful robustness check\n\n**When to use twostep vs global method**:\n- Use **twostep** (default) for maximum flexibility with per-observation weights\n- Use **global** for faster estimation with global weights\n\n**Reference**:\n- Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536" }, { "cell_type": "code", diff --git a/rust/src/trop.rs b/rust/src/trop.rs index b503f4c..19714cc 100644 --- a/rust/src/trop.rs +++ b/rust/src/trop.rs @@ -1186,27 +1186,35 @@ fn compute_joint_weights( } } + // (1-W) masking: zero out treated observations per paper Eq. 2 + for t in 0..n_periods { + for i in 0..n_units { + delta[[t, i]] *= 1.0 - d[[t, i]]; + } + } + delta } -/// Solve joint TWFE + treatment via weighted least squares (no low-rank). +/// Solve joint TWFE via weighted least squares (no low-rank, no tau). +/// +/// Minimizes: min Σ δ_{it}(Y_{it} - μ - α_i - β_t)² /// -/// Minimizes: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - τ*W_{it})² +/// tau is extracted post-hoc by the caller (ATT = mean residual over treated). /// /// # Returns -/// (mu, alpha, beta, tau) estimated parameters +/// (mu, alpha, beta) estimated parameters fn solve_joint_no_lowrank( y: &ArrayView2, - d: &ArrayView2, delta: &ArrayView2, -) -> Option<(f64, Array1, Array1, f64)> { +) -> Option<(f64, Array1, Array1)> { let n_periods = y.nrows(); let n_units = y.ncols(); // We solve using normal equations with the design matrix structure // Rather than build full X matrix, use block structure for efficiency // - // The model: Y_it = μ + α_i + β_t + τ*D_it + ε_it + // The model: Y_it = μ + α_i + β_t + ε_it // With identification: α_0 = β_0 = 0 // Compute weighted sums needed for normal equations @@ -1239,18 +1247,16 @@ fn solve_joint_no_lowrank( return None; } - // Use iterative approach: alternate between (alpha, beta, tau) and mu + // Use iterative approach: alternate between (alpha, beta) and mu // until convergence (simpler than full normal equations) let mut mu = sum_wy / sum_w; let mut alpha = Array1::::zeros(n_units); let mut beta = Array1::::zeros(n_periods); - let mut tau = 0.0; for _ in 0..50 { let mu_old = mu; - let tau_old = tau; - // Update alpha (fixing beta, tau, mu) + // Update alpha (fixing beta, mu) for i in 1..n_units { // α_0 = 0 for identification if sum_w_by_unit[i] > 1e-10 { let mut num = 0.0; @@ -1258,13 +1264,13 @@ fn solve_joint_no_lowrank( // NaN outcomes get zero weight let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 }; let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 }; - num += w * (y_ti - mu - beta[t] - tau * d[[t, i]]); + num += w * (y_ti - mu - beta[t]); } alpha[i] = num / sum_w_by_unit[i]; } } - // Update beta (fixing alpha, tau, mu) + // Update beta (fixing alpha, mu) for t in 1..n_periods { // β_0 = 0 for identification if sum_w_by_period[t] > 1e-10 { let mut num = 0.0; @@ -1272,133 +1278,138 @@ fn solve_joint_no_lowrank( // NaN outcomes get zero weight let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 }; let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 }; - num += w * (y_ti - mu - alpha[i] - tau * d[[t, i]]); + num += w * (y_ti - mu - alpha[i]); } beta[t] = num / sum_w_by_period[t]; } } - // Update tau (fixing alpha, beta, mu) - let mut num_tau = 0.0; - let mut denom_tau = 0.0; - for t in 0..n_periods { - for i in 0..n_units { - // NaN outcomes get zero weight - let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 }; - let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 }; - let d_ti = d[[t, i]]; - if d_ti > 0.5 { // Only treated observations contribute - num_tau += w * d_ti * (y_ti - mu - alpha[i] - beta[t]); - denom_tau += w * d_ti * d_ti; - } - } - } - if denom_tau > 1e-10 { - tau = num_tau / denom_tau; - } - - // Update mu (fixing alpha, beta, tau) + // Update mu (fixing alpha, beta) let mut num_mu = 0.0; for t in 0..n_periods { for i in 0..n_units { // NaN outcomes get zero weight let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 }; let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 }; - num_mu += w * (y_ti - alpha[i] - beta[t] - tau * d[[t, i]]); + num_mu += w * (y_ti - alpha[i] - beta[t]); } } mu = num_mu / sum_w; // Check convergence - if (mu - mu_old).abs() < 1e-8 && (tau - tau_old).abs() < 1e-8 { + if (mu - mu_old).abs() < 1e-8 { break; } } - Some((mu, alpha, beta, tau)) + Some((mu, alpha, beta)) } -/// Solve joint TWFE + treatment + low-rank via alternating minimization. +/// Solve joint TWFE + low-rank via alternating minimization (no tau). +/// +/// Minimizes: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - L_{it})² + λ_nn||L||_* /// -/// Minimizes: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - L_{it} - τ*W_{it})² + λ_nn||L||_* +/// tau is extracted post-hoc by the caller (ATT = mean residual over treated). /// /// # Returns -/// (mu, alpha, beta, L, tau) estimated parameters +/// (mu, alpha, beta, L) estimated parameters fn solve_joint_with_lowrank( y: &ArrayView2, - d: &ArrayView2, delta: &ArrayView2, lambda_nn: f64, max_iter: usize, tol: f64, -) -> Option<(f64, Array1, Array1, Array2, f64)> { +) -> Option<(f64, Array1, Array1, Array2)> { let n_periods = y.nrows(); let n_units = y.ncols(); + // Precompute normalized weights and threshold (constant across iterations) + let delta_max = delta.iter().cloned().fold(0.0_f64, f64::max); + let threshold = if delta_max > 0.0 { lambda_nn / (2.0 * delta_max) } else { lambda_nn / 2.0 }; + + // Precompute delta_norm (masked for NaN outcomes) + let mut delta_norm = Array2::::zeros((n_periods, n_units)); + for t in 0..n_periods { + for i in 0..n_units { + let d_ti = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 }; + delta_norm[[t, i]] = if delta_max > 0.0 { d_ti / delta_max } else { d_ti }; + } + } + // Initialize L = 0 let mut l = Array2::::zeros((n_periods, n_units)); for _ in 0..max_iter { let l_old = l.clone(); - // Step 1: Fix L, solve for (mu, alpha, beta, tau) - // Adjusted outcome: Y - L (preserve NaN so solve_joint_no_lowrank masks weights) + // Step 1: Fix L, solve for (mu, alpha, beta) let y_adj = Array2::from_shape_fn((n_periods, n_units), |(t, i)| { y[[t, i]] - l[[t, i]] // NaN - finite = NaN (preserves NaN info) }); + let (mu, alpha, beta) = solve_joint_no_lowrank(&y_adj.view(), delta)?; - let (mu, alpha, beta, tau) = solve_joint_no_lowrank(&y_adj.view(), d, delta)?; - - // Step 2: Fix (mu, alpha, beta, tau), update L - // Residual: R = Y - mu - alpha - beta - tau*D (preserve NaN) - let mut r = Array2::::zeros((n_periods, n_units)); + // Step 2: Fix (mu, alpha, beta), update L with FISTA acceleration + // Residual: R = Y - mu - alpha - beta + // For delta=0 observations (treated/NaN), keep L rather than R + let mut r_masked = Array2::::zeros((n_periods, n_units)); for t in 0..n_periods { for i in 0..n_units { - // NaN - finite = NaN (will be masked in gradient step) - r[[t, i]] = y[[t, i]] - mu - alpha[i] - beta[t] - tau * d[[t, i]]; + if delta_norm[[t, i]] > 0.0 && y[[t, i]].is_finite() { + r_masked[[t, i]] = y[[t, i]] - mu - alpha[i] - beta[t]; + } else { + r_masked[[t, i]] = l[[t, i]]; + } } } - // Weighted proximal step for L (soft-threshold SVD) - // Lipschitz constant of ∇f is L_f = 2·max(δ), step size η = 1/L_f - let delta_max = delta.iter().cloned().fold(0.0_f64, f64::max); - let eta = if delta_max > 0.0 { 1.0 / (2.0 * delta_max) } else { 0.5 }; + // Inner FISTA loop for L update + let mut l_inner = l.clone(); + let mut l_inner_prev = l_inner.clone(); + let mut t_fista = 1.0_f64; - // gradient_step = L + eta * delta * (R - L) - // NaN outcomes get zero weight so they don't affect gradient - let mut gradient_step = Array2::::zeros((n_periods, n_units)); - for t in 0..n_periods { - for i in 0..n_units { - // Mask delta for NaN outcomes - let delta_ti = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 }; - let delta_norm = if delta_max > 0.0 { - delta_ti / delta_max - } else { - delta_ti - }; - // r[[t,i]] may be NaN, but delta_norm=0 for NaN obs, so contribution=0 - let r_contrib = if r[[t, i]].is_finite() { r[[t, i]] } else { 0.0 }; - gradient_step[[t, i]] = l[[t, i]] + delta_norm * (r_contrib - l[[t, i]]); + for _ in 0..20 { + // FISTA momentum + let t_fista_new = (1.0 + (1.0 + 4.0 * t_fista * t_fista).sqrt()) / 2.0; + let momentum = (t_fista - 1.0) / t_fista_new; + + // Gradient step from momentum point + let mut gradient_step = Array2::::zeros((n_periods, n_units)); + for t in 0..n_periods { + for i in 0..n_units { + let l_mom = l_inner[[t, i]] + momentum * (l_inner[[t, i]] - l_inner_prev[[t, i]]); + gradient_step[[t, i]] = l_mom + delta_norm[[t, i]] * (r_masked[[t, i]] - l_mom); + } + } + + // Proximal step: soft-threshold singular values + // l_inner_prev holds pre-SVD value for both momentum and convergence check + l_inner_prev = l_inner; + l_inner = soft_threshold_svd(&gradient_step, threshold)?; + t_fista = t_fista_new; + + // Convergence check + let inner_diff = max_abs_diff_2d(&l_inner, &l_inner_prev); + if inner_diff < tol { + break; } } - // Soft-threshold singular values - l = soft_threshold_svd(&gradient_step, eta * lambda_nn)?; + l = l_inner; - // Check convergence + // Outer convergence check let l_diff = max_abs_diff_2d(&l, &l_old); if l_diff < tol { break; } } - // Final solve with converged L (preserve NaN so solve_joint_no_lowrank masks weights) + // Final solve with converged L let y_adj = Array2::from_shape_fn((n_periods, n_units), |(t, i)| { - y[[t, i]] - l[[t, i]] // NaN - finite = NaN (preserves NaN info) + y[[t, i]] - l[[t, i]] }); - let (mu, alpha, beta, tau) = solve_joint_no_lowrank(&y_adj.view(), d, delta)?; + let (mu, alpha, beta) = solve_joint_no_lowrank(&y_adj.view(), delta)?; - Some((mu, alpha, beta, l, tau)) + Some((mu, alpha, beta, l)) } /// Compute LOOCV score for joint method with specific parameter combination. @@ -1442,17 +1453,17 @@ fn loocv_score_joint( delta_ex[[t_ex, i_ex]] = 0.0; let result = if lambda_nn >= 1e10 { - solve_joint_no_lowrank(y, d, &delta_ex.view()) - .map(|(mu, alpha, beta, tau)| { + solve_joint_no_lowrank(y, &delta_ex.view()) + .map(|(mu, alpha, beta)| { let l = Array2::::zeros((n_periods, n_units)); - (mu, alpha, beta, l, tau) + (mu, alpha, beta, l) }) } else { - solve_joint_with_lowrank(y, d, &delta_ex.view(), lambda_nn, max_iter, tol) + solve_joint_with_lowrank(y, &delta_ex.view(), lambda_nn, max_iter, tol) }; match result { - Some((mu, alpha, beta, l, _tau)) => { + Some((mu, alpha, beta, l)) => { if y[[t_ex, i_ex]].is_finite() { let tau_loocv = y[[t_ex, i_ex]] - mu - alpha[i_ex] - beta[t_ex] - l[[t_ex, i_ex]]; (sum + tau_loocv * tau_loocv, valid + 1, first_fail) @@ -1570,16 +1581,14 @@ pub fn loocv_grid_search_joint<'py>( .into_par_iter() .map(|(lt, lu, ln)| { // Convert λ_nn=∞ → 1e10 (factor model disabled) - let lt_eff = lt; - let lu_eff = lu; let ln_eff = if ln.is_infinite() { 1e10 } else { ln }; let (score, n_valid, first_failed) = loocv_score_joint( &y_arr, &d_arr, &control_obs, - lt_eff, - lu_eff, + lt, + lu, ln_eff, treated_periods, max_iter, @@ -1677,8 +1686,6 @@ pub fn bootstrap_trop_variance_joint<'py>( let treated_periods = n_periods.saturating_sub(first_treat_period); // Convert λ_nn=∞ → 1e10 (factor model disabled) - let lt_eff = lambda_time; - let lu_eff = lambda_unit; let ln_eff = if lambda_nn.is_infinite() { 1e10 } else { lambda_nn }; // Run bootstrap iterations in parallel @@ -1724,27 +1731,45 @@ pub fn bootstrap_trop_variance_joint<'py>( let delta = compute_joint_weights( &y_boot.view(), &d_boot.view(), - lt_eff, - lu_eff, + lambda_time, + lambda_unit, treated_periods, ); let result = if ln_eff >= 1e10 { - solve_joint_no_lowrank(&y_boot.view(), &d_boot.view(), &delta.view()) - .map(|(_, _, _, tau)| tau) + solve_joint_no_lowrank(&y_boot.view(), &delta.view()) + .map(|(mu, alpha, beta)| { + let l = Array2::::zeros((n_periods, n_units)); + (mu, alpha, beta, l) + }) } else { solve_joint_with_lowrank( &y_boot.view(), - &d_boot.view(), &delta.view(), ln_eff, max_iter, tol, ) - .map(|(_, _, _, _, tau)| tau) }; - result + // Post-hoc tau extraction: ATT = mean(Y - mu - alpha - beta - L) over treated + result.and_then(|(mu, alpha, beta, l)| { + let mut tau_sum = 0.0; + let mut tau_count = 0; + for t in 0..n_periods { + for i in 0..n_units { + if d_boot[[t, i]] == 1.0 && y_boot[[t, i]].is_finite() { + tau_sum += y_boot[[t, i]] - mu - alpha[i] - beta[t] - l[[t, i]]; + tau_count += 1; + } + } + } + if tau_count > 0 { + Some(tau_sum / tau_count as f64) + } else { + None + } + }) }) .collect(); diff --git a/tests/test_trop.py b/tests/test_trop.py index 726c2e6..a5fc726 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -2846,18 +2846,32 @@ def test_method_parameter_validation(self): def test_method_in_get_params(self): """method parameter appears in get_params().""" - trop_est = TROP(method="joint") + trop_est = TROP(method="global") params = trop_est.get_params() assert "method" in params - assert params["method"] == "joint" + assert params["method"] == "global" + + def test_method_in_get_params_joint_deprecated(self): + """'joint' alias maps to 'global' in get_params().""" + with pytest.warns(FutureWarning, match="deprecated"): + trop_est = TROP(method="joint") + params = trop_est.get_params() + assert params["method"] == "global" def test_method_in_set_params(self): """method parameter can be set via set_params().""" trop_est = TROP(method="twostep") assert trop_est.method == "twostep" - trop_est.set_params(method="joint") - assert trop_est.method == "joint" + trop_est.set_params(method="global") + assert trop_est.method == "global" + + def test_method_set_params_joint_deprecated(self): + """'joint' alias maps to 'global' via set_params().""" + trop_est = TROP(method="twostep") + with pytest.warns(FutureWarning, match="deprecated"): + trop_est.set_params(method="joint") + assert trop_est.method == "global" def test_joint_bootstrap_variance(self, simple_panel_data, ci_params): """Joint method bootstrap variance estimation works.""" @@ -3212,9 +3226,9 @@ def test_joint_treated_pre_nan_handling(self, simple_panel_data): assert np.isfinite(results.se), f"SE should be finite, got {results.se}" def test_joint_rejects_staggered_adoption(self): - """Joint method raises ValueError for staggered adoption data. + """Global method raises ValueError for staggered adoption data. - The joint method assumes all treated units receive treatment at the + The global method assumes all treated units receive treatment at the same time. With staggered adoption (units first treated at different periods), the method's weights and variance estimation are invalid. """ @@ -3235,7 +3249,178 @@ def test_joint_rejects_staggered_adoption(self): }) df = pd.DataFrame(data) - trop = TROP(method="joint") + trop = TROP(method="global") with pytest.raises(ValueError, match="staggered adoption"): trop.fit(df, 'outcome', 'treated', 'unit', 'time') + def test_global_method_alias(self, simple_panel_data): + """method='global' works and produces same results as deprecated 'joint'.""" + trop_est = TROP( + method="global", + lambda_time_grid=[0.0, 1.0], + lambda_unit_grid=[0.0, 1.0], + lambda_nn_grid=[0.0, 0.1], + n_bootstrap=10, + seed=42, + ) + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + assert isinstance(results, TROPResults) + assert results.att > 0 + + def test_global_uses_control_only_weights(self, simple_panel_data): + """Verify delta[t,i] == 0 for all D[t,i] == 1 (control-only weights).""" + trop_est = TROP( + method="global", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.0], + seed=42, + ) + + # Setup data matrices + all_units = sorted(simple_panel_data['unit'].unique()) + all_periods = sorted(simple_panel_data['period'].unique()) + n_units = len(all_units) + n_periods = len(all_periods) + + Y = ( + simple_panel_data.pivot(index='period', columns='unit', values='outcome') + .reindex(index=all_periods, columns=all_units) + .values + ) + D = ( + simple_panel_data.pivot(index='period', columns='unit', values='treated') + .reindex(index=all_periods, columns=all_units) + .fillna(0) + .astype(int) + .values + ) + + treated_periods = np.sum(np.any(D == 1, axis=1)) + + delta = trop_est._compute_joint_weights( + Y, D, 1.0, 1.0, int(treated_periods), n_units, n_periods + ) + + # All treated cells should have zero weight + assert np.all(delta[D == 1] == 0.0), ( + "Treated observations should have zero weight after (1-W) masking" + ) + # Some control cells should have non-zero weight + assert np.any(delta[D == 0] > 0.0), ( + "Some control observations should have positive weight" + ) + + def test_global_tau_is_posthoc_residual(self, simple_panel_data): + """Verify ATT == mean(Y - mu - alpha - beta - L) over treated cells.""" + trop_est = TROP( + method="global", + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[0.1], + n_bootstrap=10, + seed=42, + ) + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # Reconstruct tau from treatment_effects + tau_values = [v for v in results.treatment_effects.values() if np.isfinite(v)] + assert len(tau_values) > 0, "Should have treatment effects" + reconstructed_att = np.mean(tau_values) + assert np.isclose(results.att, reconstructed_att, atol=1e-10), ( + f"ATT ({results.att}) should equal mean of treatment effects ({reconstructed_att})" + ) + + def test_global_heterogeneous_treatment_effects(self, simple_panel_data): + """Treatment effects are heterogeneous (not all identical) with global method.""" + trop_est = TROP( + method="global", + lambda_time_grid=[0.0], + lambda_unit_grid=[0.0], + lambda_nn_grid=[float('inf')], + n_bootstrap=10, + seed=42, + ) + results = trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + te_values = list(results.treatment_effects.values()) + # With post-hoc extraction, effects should vary across observations + assert len(set(te_values)) > 1, ( + "Treatment effects should be heterogeneous with post-hoc extraction" + ) + + def test_global_treated_outcome_does_not_affect_fit(self, simple_panel_data): + """Perturbing treated outcomes should not change (mu, alpha, beta, L).""" + all_units = sorted(simple_panel_data['unit'].unique()) + all_periods = sorted(simple_panel_data['period'].unique()) + n_units = len(all_units) + n_periods = len(all_periods) + + Y = ( + simple_panel_data.pivot(index='period', columns='unit', values='outcome') + .reindex(index=all_periods, columns=all_units) + .values + ) + D = ( + simple_panel_data.pivot(index='period', columns='unit', values='treated') + .reindex(index=all_periods, columns=all_units) + .fillna(0) + .astype(int) + .values + ) + + treated_periods = int(np.sum(np.any(D == 1, axis=1))) + + trop_est = TROP( + method="global", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.1], + seed=42, + ) + + # Compute weights and fit with original data + delta = trop_est._compute_joint_weights( + Y, D, 1.0, 1.0, treated_periods, n_units, n_periods + ) + mu1, alpha1, beta1, L1 = trop_est._solve_joint_with_lowrank( + Y, delta, 0.1, 100, 1e-6 + ) + + # Perturb treated outcomes by large amount + Y_perturbed = Y.copy() + Y_perturbed[D == 1] += 1000.0 + + # Recompute (same weights since (1-W) zeroes treated cells) + delta2 = trop_est._compute_joint_weights( + Y_perturbed, D, 1.0, 1.0, treated_periods, n_units, n_periods + ) + mu2, alpha2, beta2, L2 = trop_est._solve_joint_with_lowrank( + Y_perturbed, delta2, 0.1, 100, 1e-6 + ) + + # Model parameters should be identical + assert np.isclose(mu1, mu2, atol=1e-8), f"mu changed: {mu1} vs {mu2}" + assert np.allclose(alpha1, alpha2, atol=1e-8), "alpha changed" + assert np.allclose(beta1, beta2, atol=1e-8), "beta changed" + assert np.allclose(L1, L2, atol=1e-8), "L changed" + From ab913f1ba58543bb341d783b01895e92c0826d1c Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 8 Mar 2026 19:11:01 -0400 Subject: [PATCH 5/7] Fix stale coefficients in global low-rank solver and NaN bootstrap poisoning - Add final re-solve after outer loop convergence in _solve_joint_with_lowrank to ensure mu/alpha/beta are consistent with converged L (matches Rust) - Filter NaN ATT draws in bootstrap fallback with np.isfinite check - Clarify global method docs as adaptation of Eq. 2 masking principle, not paper-exact Algorithm 2 - Update FISTA documentation to reflect both solvers now use acceleration Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/trop.py | 19 ++++++++++++------- docs/api/trop.rst | 8 +++++--- docs/methodology/REGISTRY.md | 11 +++++++---- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/diff_diff/trop.py b/diff_diff/trop.py index 656daa8..b20815f 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -71,13 +71,13 @@ class TROP: a model for each treated observation, averaging the individual treatment effects. More flexible but computationally intensive. - - 'global': Global weighted least squares with post-hoc treatment - effect extraction. Fits a single model on control observations - using (1-W) masked weights per paper Eq. 2, then computes - per-observation treatment effects as residuals: + - 'global': Computationally efficient adaptation using the (1-W) + masking principle from Eq. 2. Fits a single model on control + observations with global weights, then computes per-observation + treatment effects as residuals: tau_it = Y_it - mu - alpha_i - beta_t - L_it for treated cells. - ATT is the mean of these effects. Faster than twostep but uses - global weights instead of per-observation weights. + ATT is the mean of these effects. For the paper's full + per-treated-cell estimator, use ``method='twostep'``. - 'joint': Deprecated alias for 'global'. Will be removed in a future version. @@ -979,6 +979,10 @@ def _solve_joint_with_lowrank( if np.max(np.abs(L - L_old)) < tol: break + # Final re-solve with converged L (match Rust behavior) + Y_adj = Y_safe - L + mu, alpha, beta = self._solve_joint_no_lowrank(Y_adj, delta_masked) + return mu, alpha, beta, L def _fit_joint( @@ -1405,7 +1409,8 @@ def _bootstrap_variance_joint( boot_data, outcome, treatment, unit, time, optimal_lambda, treated_periods ) - bootstrap_estimates_list.append(tau) + if np.isfinite(tau): + bootstrap_estimates_list.append(tau) except (ValueError, np.linalg.LinAlgError, KeyError): continue diff --git a/docs/api/trop.rst b/docs/api/trop.rst index 65b6d9f..fe0dd1e 100644 --- a/docs/api/trop.rst +++ b/docs/api/trop.rst @@ -121,12 +121,14 @@ the estimator is consistent if any one of the three components **Global Method** (``method='global'``) -An alternative approach that fits a single model on control data and extracts -treatment effects as post-hoc residuals: +A computationally efficient adaptation using the ``(1-W)`` masking principle +from Eq. 2. Fits a single global model rather than per-treated-cell models. +For the paper's full per-treated-cell estimator (Algorithm 2), use +``method='twostep'``. 1. **Compute weights**: Distance-based unit and time weights computed once (distance to center of treated block, RMSE to average treated trajectory), - with ``(1-W)`` masking to zero out treated observations (per paper Eq. 2). + with ``(1-W)`` masking to zero out treated observations. 2. **Fit control model**: Solve weighted least squares on control data only diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 48cca33..fab3271 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1088,8 +1088,9 @@ Gradient step: G = L + (W/max(W)) ⊙ (R - L) Proximal step: L = U × soft_threshold(Σ, η·λ_nn) × V' (SVD of G = UΣV') ``` where R is the residual after removing fixed effects (and τ·D in joint mode). -The twostep solver's inner L update uses FISTA/Nesterov acceleration (O(1/k²) convergence); -the Python joint solver uses a single proximal gradient step per outer alternating iteration. +Both the twostep and global solvers use FISTA/Nesterov acceleration for the +inner L update (O(1/k²) convergence rate, up to 20 inner iterations per +outer alternating step). Per-observation weights (Equation 3): ``` @@ -1179,8 +1180,10 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² **Method**: `method="global"` in TROP estimator (`method="joint"` is a deprecated alias) -**Approach**: Global weighted least squares on control data with (1-W) masking, -followed by post-hoc treatment effect extraction. Per paper Eq. 2. +**Approach**: Computationally efficient adaptation using the (1-W) masking +principle from Eq. 2. Fits a single global model on control data, then +extracts treatment effects as post-hoc residuals. For the paper's full +per-treated-cell estimator (Algorithm 2), use `method='twostep'`. **Objective function** (Equation G1): ``` From 9d4c80caca3f1ef9489a2f6614b67f41af00c88f Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 14 Mar 2026 16:33:10 -0400 Subject: [PATCH 6/7] Fix Rust convergence criterion, n_valid_treated consistency, and NaN bootstrap SE MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rust solve_joint_no_lowrank: check convergence across all params (mu, alpha, beta), not just mu — fixes premature termination and Rust/Python divergence - Global + twostep: use n_valid_treated (finite outcomes only) for df_trop and results.n_treated_obs; skip NaN Y in twostep loop to prevent NaN poisoning - Return np.nan (not 0.0) SE when <2 bootstrap draws succeed (all 3 paths) - Update API docs: method='joint' example → method='global' - Fix stale FISTA reference in REGISTRY.md global section - Add edge case docs for partial/all-NaN treated outcomes Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/trop.py | 55 +++++++-- docs/api/trop.rst | 16 +-- docs/methodology/REGISTRY.md | 15 ++- rust/src/trop.rs | 14 ++- tests/test_rust_backend.py | 135 ++++++++++++++++++++++ tests/test_trop.py | 212 +++++++++++++++++++++++++++++++++++ 6 files changed, 426 insertions(+), 21 deletions(-) diff --git a/diff_diff/trop.py b/diff_diff/trop.py index b20815f..abb21f9 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -1234,6 +1234,20 @@ def _fit_joint( Y, D, mu, alpha, beta, L, idx_to_unit, idx_to_period ) + # Use count of valid (finite) treated outcomes for df and metadata + n_valid_treated = len(tau_values) + if n_valid_treated == 0: + warnings.warn( + "All treated outcomes are NaN/missing. Cannot estimate ATT.", + UserWarning, + ) + elif n_valid_treated < n_treated_obs: + warnings.warn( + f"Only {n_valid_treated} of {n_treated_obs} treated outcomes are finite. " + "df and n_treated_obs reflect valid observations only.", + UserWarning, + ) + # Compute effective rank of L _, s, _ = np.linalg.svd(L, full_matrices=False) if s[0] > 0: @@ -1250,7 +1264,7 @@ def _fit_joint( ) # Compute test statistics - df_trop = max(1, n_treated_obs - 1) + df_trop = max(1, n_valid_treated - 1) t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_trop) # Create results dictionaries @@ -1266,7 +1280,7 @@ def _fit_joint( n_obs=len(data), n_treated=len(treated_unit_idx), n_control=len(control_unit_idx), - n_treated_obs=int(n_treated_obs), + n_treated_obs=int(n_valid_treated), unit_effects=unit_effects_dict, time_effects=time_effects_dict, treatment_effects=treatment_effects, @@ -1358,7 +1372,7 @@ def _bootstrap_variance_joint( UserWarning ) if len(bootstrap_estimates) == 0: - return 0.0, np.array([]) + return np.nan, np.array([]) return float(se), np.array(bootstrap_estimates) @@ -1422,7 +1436,7 @@ def _bootstrap_variance_joint( UserWarning ) if len(bootstrap_estimates) == 0: - return 0.0, np.array([]) + return np.nan, np.array([]) se = np.std(bootstrap_estimates, ddof=1) return float(se), bootstrap_estimates @@ -1771,6 +1785,15 @@ def fit( treated_observations = self._precomputed["treated_observations"] for t, i in treated_observations: + unit_id = idx_to_unit[i] + time_id = idx_to_period[t] + + # Skip observations where outcome is missing — record NaN but + # don't fit the model or include in tau_values (avoids NaN poisoning) + if not np.isfinite(Y[t, i]): + treatment_effects[(unit_id, time_id)] = np.nan + continue + # Compute observation-specific weights for this (i, t) weight_matrix = self._compute_observation_weights( Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, @@ -1786,8 +1809,6 @@ def fit( # Compute treatment effect: τ̂_{it} = Y_{it} - α̂_i - β̂_t - L̂_{it} tau_it = Y[t, i] - alpha_hat[i] - beta_hat[t] - L_hat[t, i] - unit_id = idx_to_unit[i] - time_id = idx_to_period[t] treatment_effects[(unit_id, time_id)] = tau_it tau_values.append(tau_it) @@ -1796,8 +1817,22 @@ def fit( beta_estimates.append(beta_hat) L_estimates.append(L_hat) + # Count valid treated observations + n_valid_treated = len(tau_values) + if n_valid_treated == 0: + warnings.warn( + "All treated outcomes are NaN/missing. Cannot estimate ATT.", + UserWarning, + ) + elif n_valid_treated < n_treated_obs: + warnings.warn( + f"Only {n_valid_treated} of {n_treated_obs} treated outcomes are finite. " + "df and n_treated_obs reflect valid observations only.", + UserWarning, + ) + # Average ATT - att = np.mean(tau_values) + att = np.mean(tau_values) if tau_values else np.nan # Average parameter estimates for output (representative) alpha_hat = np.mean(alpha_estimates, axis=0) if alpha_estimates else np.zeros(n_units) @@ -1820,7 +1855,7 @@ def fit( ) # Compute test statistics - df_trop = max(1, n_treated_obs - 1) + df_trop = max(1, n_valid_treated - 1) t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_trop) # Create results dictionaries @@ -1837,7 +1872,7 @@ def fit( n_obs=len(data), n_treated=len(treated_unit_idx), n_control=len(control_unit_idx), - n_treated_obs=n_treated_obs, + n_treated_obs=int(n_valid_treated), unit_effects=unit_effects_dict, time_effects=time_effects_dict, treatment_effects=treatment_effects, @@ -2528,7 +2563,7 @@ def _bootstrap_variance( UserWarning ) if len(bootstrap_estimates) == 0: - return 0.0, np.array([]) + return np.nan, np.array([]) se = np.std(bootstrap_estimates, ddof=1) return float(se), bootstrap_estimates diff --git a/docs/api/trop.rst b/docs/api/trop.rst index fe0dd1e..cfcdc8c 100644 --- a/docs/api/trop.rst +++ b/docs/api/trop.rst @@ -211,27 +211,27 @@ Quick estimation with convenience function:: n_bootstrap=200 ) -Using the joint method for faster estimation:: +Using the global method for faster estimation:: from diff_diff import TROP - # Joint method: single scalar treatment effect via weighted LS - trop_joint = TROP( - method='joint', # Use joint weighted least squares + # Global method: computationally efficient adaptation using (1-W) masking + trop_global = TROP( + method='global', lambda_time_grid=[0.0, 0.5, 1.0, 2.0], lambda_unit_grid=[0.0, 0.5, 1.0, 2.0], lambda_nn_grid=[0.0, 0.1, 1.0], n_bootstrap=200, seed=42 ) - results_joint = trop_joint.fit(data, outcome='y', treatment='treated', - unit='unit_id', time='period') + results_global = trop_global.fit(data, outcome='y', treatment='treated', + unit='unit_id', time='period') # Compare methods - trop_twostep = TROP(method='twostep', ...) # Default + trop_twostep = TROP(method='twostep', ...) # Default (per-observation) results_twostep = trop_twostep.fit(data, ...) print(f"Two-step ATT: {results_twostep.att:.3f}") - print(f"Joint ATT: {results_joint.att:.3f}") + print(f"Global ATT: {results_global.att:.3f}") Examining factor structure:: diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index fab3271..944cca8 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1225,7 +1225,7 @@ Treatment effects are **heterogeneous** per-observation values. ATT is their mea - Lipschitz constant of ∇f is L_f = 2·max(δ) - Step size η = 1/L_f = 1/(2·max(δ)) - Proximal operator: soft_threshold(gradient_step, η·λ_nn) - - Twostep inner solver uses FISTA/Nesterov acceleration (O(1/k²)) + - Inner solver uses FISTA/Nesterov acceleration (O(1/k²)) - Continue until max(|L_new - L_old|) < tol 3. **Post-hoc**: Extract τ̂_{it} = Y_{it} - μ̂ - α̂_i - β̂_t - L̂_{it} for treated cells @@ -1267,6 +1267,19 @@ For global method, LOOCV works as follows: **Reference**: Adapted from reference implementation. See also Athey et al. (2025). +**Edge Cases (treated NaN outcomes):** +- **Partial NaN**: When some treated outcomes Y_{it} are NaN/missing: + - `_extract_posthoc_tau()` (global) skips these cells; only finite τ̂ values are averaged + - Twostep loop skips NaN outcomes entirely (no model fit, no tau appended) + - `n_treated_obs` in results reflects valid (finite) count, not total D==1 count + - `df_trop = max(1, n_valid_treated - 1)` uses valid count + - Warning issued when n_valid_treated < total treated count +- **All NaN**: When all treated outcomes are NaN: + - ATT = NaN, warning issued + - `n_treated_obs = 0` +- **Bootstrap SE with <2 draws**: Returns `se=NaN` (not 0.0) when zero bootstrap + iterations succeed. `safe_inference()` propagates NaN downstream. + **Requirements checklist:** - [x] Same LOOCV framework as twostep (Equation 5) - [x] Global weight computation using treated block center diff --git a/rust/src/trop.rs b/rust/src/trop.rs index 19714cc..997bf36 100644 --- a/rust/src/trop.rs +++ b/rust/src/trop.rs @@ -1255,6 +1255,8 @@ fn solve_joint_no_lowrank( for _ in 0..50 { let mu_old = mu; + let alpha_old = alpha.clone(); + let beta_old = beta.clone(); // Update alpha (fixing beta, mu) for i in 1..n_units { // α_0 = 0 for identification @@ -1296,8 +1298,16 @@ fn solve_joint_no_lowrank( } mu = num_mu / sum_w; - // Check convergence - if (mu - mu_old).abs() < 1e-8 { + // Check convergence across ALL parameters (not just mu) + let mu_diff = (mu - mu_old).abs(); + let alpha_diff = alpha.iter().zip(alpha_old.iter()) + .map(|(a, b)| (a - b).abs()) + .fold(0.0_f64, f64::max); + let beta_diff = beta.iter().zip(beta_old.iter()) + .map(|(a, b)| (a - b).abs()) + .fold(0.0_f64, f64::max); + let max_diff = mu_diff.max(alpha_diff).max(beta_diff); + if max_diff < 1e-8 { break; } } diff --git a/tests/test_rust_backend.py b/tests/test_rust_backend.py index 0a86aee..ae2df18 100644 --- a/tests/test_rust_backend.py +++ b/tests/test_rust_backend.py @@ -1682,6 +1682,141 @@ def test_trop_joint_treated_pre_nan_rust_python_parity(self): f"Rust ATT ({results_rust.att:.3f}) and Python ATT ({results_python.att:.3f}) " \ f"differ by {att_diff:.3f}, should be < 0.5" + def test_trop_joint_solver_parity_no_lowrank(self): + """Test Rust/Python solver parity for no-lowrank path (lambda_nn >= 1e10). + + Both backends should produce matching (mu, alpha, beta) at atol=1e-6. + This validates the convergence criterion fix (checking all params, not just mu). + """ + import pandas as pd + from diff_diff import TROP + from unittest.mock import patch + import sys + + np.random.seed(42) + n_units, n_periods = 15, 8 + n_treated = 4 + n_post = 3 + + data = [] + for i in range(n_units): + is_treated = i < n_treated + for t in range(n_periods): + post = t >= (n_periods - n_post) + y = 5.0 + i * 0.5 + t * 0.4 + np.random.randn() * 0.2 + treatment_indicator = 1 if (is_treated and post) else 0 + if treatment_indicator: + y += 2.0 + data.append({ + 'unit': i, 'time': t, + 'outcome': y, 'treated': treatment_indicator, + }) + df = pd.DataFrame(data) + + # Fixed lambda with lambda_nn=inf (no low-rank) + trop_params = dict( + method="global", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[np.inf], + n_bootstrap=2, + seed=42, + ) + + # Rust backend + trop_rust = TROP(**trop_params) + results_rust = trop_rust.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') + + # Python-only backend + trop_module = sys.modules['diff_diff.trop'] + with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ + patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \ + patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None): + trop_python = TROP(**trop_params) + results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') + + # ATT should match closely + assert abs(results_rust.att - results_python.att) < 1e-6, \ + f"No-lowrank ATT mismatch: Rust={results_rust.att:.8f}, Python={results_python.att:.8f}" + + # Unit and time effects should match + for key in results_rust.unit_effects: + r_val = results_rust.unit_effects[key] + p_val = results_python.unit_effects[key] + assert abs(r_val - p_val) < 1e-6, \ + f"Unit effect mismatch for {key}: Rust={r_val:.8f}, Python={p_val:.8f}" + + for key in results_rust.time_effects: + r_val = results_rust.time_effects[key] + p_val = results_python.time_effects[key] + assert abs(r_val - p_val) < 1e-6, \ + f"Time effect mismatch for {key}: Rust={r_val:.8f}, Python={p_val:.8f}" + + def test_trop_joint_solver_parity_with_lowrank(self): + """Test Rust/Python solver parity for with-lowrank path (finite lambda_nn). + + Both backends should produce matching (mu, alpha, beta) at atol=1e-6. + The with-lowrank solver calls no-lowrank as its inner step, so the + convergence fix cascades here too. + """ + import pandas as pd + from diff_diff import TROP + from unittest.mock import patch + import sys + + np.random.seed(42) + n_units, n_periods = 15, 8 + n_treated = 4 + n_post = 3 + + data = [] + for i in range(n_units): + is_treated = i < n_treated + for t in range(n_periods): + post = t >= (n_periods - n_post) + y = 5.0 + i * 0.5 + t * 0.4 + np.random.randn() * 0.2 + treatment_indicator = 1 if (is_treated and post) else 0 + if treatment_indicator: + y += 2.0 + data.append({ + 'unit': i, 'time': t, + 'outcome': y, 'treated': treatment_indicator, + }) + df = pd.DataFrame(data) + + # Fixed lambda with finite lambda_nn (low-rank enabled) + trop_params = dict( + method="global", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.1], + n_bootstrap=2, + seed=42, + ) + + # Rust backend + trop_rust = TROP(**trop_params) + results_rust = trop_rust.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') + + # Python-only backend + trop_module = sys.modules['diff_diff.trop'] + with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ + patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \ + patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None): + trop_python = TROP(**trop_params) + results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') + + # ATT should match closely + assert abs(results_rust.att - results_python.att) < 1e-6, \ + f"With-lowrank ATT mismatch: Rust={results_rust.att:.8f}, Python={results_python.att:.8f}" + + # Unit and time effects should match + for key in results_rust.unit_effects: + r_val = results_rust.unit_effects[key] + p_val = results_python.unit_effects[key] + assert abs(r_val - p_val) < 1e-6, \ + f"Unit effect mismatch for {key}: Rust={r_val:.8f}, Python={p_val:.8f}" + @pytest.mark.skipif(not HAS_RUST_BACKEND, reason="Rust backend not available") class TestSDIDRustBackend: diff --git a/tests/test_trop.py b/tests/test_trop.py index a5fc726..0d82a5b 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -3424,3 +3424,215 @@ def test_global_treated_outcome_does_not_affect_fit(self, simple_panel_data): assert np.allclose(beta1, beta2, atol=1e-8), "beta changed" assert np.allclose(L1, L2, atol=1e-8), "L changed" + +class TestTROPNValidTreated: + """Tests for n_valid_treated consistency and NaN treated outcome handling.""" + + @staticmethod + def _make_panel(n_units=20, n_periods=8, n_treated=5, n_post=3, + effect=2.0, seed=42): + """Helper: generate a clean panel DataFrame.""" + rng = np.random.default_rng(seed) + rows = [] + for i in range(n_units): + is_treated = i < n_treated + for t in range(n_periods): + post = t >= (n_periods - n_post) + y = 5.0 + i * 0.3 + t * 0.2 + rng.normal() * 0.3 + d = 1 if (is_treated and post) else 0 + if d: + y += effect + rows.append({'unit': i, 'time': t, 'outcome': y, 'treated': d}) + return pd.DataFrame(rows) + + def test_global_n_treated_obs_partial_nan(self): + """Global method: n_treated_obs reflects only finite outcomes.""" + df = self._make_panel() + + # Inject NaN into some treated outcomes + treated_mask = (df['treated'] == 1) + treated_idx = df[treated_mask].index.tolist() + n_nan = 3 + for idx in treated_idx[:n_nan]: + df.loc[idx, 'outcome'] = np.nan + + total_treated = int(treated_mask.sum()) + + trop_est = TROP( + method="global", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[np.inf], + n_bootstrap=2, + seed=42, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + results = trop_est.fit(df, 'outcome', 'treated', 'unit', 'time') + + assert results.n_treated_obs == total_treated - n_nan, \ + f"Expected {total_treated - n_nan}, got {results.n_treated_obs}" + assert np.isfinite(results.att) + + def test_twostep_n_treated_obs_partial_nan(self): + """Twostep method: n_treated_obs reflects only finite outcomes.""" + df = self._make_panel() + + treated_mask = (df['treated'] == 1) + treated_idx = df[treated_mask].index.tolist() + n_nan = 3 + for idx in treated_idx[:n_nan]: + df.loc[idx, 'outcome'] = np.nan + + total_treated = int(treated_mask.sum()) + + trop_est = TROP( + method="twostep", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[np.inf], + n_bootstrap=2, + seed=42, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + results = trop_est.fit(df, 'outcome', 'treated', 'unit', 'time') + + assert results.n_treated_obs == total_treated - n_nan, \ + f"Expected {total_treated - n_nan}, got {results.n_treated_obs}" + assert np.isfinite(results.att) + + def test_twostep_nan_treated_not_poison_att(self): + """Twostep: NaN treated outcomes don't poison ATT via np.mean.""" + df = self._make_panel(effect=3.0) + + # Make ONE treated outcome NaN + treated_mask = (df['treated'] == 1) + first_treated_idx = df[treated_mask].index[0] + df.loc[first_treated_idx, 'outcome'] = np.nan + + trop_est = TROP( + method="twostep", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[np.inf], + n_bootstrap=2, + seed=42, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + results = trop_est.fit(df, 'outcome', 'treated', 'unit', 'time') + + # ATT must be finite (not NaN from NaN poisoning) + assert np.isfinite(results.att), f"ATT should be finite, got {results.att}" + # ATT should be in reasonable range + assert results.att > 1.0, f"ATT {results.att} should reflect treatment effect" + + def test_global_all_treated_nan_warns(self): + """Global method warns when all treated outcomes are NaN.""" + df = self._make_panel() + + # Set ALL treated outcomes to NaN + df.loc[df['treated'] == 1, 'outcome'] = np.nan + + trop_est = TROP( + method="global", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[np.inf], + n_bootstrap=2, + seed=42, + ) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = trop_est.fit(df, 'outcome', 'treated', 'unit', 'time') + + # Should warn about all NaN treated + nan_warnings = [x for x in w if "All treated outcomes are NaN" in str(x.message)] + assert len(nan_warnings) > 0, "Should warn about all-NaN treated outcomes" + assert results.n_treated_obs == 0 + assert np.isnan(results.att) + + def test_twostep_all_treated_nan_warns(self): + """Twostep method warns when all treated outcomes are NaN.""" + df = self._make_panel() + + df.loc[df['treated'] == 1, 'outcome'] = np.nan + + trop_est = TROP( + method="twostep", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[np.inf], + n_bootstrap=2, + seed=42, + ) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = trop_est.fit(df, 'outcome', 'treated', 'unit', 'time') + + nan_warnings = [x for x in w if "All treated outcomes are NaN" in str(x.message)] + assert len(nan_warnings) > 0, "Should warn about all-NaN treated outcomes" + assert results.n_treated_obs == 0 + assert np.isnan(results.att) + + +class TestTROPBootstrapNaNSE: + """Tests for NaN SE when bootstrap has <2 successful draws.""" + + def test_global_bootstrap_zero_draws_returns_nan_se(self): + """Global bootstrap with 0 successful draws returns NaN SE, not 0.0.""" + from unittest.mock import patch + + df = TestTROPNValidTreated._make_panel() + + trop_est = TROP( + method="global", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[np.inf], + n_bootstrap=5, + seed=42, + ) + + # Patch _fit_joint_with_fixed_lambda to always raise (all bootstrap iters fail) + with patch.object(TROP, '_fit_joint_with_fixed_lambda', + side_effect=ValueError("forced failure")): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + se, dist = trop_est._bootstrap_variance_joint( + df, 'outcome', 'treated', 'unit', 'time', + (1.0, 1.0, 1e10), 3, + ) + + assert np.isnan(se), f"SE should be NaN when 0 draws succeed, got {se}" + assert len(dist) == 0 + + def test_twostep_bootstrap_zero_draws_returns_nan_se(self): + """Twostep bootstrap with 0 successful draws returns NaN SE, not 0.0.""" + from unittest.mock import patch + + df = TestTROPNValidTreated._make_panel() + + trop_est = TROP( + method="twostep", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[np.inf], + n_bootstrap=5, + seed=42, + ) + + # Patch _fit_with_fixed_lambda to always raise + with patch.object(TROP, '_fit_with_fixed_lambda', + side_effect=ValueError("forced failure")): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + se, dist = trop_est._bootstrap_variance( + df, 'outcome', 'treated', 'unit', 'time', + (1.0, 1.0, 1e10), + ) + + assert np.isnan(se), f"SE should be NaN when 0 draws succeed, got {se}" + assert len(dist) == 0 + From 574585ff58f617225d9e45209e23dcea4628beef Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 14 Mar 2026 16:44:07 -0400 Subject: [PATCH 7/7] Fix global bootstrap NaN SE test to disable Rust backend The test patches _fit_joint_with_fixed_lambda to force failures, but on CI with Rust available, the Rust bootstrap path runs instead of the Python fallback. Disable Rust backend in the test to exercise the Python return path. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_trop.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_trop.py b/tests/test_trop.py index 0d82a5b..92b17b2 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -3583,6 +3583,7 @@ class TestTROPBootstrapNaNSE: def test_global_bootstrap_zero_draws_returns_nan_se(self): """Global bootstrap with 0 successful draws returns NaN SE, not 0.0.""" from unittest.mock import patch + import sys df = TestTROPNValidTreated._make_panel() @@ -3595,8 +3596,12 @@ def test_global_bootstrap_zero_draws_returns_nan_se(self): seed=42, ) - # Patch _fit_joint_with_fixed_lambda to always raise (all bootstrap iters fail) - with patch.object(TROP, '_fit_joint_with_fixed_lambda', + # Disable Rust backend so Python fallback path is tested, + # then patch _fit_joint_with_fixed_lambda to always raise + trop_module = sys.modules['diff_diff.trop'] + with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ + patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None), \ + patch.object(TROP, '_fit_joint_with_fixed_lambda', side_effect=ValueError("forced failure")): with warnings.catch_warnings(): warnings.simplefilter("ignore")