diff --git a/CHANGELOG.md b/CHANGELOG.md index b4778e0..be2c7ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Changed +- **BREAKING: TROP nuclear norm solver step size fix** — The proximal gradient + threshold for the L matrix (both `method="joint"` and `method="twostep"` with + finite `lambda_nn`) was over-shrinking singular values by a factor of 2. The + soft-thresholding threshold was λ_nn/max(δ) when the correct value is + λ_nn/(2·max(δ)), derived from the Lipschitz constant L_f=2·max(δ) of the + quadratic gradient. This fix produces higher-rank L matrices and closer + agreement with exact convex optimization solutions. Users with finite + `lambda_nn` will observe different ATT estimates. Added FISTA/Nesterov + acceleration to the twostep inner solver for faster L convergence. + ## [2.6.0] - 2026-02-22 ### Added diff --git a/diff_diff/trop.py b/diff_diff/trop.py index 41f02d0..652c940 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -889,8 +889,8 @@ def _solve_joint_with_lowrank( gradient_step = L + delta_norm * (R - L) # Soft-threshold singular values - # Use eta * lambda_nn for proper proximal step size (matches Rust) - eta = 1.0 / delta_max if delta_max > 0 else 1.0 + # Lipschitz constant of ∇f is L_f = 2·max(δ), step size η = 1/L_f + eta = 1.0 / (2.0 * delta_max) if delta_max > 0 else 0.5 L = self._soft_threshold_svd(gradient_step, eta * lambda_nn) # Check convergence @@ -1990,10 +1990,11 @@ def _weighted_nuclear_norm_solve( paper's Equation 2 (page 7). The full objective is: min_L Σ W_{ti}(R_{ti} - L_{ti})² + λ_nn||L||_* - This uses a proximal gradient / soft-impute approach (Mazumder et al. 2010): - L_{k+1} = prox_{λ||·||_*}(L_k + W ⊙ (R - L_k)) - - where W ⊙ denotes element-wise multiplication with normalized weights. + This uses proximal gradient descent (Mazumder et al. 2010) with + FISTA/Nesterov acceleration. Lipschitz constant L_f = 2·max(W), + step size η = 1/(2·max(W)), proximal threshold η·λ_nn: + G_k = L_k + (W/max(W)) ⊙ (R - L_k) + L_{k+1} = prox_{η·λ_nn·||·||_*}(G_k) IMPORTANT: For observations with W=0 (treated observations), we keep L values from the previous iteration rather than setting L = R, which @@ -2047,20 +2048,30 @@ def _weighted_nuclear_norm_solve( # Initialize L L = L_init.copy() + L_prev = L.copy() + t_fista = 1.0 - # Proximal gradient iteration with weighted soft-impute + # Proximal gradient iteration with FISTA/Nesterov acceleration # This solves: min_L ||W^{1/2} ⊙ (R - L)||_F^2 + λ||L||_* - # Using: L_{k+1} = prox_{λ/η}(L_k + W ⊙ (R - L_k)) - # where η is the step size (we use η = 1 with normalized weights) + # Lipschitz constant L_f = 2·max(W), so η = 1/(2·max(W)) + # Threshold = η·λ_nn = λ_nn/(2·max(W)) for _ in range(max_inner_iter): L_old = L.copy() - # Gradient step: L_k + W ⊙ (R - L_k) - # For W=0 observations, this keeps L_k unchanged - gradient_step = L + W_norm * (R_masked - L) + # FISTA momentum + t_fista_new = (1.0 + np.sqrt(1.0 + 4.0 * t_fista**2)) / 2.0 + momentum = (t_fista - 1.0) / t_fista_new + L_momentum = L + momentum * (L - L_prev) + + # Gradient step from momentum point: L_m + W ⊙ (R - L_m) + # For W=0 observations, this keeps L_m unchanged + gradient_step = L_momentum + W_norm * (R_masked - L_momentum) # Proximal step: soft-threshold singular values - L = self._soft_threshold_svd(gradient_step, lambda_nn) + L_prev = L.copy() + threshold = lambda_nn / (2.0 * W_max) if W_max > 0 else lambda_nn / 2.0 + L = self._soft_threshold_svd(gradient_step, threshold) + t_fista = t_fista_new # Check convergence if np.max(np.abs(L - L_old)) < self.tol: diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 3c88f1d..2570e0a 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1065,10 +1065,15 @@ Optimization (Equation 2): ``` (α̂, β̂, L̂) = argmin_{α,β,L} Σ_j Σ_s θ_s^{i,t} ω_j^{i,t} (1-W_js)(Y_js - α_j - β_s - L_js)² + λ_nn ||L||_* ``` -Solved via alternating minimization with soft-thresholding of singular values for L: +Solved via alternating minimization. For α, β (or μ, α, β, τ in joint): weighted least +squares (closed form). For L: proximal gradient with step size η = 1/(2·max(W)): ``` -L̂ = U × soft_threshold(Σ, λ_nn) × V' +Gradient step: G = L + (W/max(W)) ⊙ (R - L) +Proximal step: L = U × soft_threshold(Σ, η·λ_nn) × V' (SVD of G = UΣV') ``` +where R is the residual after removing fixed effects (and τ·D in joint mode). +The twostep solver's inner L update uses FISTA/Nesterov acceleration (O(1/k²) convergence); +the Python joint solver uses a single proximal gradient step per outer alternating iteration. Per-observation weights (Equation 3): ``` @@ -1184,13 +1189,17 @@ where: 1. **Without low-rank (λ_nn = ∞)**: Standard weighted least squares - Build design matrix with unit/time dummies + treatment indicator - - Solve via iterative coordinate descent for (μ, α, β, τ) + - Solve via np.linalg.lstsq for (μ, α, β, τ) 2. **With low-rank (finite λ_nn)**: Alternating minimization - Alternate between: - Fix L, solve weighted LS for (μ, α, β, τ) - - Fix (μ, α, β, τ), soft-threshold SVD for L (proximal step) - - Continue until convergence + - Fix (μ, α, β, τ), proximal gradient for L: + - Lipschitz constant of ∇f is L_f = 2·max(δ) + - Step size η = 1/L_f = 1/(2·max(δ)) + - Proximal operator: soft_threshold(gradient_step, η·λ_nn) + - Twostep inner solver uses FISTA/Nesterov acceleration (O(1/k²)) + - Continue until max(|L_new - L_old|) < tol **LOOCV parameter selection** (unified with twostep, Equation 5): Following paper's Equation 5 and footnote 2: diff --git a/rust/src/trop.rs b/rust/src/trop.rs index 9e275d5..b503f4c 100644 --- a/rust/src/trop.rs +++ b/rust/src/trop.rs @@ -620,9 +620,10 @@ fn compute_weight_matrix( /// /// Minimizes: Σ W_{ti}(Y_{ti} - α_i - β_t - L_{ti})² + λ_nn||L||_* /// -/// Paper alignment: Uses weighted proximal gradient for L update: -/// L ← prox_{η·λ_nn·||·||_*}(L + η·(W ⊙ (R - L))) -/// where η ≤ 1/max(W) for convergence. +/// Paper alignment: Uses weighted proximal gradient for L update with +/// Lipschitz constant L_f = 2·max(W), step size η = 1/(2·max(W)): +/// G = L + (W/max(W)) ⊙ (R - L) +/// L ← prox_{η·λ_nn·||·||_*}(G) /// /// Returns None if estimation fails due to numerical issues. #[allow(clippy::too_many_arguments)] @@ -660,9 +661,9 @@ fn estimate_model( } }); - // Compute step size for proximal gradient: η ≤ 1/max(W) + // Lipschitz constant of ∇f is L_f = 2·max(W), so prox threshold = λ/(2·max(W)) let w_max = w_masked.iter().cloned().fold(0.0_f64, f64::max); - let eta = if w_max > 0.0 { 1.0 / w_max } else { 1.0 }; + let prox_threshold = if w_max > 0.0 { lambda_nn / (2.0 * w_max) } else { lambda_nn / 2.0 }; // Weight sums per unit and time let weight_sum_per_unit: Array1 = w_masked.sum_axis(Axis(0)); @@ -722,9 +723,9 @@ fn estimate_model( } // Step 2: Update L with WEIGHTED nuclear norm penalty - // Paper alignment: Use proximal gradient instead of direct soft-thresholding - // L ← prox_{η·λ_nn·||·||_*}(L + η·(W ⊙ (R - L))) - // where R = Y - α - β + // Inner FISTA-accelerated proximal gradient loop (α, β fixed) + // L ← prox_{threshold·||·||_*}(L + W_norm ⊙ (R - L)) + // where R = Y - α - β, W_norm = W/max(W) // Compute target residual R = Y - α - β let mut r_target = Array2::::zeros((n_periods, n_units)); @@ -734,18 +735,50 @@ fn estimate_model( } } - // Weighted proximal gradient step: - // gradient_step = L + η * W ⊙ (R - L) - // For W=0 cells (treated obs), this keeps L unchanged - let mut gradient_step = Array2::::zeros((n_periods, n_units)); - for t in 0..n_periods { - for i in 0..n_units { - gradient_step[[t, i]] = l[[t, i]] + eta * w_masked[[t, i]] * (r_target[[t, i]] - l[[t, i]]); + // For W=0 cells, use current L instead of R (prevent absorbing treatment) + let r_masked = Array2::from_shape_fn((n_periods, n_units), |(t, i)| { + if w_masked[[t, i]] > 0.0 { r_target[[t, i]] } else { l[[t, i]] } + }); + + // Normalize weights: W_norm = W / W_max (max becomes 1) + let w_norm = Array2::from_shape_fn((n_periods, n_units), |(t, i)| { + if w_max > 0.0 { w_masked[[t, i]] / w_max } else { w_masked[[t, i]] } + }); + + // FISTA inner loop for L update + let mut l_prev = l.clone(); + let mut t_fista = 1.0_f64; + let max_inner_iter = 10; + + for _ in 0..max_inner_iter { + let l_inner_old = l.clone(); + + // FISTA momentum + let t_fista_new = (1.0 + (1.0 + 4.0 * t_fista * t_fista).sqrt()) / 2.0; + let momentum = (t_fista - 1.0) / t_fista_new; + let l_momentum = Array2::from_shape_fn((n_periods, n_units), |(t, i)| { + l[[t, i]] + momentum * (l[[t, i]] - l_prev[[t, i]]) + }); + + // Gradient step from momentum point + let mut gradient_step = Array2::::zeros((n_periods, n_units)); + for t in 0..n_periods { + for i in 0..n_units { + gradient_step[[t, i]] = l_momentum[[t, i]] + w_norm[[t, i]] * (r_masked[[t, i]] - l_momentum[[t, i]]); + } } - } - // Proximal step: soft-threshold singular values with scaled lambda - l = soft_threshold_svd(&gradient_step, eta * lambda_nn)?; + // Proximal step: soft-threshold with corrected threshold + l_prev = l.clone(); + l = soft_threshold_svd(&gradient_step, prox_threshold)?; + t_fista = t_fista_new; + + // Check inner convergence + let l_inner_diff = max_abs_diff_2d(&l, &l_inner_old); + if l_inner_diff < tol { + break; + } + } // Check convergence let alpha_diff = max_abs_diff(&alpha, &alpha_old); @@ -1327,8 +1360,9 @@ fn solve_joint_with_lowrank( } // Weighted proximal step for L (soft-threshold SVD) + // Lipschitz constant of ∇f is L_f = 2·max(δ), step size η = 1/L_f let delta_max = delta.iter().cloned().fold(0.0_f64, f64::max); - let eta = if delta_max > 0.0 { 1.0 / delta_max } else { 1.0 }; + let eta = if delta_max > 0.0 { 1.0 / (2.0 * delta_max) } else { 0.5 }; // gradient_step = L + eta * delta * (R - L) // NaN outcomes get zero weight so they don't affect gradient diff --git a/tests/test_trop.py b/tests/test_trop.py index 41a137f..726c2e6 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -2592,6 +2592,128 @@ def test_n_post_periods_counts_observed_treatment(self): ) +class TestTROPNuclearNormSolver: + """Tests for proximal gradient step size correctness and objective monotonicity.""" + + def test_proximal_step_size_correctness(self): + """Verify L converges to prox_{λ/2}(R) for uniform weights.""" + trop_est = TROP(method="joint", n_bootstrap=2) + + # Small problem with known solution + rng = np.random.default_rng(42) + R = rng.normal(0, 1, (4, 3)) + delta = np.ones((4, 3)) + lambda_nn = 0.5 + + # Run solver (many iterations to ensure convergence) + L = np.zeros_like(R) + for _ in range(500): + delta_max = np.max(delta) + delta_norm = delta / delta_max + gradient_step = L + delta_norm * (R - L) + eta = 1.0 / (2.0 * delta_max) + L = trop_est._soft_threshold_svd(gradient_step, eta * lambda_nn) + + # Analytical solution for uniform weights: prox_{λ/2}(R) + L_exact = trop_est._soft_threshold_svd(R, lambda_nn / 2.0) + + np.testing.assert_array_almost_equal(L, L_exact, decimal=4) + + def test_lowrank_objective_decreases(self): + """Verify objective f(L) + λ||L||_* is non-increasing across iterations.""" + # Generate small problem + rng = np.random.default_rng(42) + R = rng.normal(0, 1, (6, 4)) + delta = rng.uniform(0.5, 2.0, (6, 4)) + lambda_nn = 0.3 + + trop_est = TROP(method="joint", n_bootstrap=2) + L = np.zeros_like(R) + objectives = [] + + for _ in range(50): + # Compute objective + f_val = np.sum(delta * (R - L) ** 2) + _, s, _ = np.linalg.svd(L, full_matrices=False) + obj = f_val + lambda_nn * np.sum(s) + objectives.append(obj) + + # Proximal gradient step + delta_max = np.max(delta) + delta_norm = delta / delta_max + gradient_step = L + delta_norm * (R - L) + eta = 1.0 / (2.0 * delta_max) + L = trop_est._soft_threshold_svd(gradient_step, eta * lambda_nn) + + # Objective should be non-increasing (within numerical tolerance) + for k in range(1, len(objectives)): + assert objectives[k] <= objectives[k - 1] + 1e-10, ( + f"Objective increased at step {k}: {objectives[k]} > {objectives[k-1]}" + ) + + def test_twostep_nonuniform_weights_objective(self): + """Verify objective decreases with non-uniform weights (W_max < 1).""" + rng = np.random.default_rng(123) + R = rng.normal(0, 1, (6, 4)) + W = rng.uniform(0.1, 0.8, (6, 4)) + lambda_nn = 0.3 + + trop_est = TROP(method="twostep", n_bootstrap=2) + + # Initial objective with L=0 + L_init = np.zeros_like(R) + f_init = np.sum(W * (R - L_init) ** 2) + _, s_init, _ = np.linalg.svd(L_init, full_matrices=False) + obj_init = f_init + lambda_nn * np.sum(s_init) + + # Solve + L_final = trop_est._weighted_nuclear_norm_solve( + Y=R, + W=W, + L_init=L_init, + alpha=np.zeros(R.shape[1]), + beta=np.zeros(R.shape[0]), + lambda_nn=lambda_nn, + max_inner_iter=20, + ) + + # Final objective + f_final = np.sum(W * (R - L_final) ** 2) + _, s_final, _ = np.linalg.svd(L_final, full_matrices=False) + obj_final = f_final + lambda_nn * np.sum(s_final) + + assert obj_final <= obj_init + 1e-10, ( + f"Objective did not decrease: {obj_final} > {obj_init}" + ) + + # Soft-thresholding should reduce nuclear norm vs residual + nuclear_norm_R = np.sum(np.linalg.svd(R, compute_uv=False)) + nuclear_norm_L = np.sum(s_final) + assert nuclear_norm_L < nuclear_norm_R, ( + f"Nuclear norm not reduced: {nuclear_norm_L} >= {nuclear_norm_R}" + ) + + def test_zero_weights_no_division_error(self): + """Verify solver handles all-zero weights without ZeroDivisionError.""" + rng = np.random.default_rng(99) + Y = rng.normal(0, 1, (6, 4)) + W = np.zeros((6, 4)) + L_init = rng.normal(0, 1, (6, 4)) + + trop_est = TROP(method="twostep", n_bootstrap=2) + result = trop_est._weighted_nuclear_norm_solve( + Y=Y, + W=W, + L_init=L_init, + alpha=np.zeros(4), + beta=np.zeros(6), + lambda_nn=0.3, + ) + + assert np.isfinite(result).all(), "Result contains NaN or Inf" + assert result.shape == (6, 4), f"Expected (6, 4), got {result.shape}" + + class TestTROPJointMethod: """Tests for TROP method='joint'.