Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Changed
- **BREAKING: TROP nuclear norm solver step size fix** — The proximal gradient
threshold for the L matrix (both `method="joint"` and `method="twostep"` with
finite `lambda_nn`) was over-shrinking singular values by a factor of 2. The
soft-thresholding threshold was λ_nn/max(δ) when the correct value is
λ_nn/(2·max(δ)), derived from the Lipschitz constant L_f=2·max(δ) of the
quadratic gradient. This fix produces higher-rank L matrices and closer
agreement with exact convex optimization solutions. Users with finite
`lambda_nn` will observe different ATT estimates. Added FISTA/Nesterov
acceleration to the twostep inner solver for faster L convergence.

## [2.6.0] - 2026-02-22

### Added
Expand Down
37 changes: 24 additions & 13 deletions diff_diff/trop.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,8 +889,8 @@ def _solve_joint_with_lowrank(
gradient_step = L + delta_norm * (R - L)

# Soft-threshold singular values
# Use eta * lambda_nn for proper proximal step size (matches Rust)
eta = 1.0 / delta_max if delta_max > 0 else 1.0
# Lipschitz constant of ∇f is L_f = 2·max(δ), step size η = 1/L_f
eta = 1.0 / (2.0 * delta_max) if delta_max > 0 else 0.5
L = self._soft_threshold_svd(gradient_step, eta * lambda_nn)

# Check convergence
Expand Down Expand Up @@ -1990,10 +1990,11 @@ def _weighted_nuclear_norm_solve(
paper's Equation 2 (page 7). The full objective is:
min_L Σ W_{ti}(R_{ti} - L_{ti})² + λ_nn||L||_*

This uses a proximal gradient / soft-impute approach (Mazumder et al. 2010):
L_{k+1} = prox_{λ||·||_*}(L_k + W ⊙ (R - L_k))

where W ⊙ denotes element-wise multiplication with normalized weights.
This uses proximal gradient descent (Mazumder et al. 2010) with
FISTA/Nesterov acceleration. Lipschitz constant L_f = 2·max(W),
step size η = 1/(2·max(W)), proximal threshold η·λ_nn:
G_k = L_k + (W/max(W)) ⊙ (R - L_k)
L_{k+1} = prox_{η·λ_nn·||·||_*}(G_k)

IMPORTANT: For observations with W=0 (treated observations), we keep
L values from the previous iteration rather than setting L = R, which
Expand Down Expand Up @@ -2047,20 +2048,30 @@ def _weighted_nuclear_norm_solve(

# Initialize L
L = L_init.copy()
L_prev = L.copy()
t_fista = 1.0

# Proximal gradient iteration with weighted soft-impute
# Proximal gradient iteration with FISTA/Nesterov acceleration
# This solves: min_L ||W^{1/2} ⊙ (R - L)||_F^2 + λ||L||_*
# Using: L_{k+1} = prox_{λ/η}(L_k + W ⊙ (R - L_k))
# where η is the step size (we use η = 1 with normalized weights)
# Lipschitz constant L_f = 2·max(W), so η = 1/(2·max(W))
# Threshold = η·λ_nn = λ_nn/(2·max(W))
for _ in range(max_inner_iter):
L_old = L.copy()

# Gradient step: L_k + W ⊙ (R - L_k)
# For W=0 observations, this keeps L_k unchanged
gradient_step = L + W_norm * (R_masked - L)
# FISTA momentum
t_fista_new = (1.0 + np.sqrt(1.0 + 4.0 * t_fista**2)) / 2.0
momentum = (t_fista - 1.0) / t_fista_new
L_momentum = L + momentum * (L - L_prev)

# Gradient step from momentum point: L_m + W ⊙ (R - L_m)
# For W=0 observations, this keeps L_m unchanged
gradient_step = L_momentum + W_norm * (R_masked - L_momentum)

# Proximal step: soft-threshold singular values
L = self._soft_threshold_svd(gradient_step, lambda_nn)
L_prev = L.copy()
threshold = lambda_nn / (2.0 * W_max) if W_max > 0 else lambda_nn / 2.0
L = self._soft_threshold_svd(gradient_step, threshold)
t_fista = t_fista_new

# Check convergence
if np.max(np.abs(L - L_old)) < self.tol:
Expand Down
19 changes: 14 additions & 5 deletions docs/methodology/REGISTRY.md
Original file line number Diff line number Diff line change
Expand Up @@ -1065,10 +1065,15 @@ Optimization (Equation 2):
```
(α̂, β̂, L̂) = argmin_{α,β,L} Σ_j Σ_s θ_s^{i,t} ω_j^{i,t} (1-W_js)(Y_js - α_j - β_s - L_js)² + λ_nn ||L||_*
```
Solved via alternating minimization with soft-thresholding of singular values for L:
Solved via alternating minimization. For α, β (or μ, α, β, τ in joint): weighted least
squares (closed form). For L: proximal gradient with step size η = 1/(2·max(W)):
```
L̂ = U × soft_threshold(Σ, λ_nn) × V'
Gradient step: G = L + (W/max(W)) ⊙ (R - L)
Proximal step: L = U × soft_threshold(Σ, η·λ_nn) × V' (SVD of G = UΣV')
```
where R is the residual after removing fixed effects (and τ·D in joint mode).
The twostep solver's inner L update uses FISTA/Nesterov acceleration (O(1/k²) convergence);
the Python joint solver uses a single proximal gradient step per outer alternating iteration.

Per-observation weights (Equation 3):
```
Expand Down Expand Up @@ -1184,13 +1189,17 @@ where:

1. **Without low-rank (λ_nn = ∞)**: Standard weighted least squares
- Build design matrix with unit/time dummies + treatment indicator
- Solve via iterative coordinate descent for (μ, α, β, τ)
- Solve via np.linalg.lstsq for (μ, α, β, τ)

2. **With low-rank (finite λ_nn)**: Alternating minimization
- Alternate between:
- Fix L, solve weighted LS for (μ, α, β, τ)
- Fix (μ, α, β, τ), soft-threshold SVD for L (proximal step)
- Continue until convergence
- Fix (μ, α, β, τ), proximal gradient for L:
- Lipschitz constant of ∇f is L_f = 2·max(δ)
- Step size η = 1/L_f = 1/(2·max(δ))
- Proximal operator: soft_threshold(gradient_step, η·λ_nn)
- Twostep inner solver uses FISTA/Nesterov acceleration (O(1/k²))
- Continue until max(|L_new - L_old|) < tol

**LOOCV parameter selection** (unified with twostep, Equation 5):
Following paper's Equation 5 and footnote 2:
Expand Down
72 changes: 53 additions & 19 deletions rust/src/trop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,9 +620,10 @@ fn compute_weight_matrix(
///
/// Minimizes: Σ W_{ti}(Y_{ti} - α_i - β_t - L_{ti})² + λ_nn||L||_*
///
/// Paper alignment: Uses weighted proximal gradient for L update:
/// L ← prox_{η·λ_nn·||·||_*}(L + η·(W ⊙ (R - L)))
/// where η ≤ 1/max(W) for convergence.
/// Paper alignment: Uses weighted proximal gradient for L update with
/// Lipschitz constant L_f = 2·max(W), step size η = 1/(2·max(W)):
/// G = L + (W/max(W)) ⊙ (R - L)
/// L ← prox_{η·λ_nn·||·||_*}(G)
///
/// Returns None if estimation fails due to numerical issues.
#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -660,9 +661,9 @@ fn estimate_model(
}
});

// Compute step size for proximal gradient: η ≤ 1/max(W)
// Lipschitz constant of ∇f is L_f = 2·max(W), so prox threshold = λ/(2·max(W))
let w_max = w_masked.iter().cloned().fold(0.0_f64, f64::max);
let eta = if w_max > 0.0 { 1.0 / w_max } else { 1.0 };
let prox_threshold = if w_max > 0.0 { lambda_nn / (2.0 * w_max) } else { lambda_nn / 2.0 };

// Weight sums per unit and time
let weight_sum_per_unit: Array1<f64> = w_masked.sum_axis(Axis(0));
Expand Down Expand Up @@ -722,9 +723,9 @@ fn estimate_model(
}

// Step 2: Update L with WEIGHTED nuclear norm penalty
// Paper alignment: Use proximal gradient instead of direct soft-thresholding
// L ← prox_{η·λ_nn·||·||_*}(L + η·(W ⊙ (R - L)))
// where R = Y - α - β
// Inner FISTA-accelerated proximal gradient loop (α, β fixed)
// L ← prox_{threshold·||·||_*}(L + W_norm ⊙ (R - L))
// where R = Y - α - β, W_norm = W/max(W)

// Compute target residual R = Y - α - β
let mut r_target = Array2::<f64>::zeros((n_periods, n_units));
Expand All @@ -734,18 +735,50 @@ fn estimate_model(
}
}

// Weighted proximal gradient step:
// gradient_step = L + η * W ⊙ (R - L)
// For W=0 cells (treated obs), this keeps L unchanged
let mut gradient_step = Array2::<f64>::zeros((n_periods, n_units));
for t in 0..n_periods {
for i in 0..n_units {
gradient_step[[t, i]] = l[[t, i]] + eta * w_masked[[t, i]] * (r_target[[t, i]] - l[[t, i]]);
// For W=0 cells, use current L instead of R (prevent absorbing treatment)
let r_masked = Array2::from_shape_fn((n_periods, n_units), |(t, i)| {
if w_masked[[t, i]] > 0.0 { r_target[[t, i]] } else { l[[t, i]] }
});

// Normalize weights: W_norm = W / W_max (max becomes 1)
let w_norm = Array2::from_shape_fn((n_periods, n_units), |(t, i)| {
if w_max > 0.0 { w_masked[[t, i]] / w_max } else { w_masked[[t, i]] }
});

// FISTA inner loop for L update
let mut l_prev = l.clone();
let mut t_fista = 1.0_f64;
let max_inner_iter = 10;

for _ in 0..max_inner_iter {
let l_inner_old = l.clone();

// FISTA momentum
let t_fista_new = (1.0 + (1.0 + 4.0 * t_fista * t_fista).sqrt()) / 2.0;
let momentum = (t_fista - 1.0) / t_fista_new;
let l_momentum = Array2::from_shape_fn((n_periods, n_units), |(t, i)| {
l[[t, i]] + momentum * (l[[t, i]] - l_prev[[t, i]])
});

// Gradient step from momentum point
let mut gradient_step = Array2::<f64>::zeros((n_periods, n_units));
for t in 0..n_periods {
for i in 0..n_units {
gradient_step[[t, i]] = l_momentum[[t, i]] + w_norm[[t, i]] * (r_masked[[t, i]] - l_momentum[[t, i]]);
}
}
}

// Proximal step: soft-threshold singular values with scaled lambda
l = soft_threshold_svd(&gradient_step, eta * lambda_nn)?;
// Proximal step: soft-threshold with corrected threshold
l_prev = l.clone();
l = soft_threshold_svd(&gradient_step, prox_threshold)?;
t_fista = t_fista_new;

// Check inner convergence
let l_inner_diff = max_abs_diff_2d(&l, &l_inner_old);
if l_inner_diff < tol {
break;
}
}

// Check convergence
let alpha_diff = max_abs_diff(&alpha, &alpha_old);
Expand Down Expand Up @@ -1327,8 +1360,9 @@ fn solve_joint_with_lowrank(
}

// Weighted proximal step for L (soft-threshold SVD)
// Lipschitz constant of ∇f is L_f = 2·max(δ), step size η = 1/L_f
let delta_max = delta.iter().cloned().fold(0.0_f64, f64::max);
let eta = if delta_max > 0.0 { 1.0 / delta_max } else { 1.0 };
let eta = if delta_max > 0.0 { 1.0 / (2.0 * delta_max) } else { 0.5 };

// gradient_step = L + eta * delta * (R - L)
// NaN outcomes get zero weight so they don't affect gradient
Expand Down
122 changes: 122 additions & 0 deletions tests/test_trop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2592,6 +2592,128 @@ def test_n_post_periods_counts_observed_treatment(self):
)


class TestTROPNuclearNormSolver:
"""Tests for proximal gradient step size correctness and objective monotonicity."""

def test_proximal_step_size_correctness(self):
"""Verify L converges to prox_{λ/2}(R) for uniform weights."""
trop_est = TROP(method="joint", n_bootstrap=2)

# Small problem with known solution
rng = np.random.default_rng(42)
R = rng.normal(0, 1, (4, 3))
delta = np.ones((4, 3))
lambda_nn = 0.5

# Run solver (many iterations to ensure convergence)
L = np.zeros_like(R)
for _ in range(500):
delta_max = np.max(delta)
delta_norm = delta / delta_max
gradient_step = L + delta_norm * (R - L)
eta = 1.0 / (2.0 * delta_max)
L = trop_est._soft_threshold_svd(gradient_step, eta * lambda_nn)

# Analytical solution for uniform weights: prox_{λ/2}(R)
L_exact = trop_est._soft_threshold_svd(R, lambda_nn / 2.0)

np.testing.assert_array_almost_equal(L, L_exact, decimal=4)

def test_lowrank_objective_decreases(self):
"""Verify objective f(L) + λ||L||_* is non-increasing across iterations."""
# Generate small problem
rng = np.random.default_rng(42)
R = rng.normal(0, 1, (6, 4))
delta = rng.uniform(0.5, 2.0, (6, 4))
lambda_nn = 0.3

trop_est = TROP(method="joint", n_bootstrap=2)
L = np.zeros_like(R)
objectives = []

for _ in range(50):
# Compute objective
f_val = np.sum(delta * (R - L) ** 2)
_, s, _ = np.linalg.svd(L, full_matrices=False)
obj = f_val + lambda_nn * np.sum(s)
objectives.append(obj)

# Proximal gradient step
delta_max = np.max(delta)
delta_norm = delta / delta_max
gradient_step = L + delta_norm * (R - L)
eta = 1.0 / (2.0 * delta_max)
L = trop_est._soft_threshold_svd(gradient_step, eta * lambda_nn)

# Objective should be non-increasing (within numerical tolerance)
for k in range(1, len(objectives)):
assert objectives[k] <= objectives[k - 1] + 1e-10, (
f"Objective increased at step {k}: {objectives[k]} > {objectives[k-1]}"
)

def test_twostep_nonuniform_weights_objective(self):
"""Verify objective decreases with non-uniform weights (W_max < 1)."""
rng = np.random.default_rng(123)
R = rng.normal(0, 1, (6, 4))
W = rng.uniform(0.1, 0.8, (6, 4))
lambda_nn = 0.3

trop_est = TROP(method="twostep", n_bootstrap=2)

# Initial objective with L=0
L_init = np.zeros_like(R)
f_init = np.sum(W * (R - L_init) ** 2)
_, s_init, _ = np.linalg.svd(L_init, full_matrices=False)
obj_init = f_init + lambda_nn * np.sum(s_init)

# Solve
L_final = trop_est._weighted_nuclear_norm_solve(
Y=R,
W=W,
L_init=L_init,
alpha=np.zeros(R.shape[1]),
beta=np.zeros(R.shape[0]),
lambda_nn=lambda_nn,
max_inner_iter=20,
)

# Final objective
f_final = np.sum(W * (R - L_final) ** 2)
_, s_final, _ = np.linalg.svd(L_final, full_matrices=False)
obj_final = f_final + lambda_nn * np.sum(s_final)

assert obj_final <= obj_init + 1e-10, (
f"Objective did not decrease: {obj_final} > {obj_init}"
)

# Soft-thresholding should reduce nuclear norm vs residual
nuclear_norm_R = np.sum(np.linalg.svd(R, compute_uv=False))
nuclear_norm_L = np.sum(s_final)
assert nuclear_norm_L < nuclear_norm_R, (
f"Nuclear norm not reduced: {nuclear_norm_L} >= {nuclear_norm_R}"
)

def test_zero_weights_no_division_error(self):
"""Verify solver handles all-zero weights without ZeroDivisionError."""
rng = np.random.default_rng(99)
Y = rng.normal(0, 1, (6, 4))
W = np.zeros((6, 4))
L_init = rng.normal(0, 1, (6, 4))

trop_est = TROP(method="twostep", n_bootstrap=2)
result = trop_est._weighted_nuclear_norm_solve(
Y=Y,
W=W,
L_init=L_init,
alpha=np.zeros(4),
beta=np.zeros(6),
lambda_nn=0.3,
)

assert np.isfinite(result).all(), "Result contains NaN or Inf"
assert result.shape == (6, 4), f"Expected (6, 4), got {result.shape}"


class TestTROPJointMethod:
"""Tests for TROP method='joint'.

Expand Down