I think the below line is buggy because we have a batch dimension n here, and the norm is calculated across samples, giving a scalar. I think the norm should be calculated for each sample separately, meaning we should have a norm of shape (n,).
|
norm = tc.einsum('na,na->digit', psi, psi.conj()) |