diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..9b55eac --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +default_language_version: + python: python3 + +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.10 + hooks: + # Run the linter + - id: ruff + args: [ --fix, --config, pyproject.toml ] + # Run the formatter + - id: ruff-format + - repo: https://github.com/codespell-project/codespell + rev: v2.3.0 + hooks: + - id: codespell + additional_dependencies: + - tomli \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index 1973fa6..2fbc674 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -49,9 +49,9 @@ cuslines/ cu_propagate_seeds.py # SeedBatchPropagator: chunked seed processing cu_direction_getters.py # Direction getter ABC + Boot/Prob/PTT implementations cutils.py # REAL_DTYPE, REAL3_DTYPE, checkCudaErrors(), ModelType enum - _globals.py # AUTO-GENERATED from globals.h (never edit manually) + _globals.py # Global constants useful for all languages cuda_c/ # CUDA kernel source - globals.h # Source-of-truth for constants (REAL_SIZE, thread config) + globals.h # CUDA specific global constants generate_streamlines_cuda.cu, boot.cu, ptt.cu, tracking_helpers.cu, utils.cu cudamacro.h, cuwsort.cuh, ptt.cuh, disc.h metal/ # Metal backend (mirrors cuda_python/) @@ -82,7 +82,6 @@ Each has `from_dipy_*()` class methods for initialization from DIPY models. ## Critical Conventions -- **`_globals.py` is auto-generated** from `cuslines/cuda_c/globals.h` during `setup.py` build via `defines_to_python()`. Never edit it manually; change `globals.h` and rebuild. - **GPU arrays must be C-contiguous** — always use `np.ascontiguousarray()` and project scalar types (`REAL_DTYPE`, `REAL_SIZE` from `cutils.py` or `mutils.py`). - **All CUDA API calls must be wrapped** with `checkCudaErrors()`. - **Angle units**: CLI accepts degrees, internals convert to radians before the GPU layer. diff --git a/cuslines/__init__.py b/cuslines/__init__.py index e4085c5..2f9f9f7 100644 --- a/cuslines/__init__.py +++ b/cuslines/__init__.py @@ -34,25 +34,37 @@ def _detect_backend(): BACKEND = _detect_backend() if BACKEND == "metal": + from cuslines.metal import ( + MetalBootDirectionGetter as BootDirectionGetter, + ) from cuslines.metal import ( MetalGPUTracker as GPUTracker, + ) + from cuslines.metal import ( MetalProbDirectionGetter as ProbDirectionGetter, + ) + from cuslines.metal import ( MetalPttDirectionGetter as PttDirectionGetter, - MetalBootDirectionGetter as BootDirectionGetter, ) elif BACKEND == "cuda": from cuslines.cuda_python import ( + BootDirectionGetter, GPUTracker, ProbDirectionGetter, PttDirectionGetter, - BootDirectionGetter, ) elif BACKEND == "webgpu": from cuslines.webgpu import ( - WebGPUTracker as GPUTracker, + WebGPUBootDirectionGetter as BootDirectionGetter, + ) + from cuslines.webgpu import ( WebGPUProbDirectionGetter as ProbDirectionGetter, + ) + from cuslines.webgpu import ( WebGPUPttDirectionGetter as PttDirectionGetter, - WebGPUBootDirectionGetter as BootDirectionGetter, + ) + from cuslines.webgpu import ( + WebGPUTracker as GPUTracker, ) else: raise ImportError( diff --git a/cuslines/boot_utils.py b/cuslines/boot_utils.py index 50abd7b..054f619 100644 --- a/cuslines/boot_utils.py +++ b/cuslines/boot_utils.py @@ -8,19 +8,25 @@ from dipy.reconst import shm -def prepare_opdt(gtab, sphere, sh_order_max=6, full_basis=False, - sh_lambda=0.006, min_signal=1): +def prepare_opdt( + gtab, sphere, sh_order_max=6, full_basis=False, sh_lambda=0.006, min_signal=1 +): """Build bootstrap matrices for the OPDT model. Returns dict with keys: model_type, min_signal, H, R, delta_b, delta_q, sampling_matrix, b0s_mask. """ sampling_matrix, _, _ = shm.real_sh_descoteaux( - sh_order_max, sphere.theta, sphere.phi, - full_basis=full_basis, legacy=True, + sh_order_max, + sphere.theta, + sphere.phi, + full_basis=full_basis, + legacy=True, ) model = shm.OpdtModel( - gtab, sh_order_max=sh_order_max, smooth=sh_lambda, + gtab, + sh_order_max=sh_order_max, + smooth=sh_lambda, min_signal=min_signal, ) delta_b, delta_q = model._fit_matrix @@ -28,25 +34,36 @@ def prepare_opdt(gtab, sphere, sh_order_max=6, full_basis=False, H, R = _hat_and_lcr(gtab, model, sh_order_max) return dict( - model_type="OPDT", min_signal=min_signal, - H=H, R=R, delta_b=delta_b, delta_q=delta_q, - sampling_matrix=sampling_matrix, b0s_mask=gtab.b0s_mask, + model_type="OPDT", + min_signal=min_signal, + H=H, + R=R, + delta_b=delta_b, + delta_q=delta_q, + sampling_matrix=sampling_matrix, + b0s_mask=gtab.b0s_mask, ) -def prepare_csa(gtab, sphere, sh_order_max=6, full_basis=False, - sh_lambda=0.006, min_signal=1): +def prepare_csa( + gtab, sphere, sh_order_max=6, full_basis=False, sh_lambda=0.006, min_signal=1 +): """Build bootstrap matrices for the CSA model. Returns dict with keys: model_type, min_signal, H, R, delta_b, delta_q, sampling_matrix, b0s_mask. """ sampling_matrix, _, _ = shm.real_sh_descoteaux( - sh_order_max, sphere.theta, sphere.phi, - full_basis=full_basis, legacy=True, + sh_order_max, + sphere.theta, + sphere.phi, + full_basis=full_basis, + legacy=True, ) model = shm.CsaOdfModel( - gtab, sh_order_max=sh_order_max, smooth=sh_lambda, + gtab, + sh_order_max=sh_order_max, + smooth=sh_lambda, min_signal=min_signal, ) delta_b = model._fit_matrix @@ -55,9 +72,14 @@ def prepare_csa(gtab, sphere, sh_order_max=6, full_basis=False, H, R = _hat_and_lcr(gtab, model, sh_order_max) return dict( - model_type="CSA", min_signal=min_signal, - H=H, R=R, delta_b=delta_b, delta_q=delta_q, - sampling_matrix=sampling_matrix, b0s_mask=gtab.b0s_mask, + model_type="CSA", + min_signal=min_signal, + H=H, + R=R, + delta_b=delta_b, + delta_q=delta_q, + sampling_matrix=sampling_matrix, + b0s_mask=gtab.b0s_mask, ) diff --git a/cuslines/cuda_c/boot.cu b/cuslines/cuda_c/boot.cu index 133c43d..978d158 100644 --- a/cuslines/cuda_c/boot.cu +++ b/cuslines/cuda_c/boot.cu @@ -116,8 +116,7 @@ template -__device__ int closest_peak_d(const REAL_T max_angle, - const REAL3_T direction, //dir +__device__ int closest_peak_d(const REAL3_T direction, //dir const int npeaks, const REAL3_T *__restrict__ peaks, REAL3_T *__restrict__ peak) {// dirs, @@ -127,8 +126,7 @@ __device__ int closest_peak_d(const REAL_T max_angle, const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32; const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1))); - //const REAL_T cos_similarity = COS(MAX_ANGLE_P); - const REAL_T cos_similarity = COS(max_angle); + const REAL_T cos_similarity = COS(MAX_ANGLE); #if 0 if (!threadIdx.y && !tidx) { printf("direction: (%f, %f, %f)\n", @@ -400,34 +398,19 @@ template __device__ int get_direction_boot_d( curandStatePhilox4_32_10_t *st, - const REAL_T max_angle, - const REAL_T min_signal, - const REAL_T relative_peak_thres, - const REAL_T min_separation_angle, + const REAL_T min_signal, REAL3_T dir, - const int dimx, - const int dimy, - const int dimz, - const int dimt, const REAL_T *__restrict__ dataf, - const int *__restrict__ b0s_mask, // not using this (and its opposite, dwi_mask) - // but not clear if it will never be needed so - // we'll keep it here for now... + const int *__restrict__ b0s_mask, const REAL3_T point, const REAL_T *__restrict__ H, const REAL_T *__restrict__ R, - // model unused - // max_angle, pmf_threshold from global defines - // b0s_mask already passed - // min_signal from global defines const int delta_nr, const REAL_T *__restrict__ delta_b, const REAL_T *__restrict__ delta_q, // fit_matrix - const int samplm_nr, const REAL_T *__restrict__ sampling_matrix, const REAL3_T *__restrict__ sphere_vertices, const int2 *__restrict__ sphere_edges, - const int num_edges, REAL3_T *__restrict__ dirs) { const int tidx = threadIdx.x; @@ -436,25 +419,23 @@ __device__ int get_direction_boot_d( const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32; const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1))); - const int n32dimt = ((dimt+31)/32)*32; - extern REAL_T __shared__ __sh[]; REAL_T *__vox_data_sh = reinterpret_cast(__sh); - REAL_T *__msk_data_sh = __vox_data_sh + BDIM_Y*n32dimt; + REAL_T *__msk_data_sh = __vox_data_sh + BDIM_Y*N32DIMT; - REAL_T *__r_sh = __msk_data_sh + BDIM_Y*n32dimt; - REAL_T *__h_sh = __r_sh + BDIM_Y*MAX(n32dimt, samplm_nr); + REAL_T *__r_sh = __msk_data_sh + BDIM_Y*N32DIMT; + REAL_T *__h_sh = __r_sh + BDIM_Y*MAX(N32DIMT, SAMPLM_NR); - __vox_data_sh += tidy*n32dimt; - __msk_data_sh += tidy*n32dimt; + __vox_data_sh += tidy*N32DIMT; + __msk_data_sh += tidy*N32DIMT; + + __r_sh += tidy*MAX(N32DIMT, SAMPLM_NR); + __h_sh += tidy*MAX(N32DIMT, SAMPLM_NR); - __r_sh += tidy*MAX(n32dimt, samplm_nr); - __h_sh += tidy*MAX(n32dimt, samplm_nr); - // compute hr_side (may be passed from python) int hr_side = 0; - for(int j = tidx; j < dimt; j += BDIM_X) { + for(int j = tidx; j < DIMT; j += BDIM_X) { hr_side += !b0s_mask[j] ? 1 : 0; } #pragma unroll @@ -465,15 +446,15 @@ __device__ int get_direction_boot_d( #pragma unroll for(int i = 0; i < NATTEMPTS; i++) { - const int rv = trilinear_interp_d(dimx, dimy, dimz, dimt, -1, dataf, point, __vox_data_sh); + const int rv = trilinear_interp_d(dataf, point, __vox_data_sh); - const int nmsk = maskGet(dimt, b0s_mask, __vox_data_sh, __msk_data_sh); + const int nmsk = maskGet(DIMT, b0s_mask, __vox_data_sh, __msk_data_sh); //if (!tidx && !threadIdx.y && !blockIdx.x) { // // printf("interp of %f, %f, %f\n", point.x, point.y, point.z); // printf("hr_side: %d\n", hr_side); - // printArray("vox_data", 6, dimt, __vox_data_sh[tidy]); + // printArray("vox_data", 6, DIMT, __vox_data_sh[tidy]); // printArray("msk_data", 6, nmsk, __msk_data_sh[tidy]); //} //break; @@ -513,21 +494,21 @@ __device__ int get_direction_boot_d( //__syncwarp(); // vox_data[dwi_mask] = masked_data - maskPut(dimt, b0s_mask, __h_sh, __vox_data_sh); + maskPut(DIMT, b0s_mask, __h_sh, __vox_data_sh); __syncwarp(WMASK); - //printArray("vox_data[dwi_mask]:", 6, dimt, __vox_data_sh[tidy]); + //printArray("vox_data[dwi_mask]:", 6, DIMT, __vox_data_sh[tidy]); //__syncwarp(); - for(int j = tidx; j < dimt; j += BDIM_X) { + for(int j = tidx; j < DIMT; j += BDIM_X) { //__vox_data_sh[j] = MAX(MIN_SIGNAL_P, __vox_data_sh[j]); __vox_data_sh[j] = MAX(min_signal, __vox_data_sh[j]); } __syncwarp(WMASK); - const REAL_T denom = avgMask(dimt, b0s_mask, __vox_data_sh); + const REAL_T denom = avgMask(DIMT, b0s_mask, __vox_data_sh); - for(int j = tidx; j < dimt; j += BDIM_X) { + for(int j = tidx; j < DIMT; j += BDIM_X) { __vox_data_sh[j] /= denom; } __syncwarp(); @@ -539,23 +520,23 @@ __device__ int get_direction_boot_d( //if (!tidx && !threadIdx.y && !blockIdx.x) { // // printf("__vox_data_sh:\n"); - // printArray("vox_data", 6, dimt, __vox_data_sh[tidy]); + // printArray("vox_data", 6, DIMT, __vox_data_sh[tidy]); //} //break; - maskGet(dimt, b0s_mask, __vox_data_sh, __msk_data_sh); + maskGet(DIMT, b0s_mask, __vox_data_sh, __msk_data_sh); __syncwarp(WMASK); fit_model_coef(delta_nr, hr_side, delta_q, delta_b, __msk_data_sh, __h_sh, __r_sh); // __r_sh[tidy] <- python 'coef' - ndotp_d(samplm_nr, delta_nr, __r_sh, sampling_matrix, __h_sh); + ndotp_d(SAMPLM_NR, delta_nr, __r_sh, sampling_matrix, __h_sh); // __h_sh[tidy] <- python 'pmf' } else { #pragma unroll - for(int j = tidx; j < samplm_nr; j += BDIM_X) { + for(int j = tidx; j < SAMPLM_NR; j += BDIM_X) { __h_sh[j] = 0; } // __h_sh[tidy] <- python 'pmf' @@ -563,17 +544,17 @@ __device__ int get_direction_boot_d( __syncwarp(WMASK); #if 0 if (!threadIdx.y && threadIdx.x == 0) { - for(int j = 0; j < samplm_nr; j++) { + for(int j = 0; j < SAMPLM_NR; j++) { printf("pmf[%d]: %f\n", j, __h_sh[tidy][j]); } } //return; #endif - const REAL_T abs_pmf_thr = PMF_THRESHOLD_P*max_d(samplm_nr, __h_sh, REAL_MIN); + const REAL_T abs_pmf_thr = PMF_THRESHOLD_P*max_d(SAMPLM_NR, __h_sh, REAL_MIN); __syncwarp(WMASK); #pragma unroll - for(int j = tidx; j < samplm_nr; j += BDIM_X) { + for(int j = tidx; j < SAMPLM_NR; j += BDIM_X) { const REAL_T __v = __h_sh[j]; if (__v < abs_pmf_thr) { __h_sh[j] = 0; @@ -583,7 +564,7 @@ __device__ int get_direction_boot_d( #if 0 if (!threadIdx.y && threadIdx.x == 0) { printf("abs_pmf_thr: %f\n", abs_pmf_thr); - for(int j = 0; j < samplm_nr; j++) { + for(int j = 0; j < SAMPLM_NR; j++) { printf("pmfNORM[%d]: %f\n", j, __h_sh[tidy][j]); } } @@ -602,11 +583,7 @@ __device__ int get_direction_boot_d( BDIM_Y>(__h_sh, dirs, sphere_vertices, sphere_edges, - num_edges, - samplm_nr, - reinterpret_cast(__r_sh), // reuse __r_sh as shInd in func which is large enough - relative_peak_thres, - min_separation_angle); + reinterpret_cast(__r_sh)); // reuse __r_sh as shInd in func which is large enough if (NATTEMPTS == 1) { // init=True... return ndir; // and dirs; } else { // init=False... @@ -617,7 +594,7 @@ __device__ int get_direction_boot_d( } */ REAL3_T peak; - const int foundPeak = closest_peak_d(max_angle, dir, ndir, dirs, &peak); + const int foundPeak = closest_peak_d(dir, ndir, dirs, &peak); __syncwarp(WMASK); if (foundPeak) { if (tidx == 0) { @@ -637,17 +614,9 @@ template __global__ void getNumStreamlinesBoot_k( const ModelType model_type, - const REAL_T max_angle, const REAL_T min_signal, - const REAL_T relative_peak_thres, - const REAL_T min_separation_angle, - const long long rndSeed, const int nseed, const REAL3_T *__restrict__ seeds, - const int dimx, - const int dimy, - const int dimz, - const int dimt, const REAL_T *__restrict__ dataf, const REAL_T *__restrict__ H, const REAL_T *__restrict__ R, @@ -655,11 +624,9 @@ __global__ void getNumStreamlinesBoot_k( const REAL_T *__restrict__ delta_b, const REAL_T *__restrict__ delta_q, const int *__restrict__ b0s_mask, // change to int - const int samplm_nr, const REAL_T *__restrict__ sampling_matrix, const REAL3_T *__restrict__ sphere_vertices, const int2 *__restrict__ sphere_edges, - const int num_edges, REAL3_T *__restrict__ shDir0, int *slineOutOff) { @@ -674,13 +641,13 @@ __global__ void getNumStreamlinesBoot_k( REAL3_T seed = seeds[slid]; // seed = lin_mat*seed + offset - REAL3_T *__restrict__ __shDir = shDir0+slid*samplm_nr; + REAL3_T *__restrict__ __shDir = shDir0+slid*SAMPLM_NR; - // const int hr_side = dimt-1; + // const int hr_side = DIMT-1; curandStatePhilox4_32_10_t st; - //curand_init(rndSeed, slid + rndOffset, DIV_UP(hr_side, BDIM_X)*tidx, &st); // each thread uses DIV_UP(hr_side/BDIM_X) - curand_init(rndSeed, gid, 0, &st); // each thread uses DIV_UP(hr_side/BDIM_X) + //curand_init(RNG_SEED, slid + rndOffset, DIV_UP(hr_side, BDIM_X)*tidx, &st); // each thread uses DIV_UP(hr_side/BDIM_X) + curand_init(RNG_SEED, gid, 0, &st); // each thread uses DIV_UP(hr_side/BDIM_X) // elements of the same sequence // python: //directions = get_direction(None, dataf, dwi_mask, sphere, s, H, R, model, max_angle, @@ -699,12 +666,9 @@ __global__ void getNumStreamlinesBoot_k( 1, OPDT>( &st, - max_angle, min_signal, - relative_peak_thres, - min_separation_angle, MAKE_REAL3(0,0,0), - dimx, dimy, dimz, dimt, dataf, + dataf, b0s_mask /* !dwi_mask */, seed, H, R, @@ -714,11 +678,9 @@ __global__ void getNumStreamlinesBoot_k( // min_signal from global defines delta_nr, delta_b, delta_q, // fit_matrix - samplm_nr, sampling_matrix, sphere_vertices, sphere_edges, - num_edges, __shDir); break; case CSA: @@ -727,12 +689,9 @@ __global__ void getNumStreamlinesBoot_k( 1, CSA>( &st, - max_angle, min_signal, - relative_peak_thres, - min_separation_angle, MAKE_REAL3(0,0,0), - dimx, dimy, dimz, dimt, dataf, + dataf, b0s_mask /* !dwi_mask */, seed, H, R, @@ -742,11 +701,9 @@ __global__ void getNumStreamlinesBoot_k( // min_signal from global defines delta_nr, delta_b, delta_q, // fit_matrix - samplm_nr, sampling_matrix, sphere_vertices, sphere_edges, - num_edges, __shDir); break; default: @@ -768,24 +725,13 @@ template __device__ int tracker_boot_d( curandStatePhilox4_32_10_t *st, - const REAL_T max_angle, - const REAL_T tc_threshold, - const REAL_T step_size, - const REAL_T relative_peak_thres, - const REAL_T min_separation_angle, - REAL3_T seed, - REAL3_T first_step, - REAL3_T voxel_size, - const int dimx, - const int dimy, - const int dimz, - const int dimt, - const REAL_T *__restrict__ dataf, - const REAL_T *__restrict__ metric_map, - const int samplm_nr, - const REAL3_T *__restrict__ sphere_vertices, - const int2 *__restrict__ sphere_edges, - const int num_edges, + REAL3_T seed, + REAL3_T first_step, + REAL3_T voxel_size, + const REAL_T *__restrict__ dataf, + const cudaTextureObject_t *__restrict__ metric_map, + const REAL3_T *__restrict__ sphere_vertices, + const int2 *__restrict__ sphere_edges, /*BOOT specific params*/ const REAL_T min_signal, const int delta_nr, @@ -830,22 +776,17 @@ __device__ int tracker_boot_d( 5, MODEL_T>( st, - max_angle, min_signal, - relative_peak_thres, - min_separation_angle, direction, - dimx, dimy, dimz, dimt, dataf, + dataf, b0s_mask /* !dwi_mask */, point, H, R, delta_nr, delta_b, delta_q, // fit_matrix - samplm_nr, sampling_matrix, sphere_vertices, sphere_edges, - num_edges, __sh_new_dir + tidy); __syncwarp(WMASK); direction = __sh_new_dir[tidy]; @@ -855,16 +796,16 @@ __device__ int tracker_boot_d( break; } - point.x += (direction.x / voxel_size.x) * (step_size / step_frac); - point.y += (direction.y / voxel_size.y) * (step_size / step_frac); - point.z += (direction.z / voxel_size.z) * (step_size / step_frac); + point.x += (direction.x / voxel_size.x) * (STEP_SIZE / step_frac); + point.y += (direction.y / voxel_size.y) * (STEP_SIZE / step_frac); + point.z += (direction.z / voxel_size.z) * (STEP_SIZE / step_frac); if ((tidx == 0) && ((i % step_frac) == 0)){ streamline[i/step_frac] = point; } __syncwarp(WMASK); - tissue_class = check_point_d(tc_threshold, point, dimx, dimy, dimz, metric_map); + tissue_class = check_point_d(point, metric_map); if (tissue_class == ENDPOINT || tissue_class == INVALIDPOINT || @@ -889,25 +830,13 @@ template __global__ void genStreamlinesMergeBoot_k( - const REAL_T max_angle, - const REAL_T tc_threshold, - const REAL_T step_size, - const REAL_T relative_peak_thres, - const REAL_T min_separation_angle, - const long long rndSeed, const int rndOffset, const int nseed, const REAL3_T *__restrict__ seeds, - const int dimx, - const int dimy, - const int dimz, - const int dimt, const REAL_T *__restrict__ dataf, - const REAL_T *__restrict__ metric_map, - const int samplm_nr, + const cudaTextureObject_t *__restrict__ metric_map, const REAL3_T *__restrict__ sphere_vertices, const int2 *__restrict__ sphere_edges, - const int num_edges, /*BOOT specific params*/ const REAL_T min_signal, const int delta_nr, @@ -935,8 +864,8 @@ __global__ void genStreamlinesMergeBoot_k( curandStatePhilox4_32_10_t st; // const int gbid = blockIdx.y*gridDim.x + blockIdx.x; const size_t gid = blockIdx.x * blockDim.y * blockDim.x + blockDim.x * threadIdx.y + threadIdx.x; - //curand_init(rndSeed, slid+rndOffset, DIV_UP(hr_side, BDIM_X)*tidx, &st); // each thread uses DIV_UP(HR_SIDE/BDIM_X) - curand_init(rndSeed, gid+1, 0, &st); // each thread uses DIV_UP(hr_side/BDIM_X) + //curand_init(RNG_SEED, slid+rndOffset, DIV_UP(hr_side, BDIM_X)*tidx, &st); // each thread uses DIV_UP(HR_SIDE/BDIM_X) + curand_init(RNG_SEED, gid+1, 0, &st); // each thread uses DIV_UP(hr_side/BDIM_X) // elements of the same sequence if (slid >= nseed) { return; @@ -959,7 +888,7 @@ __global__ void genStreamlinesMergeBoot_k( int slineOff = slineOutOff[slid]; for(int i = 0; i < ndir; i++) { - REAL3_T first_step = shDir0[slid*samplm_nr + i]; + REAL3_T first_step = shDir0[slid*SAMPLM_NR + i]; REAL3_T *__restrict__ currSline = sline + slineOff*MAX_SLINE_LEN*2; @@ -977,20 +906,13 @@ __global__ void genStreamlinesMergeBoot_k( BDIM_Y, MODEL_T>( &st, - max_angle, - tc_threshold, - step_size, - relative_peak_thres, - min_separation_angle, seed, MAKE_REAL3(-first_step.x, -first_step.y, -first_step.z), MAKE_REAL3(1, 1, 1), - dimx, dimy, dimz, dimt, dataf, + dataf, metric_map, - samplm_nr, sphere_vertices, sphere_edges, - num_edges, min_signal, delta_nr, H, @@ -1016,20 +938,13 @@ __global__ void genStreamlinesMergeBoot_k( BDIM_Y, MODEL_T>( &st, - max_angle, - tc_threshold, - step_size, - relative_peak_thres, - min_separation_angle, seed, first_step, MAKE_REAL3(1, 1, 1), - dimx, dimy, dimz, dimt, dataf, + dataf, metric_map, - samplm_nr, sphere_vertices, sphere_edges, - num_edges, min_signal, delta_nr, H, diff --git a/cuslines/cuda_c/generate_streamlines_cuda.cu b/cuslines/cuda_c/generate_streamlines_cuda.cu index f5629e0..68c32d0 100644 --- a/cuslines/cuda_c/generate_streamlines_cuda.cu +++ b/cuslines/cuda_c/generate_streamlines_cuda.cu @@ -36,6 +36,7 @@ #include "utils.cu" #include "tracking_helpers.cu" #include "boot.cu" +#include "ptt_init.cu" #include "ptt.cu" #define MAX_NUM_DIR (128) @@ -52,18 +53,10 @@ template __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, const REAL_T *__restrict__ pmf, - const REAL_T max_angle, - const REAL_T relative_peak_thres, - const REAL_T min_separation_angle, REAL3_T dir, - const int dimx, - const int dimy, - const int dimz, - const int dimt, const REAL3_T point, const REAL3_T *__restrict__ sphere_vertices, const int2 *__restrict__ sphere_edges, - const int num_edges, REAL3_T *__restrict__ dirs) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; @@ -71,14 +64,12 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32; const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1))); - const int n32dimt = ((dimt+31)/32)*32; - extern __shared__ REAL_T __sh[]; - REAL_T *__pmf_data_sh = __sh + tidy*n32dimt; + REAL_T *__pmf_data_sh = __sh + tidy*N32DIMT; // pmf = self.pmf_gen.get_pmf_c(&point[0], pmf) __syncwarp(WMASK); - const int rv = trilinear_interp_d(dimx, dimy, dimz, dimt, -1, pmf, point, __pmf_data_sh); + const int rv = trilinear_interp_d(pmf, point, __pmf_data_sh); __syncwarp(WMASK); if (rv != 0) { return 0; @@ -88,14 +79,14 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, // if pmf[i] > max_pmf: // max_pmf = pmf[i] // absolute_pmf_threshold = pmf_threshold * max_pmf - const REAL_T absolpmf_thresh = PMF_THRESHOLD_P * max_d(dimt, __pmf_data_sh, REAL_MIN); + const REAL_T absolpmf_thresh = PMF_THRESHOLD_P * max_d(DIMT, __pmf_data_sh, REAL_MIN); __syncwarp(WMASK); // for i in range(_len): // if pmf[i] < absolute_pmf_threshold: // pmf[i] = 0.0 #pragma unroll - for(int i = tidx; i < dimt; i += BDIM_X) { + for(int i = tidx; i < DIMT; i += BDIM_X) { if (__pmf_data_sh[i] < absolpmf_thresh) { __pmf_data_sh[i] = 0.0; } @@ -103,23 +94,19 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, __syncwarp(WMASK); if (IS_START) { - int *__shInd = reinterpret_cast(__sh + BDIM_Y*n32dimt) + tidy*n32dimt; + int *__shInd = reinterpret_cast(__sh + BDIM_Y*N32DIMT) + tidy*N32DIMT; return peak_directions_d(__pmf_data_sh, dirs, sphere_vertices, sphere_edges, - num_edges, - dimt, - __shInd, - relative_peak_thres, - min_separation_angle); + __shInd); } else { REAL_T __tmp; #ifdef DEBUG __syncwarp(WMASK); if (tidx == 0) { - printArray("__pmf_data_sh initial", 8, dimt, __pmf_data_sh); + printArray("__pmf_data_sh initial", 8, DIMT, __pmf_data_sh); printf("absolpmf_thresh %10.8f\n", absolpmf_thresh); printf("---> dir %10.8f, %10.8f, %10.8f\n", dir.x, dir.y, dir.z); printf("---> point %10.8f, %10.8f, %10.8f\n", point.x, point.y, point.z); @@ -132,7 +119,7 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, } __syncwarp(WMASK); if (tidx == 31) { - printArray("__pmf_data_sh initial l31", 8, dimt, __pmf_data_sh); + printArray("__pmf_data_sh initial l31", 8, DIMT, __pmf_data_sh); printf("absolpmf_thresh %10.8f l31\n", absolpmf_thresh); printf("---> dir %10.8f, %10.8f, %10.8f l31\n", dir.x, dir.y, dir.z); printf("---> point %10.8f, %10.8f, %10.8f l31\n", point.x, point.y, point.z); @@ -153,10 +140,10 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, // cos_sim = cos_sim * -1 // if cos_sim < self.cos_similarity: // pmf[i] = 0 - const REAL_T cos_similarity = COS(max_angle); + const REAL_T cos_similarity = COS(MAX_ANGLE); #pragma unroll - for(int i = tidx; i < dimt; i += BDIM_X) { + for(int i = tidx; i < DIMT; i += BDIM_X) { const REAL_T dot = dir.x*sphere_vertices[i].x+ dir.y*sphere_vertices[i].y+ dir.z*sphere_vertices[i].z; @@ -170,18 +157,18 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, #ifdef DEBUG __syncwarp(WMASK); if (tidx == 0) { - printArray("__pmf_data_sh after filtering", 8, dimt, __pmf_data_sh); + printArray("__pmf_data_sh after filtering", 8, DIMT, __pmf_data_sh); } __syncwarp(WMASK); #endif // cumsum(pmf, pmf, _len) - prefix_sum_sh_d(__pmf_data_sh, dimt); + prefix_sum_sh_d(__pmf_data_sh, DIMT); #ifdef DEBUG __syncwarp(WMASK); if (tidx == 0) { - printArray("__pmf_data_sh after cumsum", 8, dimt, __pmf_data_sh); + printArray("__pmf_data_sh after cumsum", 8, DIMT, __pmf_data_sh); } __syncwarp(WMASK); #endif @@ -189,7 +176,7 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, // last_cdf = pmf[_len - 1] // if last_cdf == 0: // return 1 - REAL_T last_cdf = __pmf_data_sh[dimt - 1]; + REAL_T last_cdf = __pmf_data_sh[DIMT - 1]; if (last_cdf == 0) { return 0; } @@ -202,7 +189,7 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, // Both these implementations work #if 1 int low = 0; - int high = dimt - 1; + int high = DIMT - 1; while ((high - low) >= BDIM_X) { const int mid = (low + high) / 2; if (__pmf_data_sh[mid] < selected_cdf) { @@ -215,10 +202,10 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, const int __msk = __ballot_sync(WMASK, __ballot); const int indProb = low + __ffs(__msk) - 1; #else - int indProb = dimt - 1; - for (int ii = 0; ii < dimt; ii+=BDIM_X) { + int indProb = DIMT - 1; + for (int ii = 0; ii < DIMT; ii+=BDIM_X) { int __is_greater = 0; - if (ii+tidx < dimt) { + if (ii+tidx < DIMT) { __is_greater = selected_cdf < __pmf_data_sh[ii+tidx]; } const int __msk = __ballot_sync(WMASK, __is_greater); @@ -234,7 +221,7 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, if (tidx == 0) { printf("last_cdf %10.8f\n", last_cdf); printf("selected_cdf %10.8f\n", selected_cdf); - printf("indProb %i out of %i\n", indProb, dimt); + printf("indProb %i out of %i\n", indProb, DIMT); } __syncwarp(WMASK); #endif @@ -269,19 +256,19 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, if (tidx == 0) { printf("last_cdf %10.8f\n", last_cdf); printf("selected_cdf %10.8f\n", selected_cdf); - printf("indProb %i out of %i\n", indProb, dimt); + printf("indProb %i out of %i\n", indProb, DIMT); } __syncwarp(WMASK); if (tidx == 15) { printf("last_cdf %10.8f l15\n", last_cdf); printf("selected_cdf %10.8f l15\n", selected_cdf); - printf("indProb %i out of %i l15\n", indProb, dimt); + printf("indProb %i out of %i l15\n", indProb, DIMT); } __syncwarp(WMASK); if (tidx == 31) { printf("last_cdf %10.8f l31\n", last_cdf); printf("selected_cdf %10.8f l31\n", selected_cdf); - printf("indProb %i out of %i l31\n", indProb, dimt); + printf("indProb %i out of %i l31\n", indProb, DIMT); } __syncwarp(WMASK); #endif @@ -292,28 +279,18 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, template __device__ int tracker_d(curandStatePhilox4_32_10_t *st, - const REAL_T max_angle, - const REAL_T tc_threshold, - const REAL_T step_size, - const REAL_T relative_peak_thres, - const REAL_T min_separation_angle, REAL3_T seed, REAL3_T first_step, REAL_T* ptt_frame, REAL3_T voxel_size, - const int dimx, - const int dimy, - const int dimz, - const int dimt, - const REAL_T *__restrict__ dataf, - const REAL_T *__restrict__ metric_map, - const int samplm_nr, - const REAL3_T *__restrict__ sphere_vertices, + DATA_T dataf, + const cudaTextureObject_t *__restrict__ metric_map, + DATA_T sphere_vertices, const int2 *__restrict__ sphere_edges, - const int num_edges, int *__restrict__ nsteps, REAL3_T *__restrict__ streamline) { @@ -350,15 +327,10 @@ __device__ int tracker_d(curandStatePhilox4_32_10_t *st, 0>( st, dataf, - max_angle, - relative_peak_thres, - min_separation_angle, direction, - dimx, dimy, dimz, dimt, point, - sphere_vertices, + (REAL3_T *__restrict__) sphere_vertices, sphere_edges, - num_edges, __sh_new_dir + tidy); } else if constexpr (MODEL_T == PTT) { ndir = get_direction_ptt_d( st, dataf, - max_angle, - step_size, direction, ptt_frame, - dimx, dimy, dimz, dimt, point, sphere_vertices, __sh_new_dir + tidy); @@ -389,9 +358,9 @@ __device__ int tracker_d(curandStatePhilox4_32_10_t *st, //return; #endif - point.x += (direction.x / voxel_size.x) * (step_size / step_frac); - point.y += (direction.y / voxel_size.y) * (step_size / step_frac); - point.z += (direction.z / voxel_size.z) * (step_size / step_frac); + point.x += (direction.x / voxel_size.x) * (STEP_SIZE / step_frac); + point.y += (direction.y / voxel_size.y) * (STEP_SIZE / step_frac); + point.z += (direction.z / voxel_size.z) * (STEP_SIZE / step_frac); if ((tidx == 0) && ((i % step_frac) == 0)){ streamline[i/step_frac] = point; @@ -403,12 +372,12 @@ __device__ int tracker_d(curandStatePhilox4_32_10_t *st, } __syncwarp(WMASK); - tissue_class = check_point_d(tc_threshold, point, dimx, dimy, dimz, metric_map); + tissue_class = check_point_d(point, metric_map); #if 0 __syncwarp(WMASK); if (tidx == 0) { - printf("step_size %f\n", step_size); + printf("step_size %f\n", STEP_SIZE); printf("direction %f, %f, %f\n", direction.x, direction.y, direction.z); printf("direction addr read %p, slid %i\n", __shDir, blockIdx.x*blockDim.y+threadIdx.y); printf("voxel_size %f, %f, %f\n", voxel_size.x, voxel_size.y, voxel_size.z); @@ -417,7 +386,7 @@ __device__ int tracker_d(curandStatePhilox4_32_10_t *st, } __syncwarp(WMASK); if (tidx == 15) { - printf("step_size %f l15\n", step_size); + printf("step_size %f l15\n", STEP_SIZE); printf("direction %f, %f, %f l15\n", direction.x, direction.y, direction.z); printf("direction addr read %p, slid %i l15\n", __shDir, blockIdx.x*blockDim.y+threadIdx.y); printf("voxel_size %f, %f, %f l15\n", voxel_size.x, voxel_size.y, voxel_size.z); @@ -426,7 +395,7 @@ __device__ int tracker_d(curandStatePhilox4_32_10_t *st, } __syncwarp(WMASK); if (tidx == 31) { - printf("step_size %f l31\n", step_size); + printf("step_size %f l31\n", STEP_SIZE); printf("direction %f, %f, %f l31\n", direction.x, direction.y, direction.z); printf("direction addr read %p, slid %i l31\n", __shDir, blockIdx.x*blockDim.y+threadIdx.y); printf("voxel_size %f, %f, %f l31\n", voxel_size.x, voxel_size.y, voxel_size.z); @@ -457,20 +426,11 @@ template -__global__ void getNumStreamlinesProb_k(const REAL_T max_angle, - const REAL_T relative_peak_thres, - const REAL_T min_separation_angle, - const long long rndSeed, - const int nseed, +__global__ void getNumStreamlinesProb_k(const int nseed, const REAL3_T *__restrict__ seeds, - const int dimx, - const int dimy, - const int dimz, - const int dimt, const REAL_T *__restrict__ dataf, const REAL3_T *__restrict__ sphere_vertices, const int2 *__restrict__ sphere_edges, - const int num_edges, REAL3_T *__restrict__ shDir0, int *slineOutOff) { @@ -482,24 +442,19 @@ __global__ void getNumStreamlinesProb_k(const REAL_T max_angle, return; } - REAL3_T *__restrict__ __shDir = shDir0+slid*dimt; + REAL3_T *__restrict__ __shDir = shDir0+slid*DIMT; curandStatePhilox4_32_10_t st; - curand_init(rndSeed, gid, 0, &st); + curand_init(RNG_SEED, gid, 0, &st); int ndir = get_direction_prob_d( &st, dataf, - max_angle, - relative_peak_thres, - min_separation_angle, MAKE_REAL3(0,0,0), - dimx, dimy, dimz, dimt, seeds[slid], sphere_vertices, sphere_edges, - num_edges, __shDir); if (tidx == 0) { slineOutOff[slid] = ndir; @@ -511,28 +466,17 @@ __global__ void getNumStreamlinesProb_k(const REAL_T max_angle, template __global__ void genStreamlinesMergeProb_k( - const REAL_T max_angle, - const REAL_T tc_threshold, - const REAL_T step_size, - const REAL_T relative_peak_thres, - const REAL_T min_separation_angle, - const long long rndSeed, const int rndOffset, const int nseed, const REAL3_T *__restrict__ seeds, - const int dimx, - const int dimy, - const int dimz, - const int dimt, - const REAL_T *__restrict__ dataf, - const REAL_T *__restrict__ metric_map, - const int samplm_nr, - const REAL3_T *__restrict__ sphere_vertices, + DATA_T dataf, + const cudaTextureObject_t *__restrict__ metric_map, + DATA_T sphere_vertices, const int2 *__restrict__ sphere_edges, - const int num_edges, const int *__restrict__ slineOutOff, REAL3_T *__restrict__ shDir0, int *__restrict__ slineSeed, @@ -549,13 +493,13 @@ __global__ void genStreamlinesMergeProb_k( __shared__ REAL_T frame_sh[((MODEL_T == PTT) ? BDIM_Y*18 : 1)]; // Only used by PTT, TODO: way to remove this in other cases REAL_T* __ptt_frame = frame_sh + tidy*18; - // const int hr_side = dimt-1; + // const int hr_side = DIMT-1; curandStatePhilox4_32_10_t st; // const int gbid = blockIdx.y*gridDim.x + blockIdx.x; const size_t gid = blockIdx.x * blockDim.y * blockDim.x + blockDim.x * threadIdx.y + threadIdx.x; - //curand_init(rndSeed, slid+rndOffset, DIV_UP(hr_side, BDIM_X)*tidx, &st); // each thread uses DIV_UP(HR_SIDE/BDIM_X) - curand_init(rndSeed, gid+1, 0, &st); // each thread uses DIV_UP(hr_side/BDIM_X) + //curand_init(RNG_SEED, slid+rndOffset, DIV_UP(hr_side, BDIM_X)*tidx, &st); // each thread uses DIV_UP(HR_SIDE/BDIM_X) + curand_init(RNG_SEED, gid+1, 0, &st); // each thread uses DIV_UP(hr_side/BDIM_X) // elements of the same sequence if (slid >= nseed) { return; @@ -578,7 +522,7 @@ __global__ void genStreamlinesMergeProb_k( int slineOff = slineOutOff[slid]; for(int i = 0; i < ndir; i++) { - REAL3_T first_step = shDir0[slid*samplm_nr + i]; + REAL3_T first_step = shDir0[slid*SAMPLM_NR + i]; REAL3_T *__restrict__ currSline = sline + slineOff*MAX_SLINE_LEN*2; @@ -594,13 +538,10 @@ __global__ void genStreamlinesMergeProb_k( if (MODEL_T == PTT) { if (!init_frame_ptt_d( &st, - dataf, - max_angle, - step_size, + (cudaTextureObject_t*) dataf, first_step, - dimx, dimy, dimz, dimt, seed, - sphere_vertices, + (cudaTextureObject_t*) sphere_vertices, __ptt_frame )) { // this fails rarely if (tidx == 0) { @@ -617,22 +558,16 @@ __global__ void genStreamlinesMergeProb_k( int stepsB; const int tissue_classB = tracker_d(&st, - max_angle, - tc_threshold, - step_size, - relative_peak_thres, - min_separation_angle, + MODEL_T, + DATA_T>(&st, seed, MAKE_REAL3(-first_step.x, -first_step.y, -first_step.z), __ptt_frame, MAKE_REAL3(1, 1, 1), - dimx, dimy, dimz, dimt, dataf, + dataf, metric_map, - samplm_nr, sphere_vertices, sphere_edges, - num_edges, &stepsB, currSline); //if (tidx == 0) { @@ -651,22 +586,16 @@ __global__ void genStreamlinesMergeProb_k( int stepsF; const int tissue_classF = tracker_d(&st, - max_angle, - tc_threshold, - step_size, - relative_peak_thres, - min_separation_angle, + MODEL_T, + DATA_T>(&st, seed, first_step, __ptt_frame + 9, MAKE_REAL3(1, 1, 1), - dimx, dimy, dimz, dimt, dataf, + dataf, metric_map, - samplm_nr, sphere_vertices, sphere_edges, - num_edges, &stepsF, currSline + stepsB-1); if (tidx == 0) { diff --git a/cuslines/cuda_c/globals.h b/cuslines/cuda_c/globals.h index 71bcd73..505f5cc 100644 --- a/cuslines/cuda_c/globals.h +++ b/cuslines/cuda_c/globals.h @@ -29,8 +29,6 @@ #ifndef __GLOBALS_H__ #define __GLOBALS_H__ -#define REAL_SIZE 4 - #if REAL_SIZE == 4 #define REAL float @@ -68,14 +66,6 @@ #define ACOS acos #endif -// TODO: half this in when WMGMI seeding -#define MAX_SLINE_LEN (501) -#define PMF_THRESHOLD_P ((REAL)0.05) - -#define THR_X_BL (64) -#define THR_X_SL (32) - -#define MAX_SLINES_PER_SEED (10) #define MIN(x,y) (((x)<(y))?(x):(y)) #define MAX(x,y) (((x)>(y))?(x):(y)) @@ -83,14 +73,6 @@ #define DIV_UP(a,b) (((a)+((b)-1))/(b)) -#define EXCESS_ALLOC_FACT 2 - -#define NORM_EPS ((REAL)1e-8) - -#if 0 - #define DEBUG -#endif - enum ModelType { OPDT = 0, CSA = 1, diff --git a/cuslines/cuda_c/ptt.cu b/cuslines/cuda_c/ptt.cu index 5684272..b3caa21 100644 --- a/cuslines/cuda_c/ptt.cu +++ b/cuslines/cuda_c/ptt.cu @@ -1,82 +1,46 @@ -template -__device__ __forceinline__ void norm3_d(REAL_T *num, int fail_ind) { - const REAL_T scale = SQRT(num[0] * num[0] + num[1] * num[1] + num[2] * num[2]); - - if (scale > NORM_EPS) { - num[0] /= scale; - num[1] /= scale; - num[2] /= scale; +template +__device__ __forceinline__ void norm3_d(float *num) { + const float inv_scale = rnorm3df(num[0], num[1], num[2]); + + if (isfinite(inv_scale) && inv_scale > 0.0f && inv_scale < PTT_INV_NORM_EPS) { + num[0] *= inv_scale; + num[1] *= inv_scale; + num[2] *= inv_scale; } else { num[0] = num[1] = num[2] = 0; - num[fail_ind] = 1.0; // this can happen randomly during propogation, though is exceedingly rare + num[FAIL_IND] = 1.0; // this can happen randomly during propogation, though is exceedingly rare } } -template -__device__ __forceinline__ void crossnorm3_d(REAL_T *dest, const REAL_T *src1, const REAL_T *src2, int fail_ind) { +template +__device__ __forceinline__ void crossnorm3_d(float *dest, const float *src1, const float *src2) { dest[0] = src1[1] * src2[2] - src1[2] * src2[1]; dest[1] = src1[2] * src2[0] - src1[0] * src2[2]; dest[2] = src1[0] * src2[1] - src1[1] * src2[0]; - norm3_d(dest, fail_ind); + norm3_d(dest); } -template -__device__ REAL_T interp4_d(const REAL3_T pos, const REAL_T* frame, const REAL_T *__restrict__ pmf, - const int dimx, const int dimy, const int dimz, const int dimt, - const REAL3_T *__restrict__ odf_sphere_vertices) { - const int tidx = threadIdx.x; - - const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32; - const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1))); - - int closest_odf_idx = 0; - REAL_T __max_cos = REAL_T(0); - - #pragma unroll - for (int ii = tidx; ii < dimt; ii+= BDIM_X) { // TODO: I need to think about better ways of parallelizing this - REAL_T cos_sim = FABS( - odf_sphere_vertices[ii].x * frame[0] \ - + odf_sphere_vertices[ii].y * frame[1] \ - + odf_sphere_vertices[ii].z * frame[2]); - if (cos_sim > __max_cos) { - __max_cos = cos_sim; - closest_odf_idx = ii; - } - } - __syncwarp(WMASK); - - #pragma unroll - for(int i = BDIM_X/2; i; i /= 2) { - const REAL_T __tmp = __shfl_xor_sync(WMASK, __max_cos, i, BDIM_X); - const int __tmp_idx = __shfl_xor_sync(WMASK, closest_odf_idx, i, BDIM_X); - if (__tmp > __max_cos || - (__tmp == __max_cos && __tmp_idx < closest_odf_idx)) { - __max_cos = __tmp; - closest_odf_idx = __tmp_idx; - } - } - __syncwarp(WMASK); - -#if 0 - if (closest_odf_idx >= dimt || closest_odf_idx < 0) { - printf("Error: closest_odf_idx out of bounds: %d (dimt: %d)\n", closest_odf_idx, dimt); - } -#endif - - // TODO: maybe this should be texture memory, I am not so sure - const int rv = trilinear_interp_d(dimx, dimy, dimz, dimt, closest_odf_idx, pmf, pos, &__max_cos); - - if (rv != 0) { - return 0; // No support - } else { - return __max_cos; - } +__device__ float interp4_d(const float3 pos, const float* frame, + const cudaTextureObject_t *__restrict__ pmf, + const cudaTextureObject_t *__restrict__ sphere_vertices_lut) { + float3 uvw = { // Map from -1,1 to 0,1 for texture lookup + fmaf(frame[0], 0.5f, 0.5f), + fmaf(frame[1], 0.5f, 0.5f), + fmaf(frame[2], 0.5f, 0.5f) + }; + const int odf_idx = static_cast(tex3D(*sphere_vertices_lut, uvw.z, uvw.y, uvw.x)); + + const int grid_col = odf_idx & WIDTH_MASK; + const int grid_row = odf_idx >> LOG2_WIDTH; + + const float x_query = (float)(grid_col * DIMX) + pos.x; + const float y_query = (float)(grid_row * DIMY) + pos.y; + return tex3D(*pmf, x_query, y_query, pos.z); } -template -__device__ void prepare_propagator_d(REAL_T k1, REAL_T k2, REAL_T arclength, - REAL_T *propagator) { +__device__ void prepare_propagator_d(float k1, float k2, float arclength, + float *propagator) { if ((FABS(k1) < K_SMALL) && (FABS(k2) < K_SMALL)) { propagator[0] = arclength; propagator[1] = 0; @@ -94,10 +58,10 @@ __device__ void prepare_propagator_d(REAL_T k1, REAL_T k2, REAL_T arclength, if (FABS(k2) < K_SMALL) { k2 = K_SMALL; } - const REAL_T k = SQRT(k1*k1+k2*k2); - const REAL_T sinkt = SIN(k*arclength); - const REAL_T coskt = COS(k*arclength); - const REAL_T kk = 1/(k*k); + const float k = SQRT(k1*k1+k2*k2); + const float sinkt = SIN(k*arclength); + const float coskt = COS(k*arclength); + const float kk = 1/(k*k); propagator[0] = sinkt/k; propagator[1] = k1*(1-coskt)*kk; @@ -111,26 +75,25 @@ __device__ void prepare_propagator_d(REAL_T k1, REAL_T k2, REAL_T arclength, } } -template -__device__ void random_normal(curandStatePhilox4_32_10_t *st, REAL_T* probing_frame) { +__device__ void random_normal(curandStatePhilox4_32_10_t *st, float* probing_frame) { probing_frame[3] = curand_normal(st); probing_frame[4] = curand_normal(st); probing_frame[5] = curand_normal(st); - REAL_T dot = probing_frame[3]*probing_frame[0] + float dot = probing_frame[3]*probing_frame[0] + probing_frame[4]*probing_frame[1] + probing_frame[5]*probing_frame[2]; probing_frame[3] -= dot*probing_frame[0]; probing_frame[4] -= dot*probing_frame[1]; probing_frame[5] -= dot*probing_frame[2]; - REAL_T n2 = probing_frame[3]*probing_frame[3] + float n2 = probing_frame[3]*probing_frame[3] + probing_frame[4]*probing_frame[4] + probing_frame[5]*probing_frame[5]; - if (n2 < NORM_EPS) { - REAL_T abs_x = FABS(probing_frame[0]); - REAL_T abs_y = FABS(probing_frame[1]); - REAL_T abs_z = FABS(probing_frame[2]); + if (n2 < PTT_NORM_EPS) { + float abs_x = FABS(probing_frame[0]); + float abs_y = FABS(probing_frame[1]); + float abs_z = FABS(probing_frame[2]); if (abs_x <= abs_y && abs_x <= abs_z) { probing_frame[3] = 0.0; @@ -150,19 +113,19 @@ __device__ void random_normal(curandStatePhilox4_32_10_t *st, REAL_T* probing_fr } } -template -__device__ void get_probing_frame_d(const REAL_T* frame, curandStatePhilox4_32_10_t *st, REAL_T* probing_frame) { +template +__device__ void get_probing_frame_d(const float* frame, curandStatePhilox4_32_10_t *st, float* probing_frame) { if (IS_INIT) { for (int ii = 0; ii < 3; ii++) { // tangent probing_frame[ii] = frame[ii]; } - norm3_d(probing_frame, 0); + norm3_d<0>(probing_frame); random_normal(st, probing_frame); - norm3_d(probing_frame + 3, 1); // norm + norm3_d<1>(probing_frame + 3); // norm // calculate binorm - crossnorm3_d(probing_frame + 2*3, probing_frame, probing_frame + 3, 2); // binorm + crossnorm3_d<2>(probing_frame + 2*3, probing_frame, probing_frame + 3); // binorm } else { for (int ii = 0; ii < 9; ii++) { probing_frame[ii] = frame[ii]; @@ -170,9 +133,8 @@ __device__ void get_probing_frame_d(const REAL_T* frame, curandStatePhilox4_32_1 } } -template -__device__ void propagate_frame_d(REAL_T* propagator, REAL_T* frame, REAL_T* direc) { - REAL_T __tmp[3]; +__device__ void propagate_frame_d(float* propagator, float* frame, float* direc) { + float __tmp[3]; for (int ii = 0; ii < 3; ii++) { direc[ii] = propagator[0]*frame[ii] + propagator[1]*frame[3+ii] + propagator[2]*frame[6+ii]; @@ -180,61 +142,46 @@ __device__ void propagate_frame_d(REAL_T* propagator, REAL_T* frame, REAL_T* dir frame[2*3 + ii] = propagator[6]*frame[ii] + propagator[7]*frame[3+ii] + propagator[8]*frame[6+ii]; } - norm3_d(__tmp, 0); // normalize tangent - crossnorm3_d(frame + 3, frame + 2*3, __tmp, 1); // calc normal - crossnorm3_d(frame + 2*3, __tmp, frame + 3, 2); // calculate binorm from tangent, norm + norm3_d<0>(__tmp); // normalize tangent + crossnorm3_d<1>(frame + 3, frame + 2*3, __tmp); // calc normal + crossnorm3_d<2>(frame + 2*3, __tmp, frame + 3); // calculate binorm from tangent, norm for (int ii = 0; ii < 3; ii++) { frame[ii] = __tmp[ii]; } } -template -__device__ REAL_T calculate_data_support_d(REAL_T support, - const REAL3_T pos, const REAL_T *__restrict__ pmf, - const int dimx, const int dimy, const int dimz, const int dimt, - const REAL_T probe_step_size, - const REAL_T absolpmf_thresh, - const REAL3_T *__restrict__ odf_sphere_vertices, - REAL_T* probing_prop_sh, - REAL_T* direc_sh, - REAL3_T* probing_pos_sh, - REAL_T* k1_sh, REAL_T* k2_sh, - REAL_T* probing_frame_sh) { - const int tidx = threadIdx.x; - - const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32; - const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1))); - - if (tidx == 0) { - prepare_propagator_d( - *k1_sh, *k2_sh, - probe_step_size, probing_prop_sh); - probing_pos_sh->x = pos.x; - probing_pos_sh->y = pos.y; - probing_pos_sh->z = pos.z; - } - __syncwarp(WMASK); +__device__ float calculate_data_support_d(float support, + const float3 pos, const cudaTextureObject_t *__restrict__ pmf, + const cudaTextureObject_t *__restrict__ sphere_vertices_lut, + float k1, float k2, + float* probing_frame) { + + float probing_prop[9]; + float direc[3]; + float3 probing_pos; + prepare_propagator_d( + k1, k2, + PROBE_STEP_SIZE, probing_prop); + probing_pos.x = pos.x; + probing_pos.y = pos.y; + probing_pos.z = pos.z; for (int ii = 0; ii < PROBE_QUALITY; ii++) { // we spend about 2/3 of our time in this loop when doing PTT - if (tidx == 0) { - propagate_frame_d( - probing_prop_sh, - probing_frame_sh, - direc_sh); - - probing_pos_sh->x += direc_sh[0]; - probing_pos_sh->y += direc_sh[1]; - probing_pos_sh->z += direc_sh[2]; - } - __syncwarp(WMASK); + propagate_frame_d( + probing_prop, + probing_frame, + direc); + + probing_pos.x += direc[0]; + probing_pos.y += direc[1]; + probing_pos.z += direc[2]; - const REAL_T fod_amp = interp4_d( // This is the most expensive call - *probing_pos_sh, probing_frame_sh, pmf, - dimx, dimy, dimz, dimt, - odf_sphere_vertices); + const float fod_amp = interp4_d( // This is the most expensive call + probing_pos, probing_frame, pmf, + sphere_vertices_lut); - if (!ALLOW_WEAK_LINK && (fod_amp < absolpmf_thresh)) { + if (!ALLOW_WEAK_LINK && (fod_amp < PMF_THRESHOLD_P)) { return 0; } support += fod_amp; @@ -244,20 +191,15 @@ __device__ REAL_T calculate_data_support_d(REAL_T support, template + bool IS_INIT> __device__ int get_direction_ptt_d( curandStatePhilox4_32_10_t *st, - const REAL_T *__restrict__ pmf, - const REAL_T max_angle, - const REAL_T step_size, - REAL3_T dir, - REAL_T *__frame_sh, - const int dimx, const int dimy, const int dimz, const int dimt, - REAL3_T pos, - const REAL3_T *__restrict__ odf_sphere_vertices, - REAL3_T *__restrict__ dirs) { + const cudaTextureObject_t *__restrict__ pmf, + float3 dir, + float *__frame_sh, + float3 pos, + const cudaTextureObject_t *__restrict__ sphere_vertices_lut, + float3 *__restrict__ dirs) { // Aydogan DB, Shi Y. Parallel Transport Tractography. IEEE Trans // Med Imaging. 2021 Feb;40(2):635-647. doi: 10.1109/TMI.2020.3034038. // Epub 2021 Feb 2. PMID: 33104507; PMCID: PMC7931442. @@ -272,39 +214,13 @@ __device__ int get_direction_ptt_d( const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32; const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1))); - __shared__ REAL_T face_cdf_sh[BDIM_Y*DISC_FACE_CNT]; - __shared__ REAL_T vert_pdf_sh[BDIM_Y*DISC_VERT_CNT]; + __shared__ float face_cdf_sh[BDIM_Y*DISC_FACE_CNT]; + __shared__ float vert_pdf_sh[BDIM_Y*DISC_VERT_CNT]; - __shared__ REAL_T probing_frame_sh[BDIM_Y*9]; - __shared__ REAL_T k1_probe_sh[BDIM_Y]; - __shared__ REAL_T k2_probe_sh[BDIM_Y]; + float *__face_cdf_sh = face_cdf_sh + tidy*DISC_FACE_CNT; + float *__vert_pdf_sh = vert_pdf_sh + tidy*DISC_VERT_CNT; - __shared__ REAL_T probing_prop_sh[BDIM_Y*9]; - __shared__ REAL_T direc_sh[BDIM_Y*3]; - __shared__ REAL3_T probing_pos_sh[BDIM_Y]; - - REAL_T *__face_cdf_sh = face_cdf_sh + tidy*DISC_FACE_CNT; - REAL_T *__vert_pdf_sh = vert_pdf_sh + tidy*DISC_VERT_CNT; - - REAL_T *__probing_frame_sh = probing_frame_sh + tidy*9; - REAL_T *__k1_probe_sh = k1_probe_sh + tidy; - REAL_T *__k2_probe_sh = k2_probe_sh + tidy; - - REAL_T *__probing_prop_sh = probing_prop_sh + tidy*9; - REAL_T *__direc_sh = direc_sh + tidy*3; - REAL3_T *__probing_pos_sh = probing_pos_sh + tidy; - - const REAL_T probe_step_size = ((step_size / PROBE_FRAC) / (PROBE_QUALITY - 1)); - const REAL_T max_curvature = 2.0 * SIN(max_angle / 2.0) / (step_size / PROBE_FRAC); // This seems to work well - const REAL_T absolpmf_thresh = PMF_THRESHOLD_P * max_d(dimt, pmf, REAL_MIN); - -#if 0 - printf("absolpmf_thresh: %f, max_curvature: %f, probe_step_size: %f\n", absolpmf_thresh, max_curvature, probe_step_size); - printf("max_angle: %f\n", max_angle); - printf("step_size: %f\n", step_size); -#endif - - REAL_T __tmp; + float __tmp; __syncwarp(WMASK); if (IS_INIT) { @@ -315,31 +231,27 @@ __device__ int get_direction_ptt_d( } } - const REAL_T first_val = interp4_d( + const float first_val = interp4_d( pos, __frame_sh, pmf, - dimx, dimy, dimz, dimt, - odf_sphere_vertices); + sphere_vertices_lut); __syncwarp(WMASK); // Calculate __vert_pdf_sh - bool support_found = false; - for (int ii = 0; ii < DISC_VERT_CNT; ii++) { - if (tidx == 0) { - *__k1_probe_sh = DISC_VERT[ii*2] * max_curvature; - *__k2_probe_sh = DISC_VERT[ii*2+1] * max_curvature; - get_probing_frame_d(__frame_sh, st, __probing_frame_sh); - } - __syncwarp(WMASK); + float probing_frame[9]; + float k1_probe, k2_probe; + bool support_found = 0; + for (int ii = tidx; ii < DISC_VERT_CNT; ii += BDIM_X) { + k1_probe = DISC_VERT[ii*2] * MAX_CURVATURE; + k2_probe = DISC_VERT[ii*2+1] * MAX_CURVATURE; - const REAL_T this_support = calculate_data_support_d( + get_probing_frame_d(__frame_sh, st, probing_frame); + + const float this_support = calculate_data_support_d( first_val, - pos, pmf, dimx, dimy, dimz, dimt, - probe_step_size, - absolpmf_thresh, - odf_sphere_vertices, - __probing_prop_sh, __direc_sh, __probing_pos_sh, - __k1_probe_sh, __k2_probe_sh, - __probing_frame_sh); + pos, pmf, + sphere_vertices_lut, + k1_probe, k2_probe, + probing_frame); #if 0 if (threadIdx.y == 1 && ii == 0) { @@ -347,18 +259,15 @@ __device__ int get_direction_ptt_d( } #endif - if (this_support < PROBE_QUALITY * absolpmf_thresh) { - if (tidx == 0) { - __vert_pdf_sh[ii] = 0; - } + if (this_support < PROBE_QUALITY * PMF_THRESHOLD_P) { + __vert_pdf_sh[ii] = 0; } else { - if (tidx == 0) { - __vert_pdf_sh[ii] = this_support; - } + __vert_pdf_sh[ii] = this_support; support_found = 1; } } - if (support_found == 0) { + const int __msk = __ballot_sync(WMASK, support_found); + if (__msk == 0) { return 0; } @@ -380,7 +289,7 @@ __device__ int get_direction_ptt_d( for (int ii = tidx; ii < DISC_FACE_CNT; ii+=BDIM_X) { bool all_verts_valid = 1; for (int jj = 0; jj < 3; jj++) { - REAL_T vert_val = __vert_pdf_sh[DISC_FACE[ii*3 + jj]]; + float vert_val = __vert_pdf_sh[DISC_FACE[ii*3 + jj]]; if (vert_val == 0) { all_verts_valid = IS_INIT; // On init, even go with faces that are not fully supported } @@ -402,7 +311,7 @@ __device__ int get_direction_ptt_d( // Prefix sum __face_cdf_sh and return 0 if all 0 prefix_sum_sh_d(__face_cdf_sh, DISC_FACE_CNT); - REAL_T last_cdf = __face_cdf_sh[DISC_FACE_CNT - 1]; + float last_cdf = __face_cdf_sh[DISC_FACE_CNT - 1]; if (last_cdf == 0) { return 0; @@ -417,109 +326,100 @@ __device__ int get_direction_ptt_d( #endif // Sample random valid faces randomly - for (int ii = 0; ii < TRIES_PER_REJECTION_SAMPLING; ii++) { - if (tidx == 0) { - REAL_T r1 = curand_uniform(st); - REAL_T r2 = curand_uniform(st); - if (r1 + r2 > 1) { - r1 = 1 - r1; - r2 = 1 - r2; - } - - __tmp = curand_uniform(st) * last_cdf; - int jj; - for (jj = 0; jj < DISC_FACE_CNT; jj++) { // TODO: parallelize this - if (__face_cdf_sh[jj] >= __tmp) - break; - } - - const REAL_T vx0 = max_curvature * DISC_VERT[DISC_FACE[jj*3]*2]; - const REAL_T vx1 = max_curvature * DISC_VERT[DISC_FACE[jj*3+1]*2]; - const REAL_T vx2 = max_curvature * DISC_VERT[DISC_FACE[jj*3+2]*2]; + float r1, r2; + for (int ii = 0; ii < TRIES_PER_REJECTION_SAMPLING / BDIM_X; ii++) { + r1 = curand_uniform(st); + r2 = curand_uniform(st); + if (r1 + r2 > 1) { + r1 = 1 - r1; + r2 = 1 - r2; + } + + __tmp = curand_uniform(st) * last_cdf; + int jj; + for (jj = 0; jj < DISC_FACE_CNT - 1; jj++) { + if (__face_cdf_sh[jj] >= __tmp) + break; + } + + const float vx0 = MAX_CURVATURE * DISC_VERT[DISC_FACE[jj*3]*2]; + const float vx1 = MAX_CURVATURE * DISC_VERT[DISC_FACE[jj*3+1]*2]; + const float vx2 = MAX_CURVATURE * DISC_VERT[DISC_FACE[jj*3+2]*2]; + + const float vy0 = MAX_CURVATURE * DISC_VERT[DISC_FACE[jj*3]*2 + 1]; + const float vy1 = MAX_CURVATURE * DISC_VERT[DISC_FACE[jj*3+1]*2 + 1]; + const float vy2 = MAX_CURVATURE * DISC_VERT[DISC_FACE[jj*3+2]*2 + 1]; + + k1_probe = vx0 + r1 * (vx1 - vx0) + r2 * (vx2 - vx0); + k2_probe = vy0 + r1 * (vy1 - vy0) + r2 * (vy2 - vy0); + + get_probing_frame_d(__frame_sh, st, probing_frame); + + const float this_support = calculate_data_support_d(first_val, + pos, pmf, + sphere_vertices_lut, + k1_probe, k2_probe, + probing_frame); - const REAL_T vy0 = max_curvature * DISC_VERT[DISC_FACE[jj*3]*2 + 1]; - const REAL_T vy1 = max_curvature * DISC_VERT[DISC_FACE[jj*3+1]*2 + 1]; - const REAL_T vy2 = max_curvature * DISC_VERT[DISC_FACE[jj*3+2]*2 + 1]; - *__k1_probe_sh = vx0 + r1 * (vx1 - vx0) + r2 * (vx2 - vx0); - *__k2_probe_sh = vy0 + r1 * (vy1 - vy0) + r2 * (vy2 - vy0); - get_probing_frame_d(__frame_sh, st, __probing_frame_sh); - } - __syncwarp(WMASK); - - const REAL_T this_support = calculate_data_support_d( - first_val, - pos, pmf, dimx, dimy, dimz, dimt, - probe_step_size, - absolpmf_thresh, - odf_sphere_vertices, - __probing_prop_sh, __direc_sh, __probing_pos_sh, - __k1_probe_sh, __k2_probe_sh, - __probing_frame_sh); __syncwarp(WMASK); - if (this_support < PROBE_QUALITY * absolpmf_thresh) { - continue; + int winning_lane = -1; // -1 indicates nobody won + int __msk = __ballot_sync(WMASK, this_support >= PROBE_QUALITY * PMF_THRESHOLD_P); + if (__msk != 0) { + winning_lane = __ffs(__msk) - 1; } + if (winning_lane != -1) { + if (tidx == winning_lane) { + if (IS_INIT) { + dirs[0] = dir; + } else { + float __prop[9]; + float __dir[3]; + prepare_propagator_d(k1_probe, k2_probe, STEP_SIZE/STEP_FRAC, __prop); + get_probing_frame_d<0>(__frame_sh, st, probing_frame); + propagate_frame_d(__prop, probing_frame, __dir); + norm3_d<0>(__dir); // this will be scaled by the generic stepping code + dirs[0] = MAKE_REAL3(__dir[0], __dir[1], __dir[2]); + } - if (tidx == 0) { - if (IS_INIT) { - dirs[0] = dir; - } else { - // propagate, but only 1/STEP_FRAC of a step - prepare_propagator_d( - *__k1_probe_sh, *__k2_probe_sh, - step_size/STEP_FRAC, __probing_prop_sh); - get_probing_frame_d<0>(__frame_sh, st, __probing_frame_sh); - propagate_frame_d(__probing_prop_sh, __probing_frame_sh, __direc_sh); - norm3_d(__direc_sh, 0); // this will be scaled by the generic stepping code - dirs[0] = MAKE_REAL3(__direc_sh[0], __direc_sh[1], __direc_sh[2]); + for (int jj = 0; jj < 9; jj++) { + __frame_sh[jj] = probing_frame[jj]; + } } + __syncwarp(WMASK); + return 1; } - - if (tidx < 9) { - __frame_sh[tidx] = __probing_frame_sh[tidx]; - } - __syncwarp(WMASK); - return 1; } return 0; } template + int BDIM_Y> __device__ bool init_frame_ptt_d( curandStatePhilox4_32_10_t *st, - const REAL_T *__restrict__ pmf, - const REAL_T max_angle, - const REAL_T step_size, - REAL3_T first_step, - const int dimx, const int dimy, const int dimz, const int dimt, - REAL3_T seed, - const REAL3_T *__restrict__ sphere_vertices, - REAL_T* __frame) { + const cudaTextureObject_t *__restrict__ pmf, + float3 first_step, + float3 seed, + const cudaTextureObject_t *__restrict__ sphere_vertices_lut, + float* __frame) { const int tidx = threadIdx.x; const int lid = (threadIdx.y*BDIM_X + tidx) % 32; const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1))); bool init_norm_success; - REAL3_T tmp; + float3 tmp; // Here we probabilistic find a good intial normal for this initial direction init_norm_success = (bool) get_direction_ptt_d( st, pmf, - max_angle, - step_size, MAKE_REAL3(-first_step.x, -first_step.y, -first_step.z), __frame, - dimx, dimy, dimz, dimt, seed, - sphere_vertices, + sphere_vertices_lut, &tmp); __syncwarp(WMASK); @@ -528,13 +428,10 @@ __device__ bool init_frame_ptt_d( init_norm_success = (bool) get_direction_ptt_d( st, pmf, - max_angle, - step_size, MAKE_REAL3(first_step.x, first_step.y, first_step.z), __frame, - dimx, dimy, dimz, dimt, seed, - sphere_vertices, + sphere_vertices_lut, &tmp); __syncwarp(WMASK); diff --git a/cuslines/cuda_c/ptt.cuh b/cuslines/cuda_c/ptt.cuh index 9126250..329f9c7 100644 --- a/cuslines/cuda_c/ptt.cuh +++ b/cuslines/cuda_c/ptt.cuh @@ -5,13 +5,14 @@ #include "globals.h" #define STEP_FRAC (20) // divides output step size (usually 0.5) into this many internal steps -#define PROBE_FRAC (2) // divides output step size (usually 0.5) to find probe length -#define PROBE_QUALITY (4) // Number of probing steps #define SAMPLING_QUALITY (2) // can be 2-7 -#define ALLOW_WEAK_LINK (0) +#define ALLOW_WEAK_LINK (1) #define TRIES_PER_REJECTION_SAMPLING (1024) #define K_SMALL ((REAL) 0.0001) +#define PTT_NORM_EPS static_cast(1e-8) +#define PTT_INV_NORM_EPS static_cast(1e8) + #if SAMPLING_QUALITY == 2 #define DISC_VERT_CNT DISC_2_VERT_CNT #define DISC_FACE_CNT DISC_2_FACE_CNT diff --git a/cuslines/cuda_c/ptt_init.cu b/cuslines/cuda_c/ptt_init.cu new file mode 100644 index 0000000..9d45a9e --- /dev/null +++ b/cuslines/cuda_c/ptt_init.cu @@ -0,0 +1,63 @@ +template +__global__ void getNumStreamlinesPtt_k( const int nseed, + const REAL3_T *__restrict__ seeds, + const cudaTextureObject_t *__restrict__ pmf, + const REAL3_T *__restrict__ sphere_vertices, + const int2 *__restrict__ sphere_edges, + REAL3_T *__restrict__ shDir0, + int *slineOutOff) { + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + + const int slid = blockIdx.x*blockDim.y + threadIdx.y; + const size_t gid = blockIdx.x * blockDim.y * blockDim.x + blockDim.x * threadIdx.y + threadIdx.x; + + const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32; + const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1))); + + if (slid >= nseed) { + return; + } + + REAL3_T *__restrict__ __shDir = shDir0+slid*DIMT; + curandStatePhilox4_32_10_t st; + curand_init(RNG_SEED, gid, 0, &st); + + extern __shared__ REAL_T __sh[]; + REAL_T *__pmf_data_sh = __sh + tidy*N32DIMT; + + REAL3_T point = seeds[slid]; + + #pragma unroll + for (int i = tidx; i < DIMT; i += BDIM_X) { + const int grid_col = i & WIDTH_MASK; + const int grid_row = i >> LOG2_WIDTH; + + const REAL_T x_query = (REAL_T)(grid_col * DIMX) + point.x; + const REAL_T y_query = (REAL_T)(grid_row * DIMY) + point.y; + __pmf_data_sh[i] = tex3D(*pmf, x_query, y_query, point.z); + if (__pmf_data_sh[i] < PMF_THRESHOLD_P) { + __pmf_data_sh[i] = 0.0; + } + } + __syncwarp(WMASK); + + int *__shInd = reinterpret_cast(__sh + BDIM_Y*N32DIMT) + tidy*N32DIMT; + int ndir = peak_directions_d< + BDIM_X, + BDIM_Y>(__pmf_data_sh, + __shDir, + sphere_vertices, + sphere_edges, + __shInd); + + if (tidx == 0) { + slineOutOff[slid] = ndir; + } + + return; +} diff --git a/cuslines/cuda_c/tracking_helpers.cu b/cuslines/cuda_c/tracking_helpers.cu index 21d5f67..b9d9e20 100644 --- a/cuslines/cuda_c/tracking_helpers.cu +++ b/cuslines/cuda_c/tracking_helpers.cu @@ -1,43 +1,18 @@ using namespace cuwsort; -template -__device__ REAL_T interpolation_helper_d(const REAL_T*__restrict__ dataf, const REAL_T wgh[3][2], const long long coo[3][2], int dimy, int dimz, int dimt, int t) { - REAL_T __tmp = 0; - #pragma unroll - for (int i = 0; i < 2; i++) { - #pragma unroll - for (int j = 0; j < 2; j++) { - #pragma unroll - for (int k = 0; k < 2; k++) { - __tmp += wgh[0][i] * wgh[1][j] * wgh[2][k] * - dataf[coo[0][i] * dimy * dimz * dimt + - coo[1][j] * dimz * dimt + - coo[2][k] * dimt + - t]; - } - } - } - return __tmp; -} - template -__device__ int trilinear_interp_d(const int dimx, - const int dimy, - const int dimz, - const int dimt, - int dimt_idx, // If -1, get all - const REAL_T *__restrict__ dataf, +__device__ int trilinear_interp_d(const REAL_T *__restrict__ dataf, const REAL3_T point, REAL_T *__restrict__ __vox_data) { const REAL_T HALF = static_cast(0.5); // all thr compute the same here - if (point.x < -HALF || point.x+HALF >= dimx || - point.y < -HALF || point.y+HALF >= dimy || - point.z < -HALF || point.z+HALF >= dimz) { + if (point.x < -HALF || point.x+HALF >= DIMX || + point.y < -HALF || point.y+HALF >= DIMY || + point.z < -HALF || point.z+HALF >= DIMZ) { return -1; } @@ -53,66 +28,48 @@ __device__ int trilinear_interp_d(const int dimx, wgh[0][1] = point.x - fl.x; wgh[0][0] = ONE-wgh[0][1]; coo[0][0] = MAX(0, fl.x); - coo[0][1] = MIN(dimx-1, coo[0][0]+1); + coo[0][1] = MIN(DIMX-1, coo[0][0]+1); wgh[1][1] = point.y - fl.y; wgh[1][0] = ONE-wgh[1][1]; coo[1][0] = MAX(0, fl.y); - coo[1][1] = MIN(dimy-1, coo[1][0]+1); + coo[1][1] = MIN(DIMY-1, coo[1][0]+1); wgh[2][1] = point.z - fl.z; wgh[2][0] = ONE-wgh[2][1]; coo[2][0] = MAX(0, fl.z); - coo[2][1] = MIN(dimz-1, coo[2][0]+1); - - if (dimt_idx == -1) { - for (int t = threadIdx.x; t < dimt; t += BDIM_X) { - __vox_data[t] = interpolation_helper_d(dataf, wgh, coo, dimy, dimz, dimt, t); + coo[2][1] = MIN(DIMZ-1, coo[2][0]+1); + + for (int t = threadIdx.x; t < DIMT; t += BDIM_X) { + __vox_data[t] = 0; + #pragma unroll + for (int i = 0; i < 2; i++) { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int k = 0; k < 2; k++) { + __vox_data[t] += wgh[0][i] * wgh[1][j] * wgh[2][k] * + dataf[coo[0][i] * DIMY * DIMZ * DIMT + + coo[1][j] * DIMZ * DIMT + + coo[2][k] * DIMT + + t]; + } + } } - } else { - *__vox_data = interpolation_helper_d(dataf, wgh, coo, dimy, dimz, dimt, dimt_idx); } - - // if (threadIdx.x == 0) { - // printf("point: %f, %f, %f\n", point.x, point.y, point.z); - // printf("dimt_idx: %d\n", dimt_idx); - // // for(int i = 0; i < dimt; i++) { - // // printf("__vox_data[%d]: %f\n", i, __vox_data[i]); - // // } - // } return 0; } -template -__device__ int check_point_d(const REAL_T tc_threshold, - const REAL3_T point, - const int dimx, - const int dimy, - const int dimz, - const REAL_T *__restrict__ metric_map) { - - const int tidy = threadIdx.y; - - const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32; - const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1))); +template +__device__ int check_point_d(const REAL3_T point, + const cudaTextureObject_t *__restrict__ metric_map) { + float val = tex3D(*metric_map, (float) point.z, (float) point.y, (float) point.x); - __shared__ REAL_T __shInterpOut[BDIM_Y]; + if (val == -1.0f) { + return OUTSIDEIMAGE; + } - const int rv = trilinear_interp_d(dimx, dimy, dimz, 1, 0, metric_map, point, __shInterpOut+tidy); - __syncwarp(WMASK); -#if 0 - if (threadIdx.y == 1 && threadIdx.x == 0) { - printf("__shInterpOut[tidy]: %f, TC_THRESHOLD_P: %f\n", __shInterpOut[tidy], TC_THRESHOLD_P); - } -#endif - if (rv != 0) { - return OUTSIDEIMAGE; - } - //return (__shInterpOut[tidy] > TC_THRESHOLD_P) ? TRACKPOINT : ENDPOINT; - return (__shInterpOut[tidy] > tc_threshold) ? TRACKPOINT : ENDPOINT; + return (val > TC_THRESHOLD) ? TRACKPOINT : ENDPOINT; } template(samplm_nr, odf, REAL_MAX); + REAL_T odf_min = min_d(SAMPLM_NR, odf, REAL_MAX); odf_min = MAX(0, odf_min); __syncwarp(WMASK); @@ -152,8 +105,8 @@ __device__ int peak_directions_d(const REAL_T *__restrict__ odf, // selecting only the indices corrisponding to maxima Ms // such that M-odf_min >= relative_peak_thres //#pragma unroll - for(int j = 0; j < num_edges; j += BDIM_X) { - if (j+tidx < num_edges) { + for(int j = 0; j < NUM_EDGES; j += BDIM_X) { + if (j+tidx < NUM_EDGES) { const int u_ind = sphere_edges[j+tidx].x; const int v_ind = sphere_edges[j+tidx].y; @@ -179,7 +132,7 @@ __device__ int peak_directions_d(const REAL_T *__restrict__ odf, } __syncwarp(WMASK); - const REAL_T compThres = relative_peak_thres*max_mask_transl_d(samplm_nr, __shInd, odf, -odf_min, REAL_MIN); + const REAL_T compThres = RELATIVE_PEAK_THRESH*max_mask_transl_d(SAMPLM_NR, __shInd, odf, -odf_min, REAL_MIN); #if 1 /* if (!tidy && !tidx) { @@ -193,9 +146,9 @@ __device__ int peak_directions_d(const REAL_T *__restrict__ odf, // compact indices of positive values to the right int n = 0; - for(int j = 0; j < samplm_nr; j += BDIM_X) { + for(int j = 0; j < SAMPLM_NR; j += BDIM_X) { - const int __v = (j+tidx < samplm_nr) ? __shInd[j+tidx] : -1; + const int __v = (j+tidx < SAMPLM_NR) ? __shInd[j+tidx] : -1; const int __keep = (__v > 0) && ((odf[j+tidx]-odf_min) >= compThres); const int __msk = __ballot_sync(WMASK, __keep); @@ -242,7 +195,7 @@ __device__ int peak_directions_d(const REAL_T *__restrict__ odf, // remove_similar_vertices() // PRELIMINARY INEFFICIENT, SINGLE TH, IMPLEMENTATION if (tidx == 0) { - const REAL_T cos_similarity = COS(min_separation_angle); + const REAL_T cos_similarity = COS(MIN_SEPARATION_ANGLE); dirs[0] = sphere_vertices[__shInd[0]]; diff --git a/cuslines/cuda_python/__init__.py b/cuslines/cuda_python/__init__.py index fd05c1e..674a6a0 100644 --- a/cuslines/cuda_python/__init__.py +++ b/cuslines/cuda_python/__init__.py @@ -1,9 +1,9 @@ -from .cu_tractography import GPUTracker from .cu_direction_getters import ( + BootDirectionGetter, ProbDirectionGetter, PttDirectionGetter, - BootDirectionGetter, ) +from .cu_tractography import GPUTracker __all__ = [ "GPUTracker", diff --git a/cuslines/cuda_python/_globals.py b/cuslines/cuda_python/_globals.py index c19368e..2358104 100644 --- a/cuslines/cuda_python/_globals.py +++ b/cuslines/cuda_python/_globals.py @@ -1,10 +1,8 @@ -# AUTO-GENERATED FROM globals.h — DO NOT EDIT - EXCESS_ALLOC_FACT = 2 MAX_SLINES_PER_SEED = 10 -MAX_SLINE_LEN = 501 -NORM_EPS = 1e-08 +MAX_SLINE_LEN = 501 # TODO: half this in when WMGMI seeding, and/or this needs to be set dynamically by user PMF_THRESHOLD_P = 0.05 REAL_SIZE = 4 THR_X_BL = 64 THR_X_SL = 32 +NORM_EPS = 1e-12 diff --git a/cuslines/cuda_python/cu_direction_getters.py b/cuslines/cuda_python/cu_direction_getters.py index 36d2c66..b376222 100644 --- a/cuslines/cuda_python/cu_direction_getters.py +++ b/cuslines/cuda_python/cu_direction_getters.py @@ -1,26 +1,32 @@ -import numpy as np -from abc import ABC, abstractmethod import logging +import math +from abc import ABC, abstractmethod from importlib.resources import files from time import time -from cuslines.boot_utils import prepare_opdt, prepare_csa - -from cuda.core import Device, LaunchConfig, Program, launch, ProgramOptions -from cuda.pathfinder import find_nvidia_header_directory -from cuda.cccl import get_include_paths -from cuda.bindings import runtime, driver +import numpy as np +from cuda.bindings import driver, runtime from cuda.bindings.runtime import cudaMemcpyKind +from cuda.cccl import get_include_paths +from cuda.core import Device, LaunchConfig, Program, ProgramOptions, launch +from cuda.pathfinder import find_nvidia_header_directory +from scipy.spatial import KDTree +from cuslines.boot_utils import prepare_csa, prepare_opdt from cuslines.cuda_python.cutils import ( - REAL_SIZE, + BLOCK_Y, + REAL3_DTYPE_AS_STR, REAL_DTYPE, REAL_DTYPE_AS_STR, - REAL3_DTYPE_AS_STR, - checkCudaErrors, ModelType, + checkCudaErrors, + EXCESS_ALLOC_FACT, + MAX_SLINES_PER_SEED, + MAX_SLINE_LEN, + PMF_THRESHOLD_P, + REAL_SIZE, + THR_X_BL, THR_X_SL, - BLOCK_Y, ) logger = logging.getLogger("GPUStreamlines") @@ -35,13 +41,16 @@ def getNumStreamlines(self, n, nseeds_gpu, block, grid, sp): def generateStreamlines(self, n, nseeds_gpu, block, grid, sp): pass + def set_macros(self, gpu_tracker): + pass + def allocate_on_gpu(self, n): pass def deallocate_on_gpu(self, n): pass - def compile_program(self, debug: bool = False): + def compile_program(self, gpu_tracker, debug: bool = False): start_time = time() logger.info("Compiling GPUStreamlines") @@ -57,11 +66,43 @@ def compile_program(self, debug: bool = False): else: program_opts = {"ptxas_options": ["-O3"]} + n32dimt = ((gpu_tracker.dimt + 31) // 32) * 32 + self.macros = { + "__NVRTC__": None, + "DIMX": str(int(gpu_tracker.dimx)), + "DIMY": str(int(gpu_tracker.dimy)), + "DIMZ": str(int(gpu_tracker.dimz)), + "DIMT": str(int(gpu_tracker.dimt)), + "N32DIMT": str(int(n32dimt)), + "STEP_SIZE": str(float(gpu_tracker.step_size)), + "MAX_ANGLE": str(float(gpu_tracker.max_angle)), + "TC_THRESHOLD": str(float(gpu_tracker.tc_threshold)), + "RELATIVE_PEAK_THRESH": str(float(gpu_tracker.relative_peak_thresh)), + "MIN_SEPARATION_ANGLE": str(float(gpu_tracker.min_separation_angle)), + "RNG_SEED": str(int(gpu_tracker.rng_seed)), + "SAMPLM_NR": str(int(gpu_tracker.samplm_nr)), + "NUM_EDGES": str(int(gpu_tracker.nedges)), + "EXCESS_ALLOC_FACT": str(int(EXCESS_ALLOC_FACT)), + "MAX_SLINES_PER_SEED": str(int(MAX_SLINES_PER_SEED)), + "MAX_SLINE_LEN": str(int(MAX_SLINE_LEN)), + "PMF_THRESHOLD_P": str(float(PMF_THRESHOLD_P)), + "REAL_SIZE": str(int(REAL_SIZE)), + "THR_X_BL": str(int(THR_X_BL)), + "THR_X_SL": str(int(THR_X_SL)), + } + self.set_macros(gpu_tracker) + optional_macros = ["log2_width", "width_mask", "probe_step_size", "max_curvature", "probe_quality", "probe_frac"] + for name in optional_macros: + if name.upper() not in self.macros: + self.macros[name.upper()] = "0" + if debug: + self.macros["DEBUG"] = None + program_options = ProgramOptions( name="cuslines", use_fast_math=True, std="c++17", - define_macro="__NVRTC__", + define_macro=[f"{k}={v}" if v is not None else k for k, v in self.macros.items()], include_path=[ str(cuslines_cuda), find_nvidia_header_directory("cudart"), @@ -132,19 +173,36 @@ def __init__( self.getnum_kernel_name = f"getNumStreamlinesBoot_k<{THR_X_SL},{BLOCK_Y},{REAL_DTYPE_AS_STR},{REAL3_DTYPE_AS_STR}>" self.genstreamlines_kernel_name = f"genStreamlinesMergeBoot_k<{THR_X_SL},{BLOCK_Y},{model_type.upper()},{REAL_DTYPE_AS_STR},{REAL3_DTYPE_AS_STR}>" - self.compile_program() @classmethod - def from_dipy_opdt(cls, gtab, sphere, sh_order_max=6, full_basis=False, - sh_lambda=0.006, min_signal=1): - return cls(**prepare_opdt(gtab, sphere, sh_order_max, full_basis, - sh_lambda, min_signal)) + def from_dipy_opdt( + cls, + gtab, + sphere, + sh_order_max=6, + full_basis=False, + sh_lambda=0.006, + min_signal=1, + ): + return cls( + **prepare_opdt( + gtab, sphere, sh_order_max, full_basis, sh_lambda, min_signal + ) + ) @classmethod - def from_dipy_csa(cls, gtab, sphere, sh_order_max=6, full_basis=False, - sh_lambda=0.006, min_signal=1): - return cls(**prepare_csa(gtab, sphere, sh_order_max, full_basis, - sh_lambda, min_signal)) + def from_dipy_csa( + cls, + gtab, + sphere, + sh_order_max=6, + full_basis=False, + sh_lambda=0.006, + min_signal=1, + ): + return cls( + **prepare_csa(gtab, sphere, sh_order_max, full_basis, sh_lambda, min_signal) + ) def allocate_on_gpu(self, n): self.H_d.append(checkCudaErrors(runtime.cudaMalloc(REAL_SIZE * self.H.size))) @@ -247,17 +305,9 @@ def getNumStreamlines(self, n, nseeds_gpu, block, grid, sp): config, ker, self.model_type, - sp.gpu_tracker.max_angle, self.min_signal, - sp.gpu_tracker.relative_peak_thresh, - sp.gpu_tracker.min_separation_angle, - sp.gpu_tracker.rng_seed, nseeds_gpu, sp.seeds_d[n], - sp.gpu_tracker.dimx, - sp.gpu_tracker.dimy, - sp.gpu_tracker.dimz, - sp.gpu_tracker.dimt, sp.gpu_tracker.dataf_d[n], self.H_d[n], self.R_d[n], @@ -265,11 +315,9 @@ def getNumStreamlines(self, n, nseeds_gpu, block, grid, sp): self.delta_b_d[n], self.delta_q_d[n], self.b0s_mask_d[n], - sp.gpu_tracker.samplm_nr, self.sampling_matrix_d[n], sp.gpu_tracker.sphere_vertices_d[n], sp.gpu_tracker.sphere_edges_d[n], - sp.gpu_tracker.nedges, sp.shDirTemp0_d[n], sp.slinesOffs_d[n], ) @@ -283,25 +331,13 @@ def generateStreamlines(self, n, nseeds_gpu, block, grid, sp): sp.gpu_tracker.streams[n], config, ker, - sp.gpu_tracker.max_angle, - sp.gpu_tracker.tc_threshold, - sp.gpu_tracker.step_size, - sp.gpu_tracker.relative_peak_thresh, - sp.gpu_tracker.min_separation_angle, - sp.gpu_tracker.rng_seed, sp.gpu_tracker.rng_offset + n * nseeds_gpu, nseeds_gpu, sp.seeds_d[n], - sp.gpu_tracker.dimx, - sp.gpu_tracker.dimy, - sp.gpu_tracker.dimz, - sp.gpu_tracker.dimt, sp.gpu_tracker.dataf_d[n], - sp.gpu_tracker.metric_map_d[n], - sp.gpu_tracker.samplm_nr, + sp.gpu_tracker.metric_map_d[n].getPtr(), sp.gpu_tracker.sphere_vertices_d[n], sp.gpu_tracker.sphere_edges_d[n], - sp.gpu_tracker.nedges, self.min_signal, self.delta_nr, self.H_d[n], @@ -322,8 +358,7 @@ class ProbDirectionGetter(GPUDirectionGetter): def __init__(self): checkCudaErrors(driver.cuInit(0)) self.getnum_kernel_name = f"getNumStreamlinesProb_k<{THR_X_SL},{BLOCK_Y},{REAL_DTYPE_AS_STR},{REAL3_DTYPE_AS_STR}>" - self.genstreamlines_kernel_name = f"genStreamlinesMergeProb_k<{THR_X_SL},{BLOCK_Y},PROB,{REAL_DTYPE_AS_STR},{REAL3_DTYPE_AS_STR}>" - self.compile_program() + self.genstreamlines_kernel_name = f"genStreamlinesMergeProb_k<{THR_X_SL},{BLOCK_Y},PROB,const {REAL_DTYPE_AS_STR} *__restrict__,{REAL_DTYPE_AS_STR},{REAL3_DTYPE_AS_STR}>" def getNumStreamlines(self, n, nseeds_gpu, block, grid, sp): ker = self.module.get_kernel(self.getnum_kernel_name) @@ -333,59 +368,40 @@ def getNumStreamlines(self, n, nseeds_gpu, block, grid, sp): ) config = LaunchConfig(block=block, grid=grid, shmem_size=shared_memory) + if isinstance(sp.gpu_tracker.dataf_d[n], runtime.cudaTextureObject_t): + dataf_d_n = sp.gpu_tracker.dataf_d[n].getPtr() + else: + dataf_d_n = sp.gpu_tracker.dataf_d[n] + launch( sp.gpu_tracker.streams[n], config, ker, - sp.gpu_tracker.max_angle, - sp.gpu_tracker.relative_peak_thresh, - sp.gpu_tracker.min_separation_angle, - sp.gpu_tracker.rng_seed, nseeds_gpu, sp.seeds_d[n], - sp.gpu_tracker.dimx, - sp.gpu_tracker.dimy, - sp.gpu_tracker.dimz, - sp.gpu_tracker.dimt, - sp.gpu_tracker.dataf_d[n], + dataf_d_n, sp.gpu_tracker.sphere_vertices_d[n], sp.gpu_tracker.sphere_edges_d[n], - sp.gpu_tracker.nedges, sp.shDirTemp0_d[n], sp.slinesOffs_d[n], ) - def _shared_mem_bytes(self, sp): - return REAL_SIZE * BLOCK_Y * sp.gpu_tracker.n32dimt - def generateStreamlines(self, n, nseeds_gpu, block, grid, sp): ker = self.module.get_kernel(self.genstreamlines_kernel_name) - shared_memory = self._shared_mem_bytes(sp) + shared_memory = REAL_SIZE * BLOCK_Y * sp.gpu_tracker.n32dimt config = LaunchConfig(block=block, grid=grid, shmem_size=shared_memory) launch( sp.gpu_tracker.streams[n], config, ker, - sp.gpu_tracker.max_angle, - sp.gpu_tracker.tc_threshold, - sp.gpu_tracker.step_size, - sp.gpu_tracker.relative_peak_thresh, - sp.gpu_tracker.min_separation_angle, - sp.gpu_tracker.rng_seed, sp.gpu_tracker.rng_offset + n * nseeds_gpu, nseeds_gpu, sp.seeds_d[n], - sp.gpu_tracker.dimx, - sp.gpu_tracker.dimy, - sp.gpu_tracker.dimz, - sp.gpu_tracker.dimt, sp.gpu_tracker.dataf_d[n], - sp.gpu_tracker.metric_map_d[n], - sp.gpu_tracker.samplm_nr, + sp.gpu_tracker.metric_map_d[n].getPtr(), sp.gpu_tracker.sphere_vertices_d[n], sp.gpu_tracker.sphere_edges_d[n], - sp.gpu_tracker.nedges, sp.slinesOffs_d[n], sp.shDirTemp0_d[n], sp.slineSeed_d[n], @@ -395,11 +411,171 @@ def generateStreamlines(self, n, nseeds_gpu, block, grid, sp): class PttDirectionGetter(ProbDirectionGetter): - def __init__(self): + def __init__(self, odf_lut_res: int = 128, probe_frac: int = 2, probe_quality: int = 4): + """ + Use Parallel Transport Tractography + + Parameters + ---------- + odf_lut_res: int + Resolution of the ODF lookup table. + Default: 128 + probe_frac: int + Divides output step size (usually 0.5) to find probe length. + Default: 2 + probe_quality : int + Number of probing steps. + Default: 4 + """ checkCudaErrors(driver.cuInit(0)) - self.getnum_kernel_name = f"getNumStreamlinesProb_k<{THR_X_SL},{BLOCK_Y},{REAL_DTYPE_AS_STR},{REAL3_DTYPE_AS_STR}>" - self.genstreamlines_kernel_name = f"genStreamlinesMergeProb_k<{THR_X_SL},{BLOCK_Y},PTT,{REAL_DTYPE_AS_STR},{REAL3_DTYPE_AS_STR}>" - self.compile_program() + self.getnum_kernel_name = f"getNumStreamlinesPtt_k<{THR_X_SL},{BLOCK_Y},{REAL_DTYPE_AS_STR},{REAL3_DTYPE_AS_STR}>" + self.genstreamlines_kernel_name = f"genStreamlinesMergeProb_k<{THR_X_SL},{BLOCK_Y},PTT,const cudaTextureObject_t *__restrict__,{REAL_DTYPE_AS_STR},{REAL3_DTYPE_AS_STR}>" + self.odf_lut_res = odf_lut_res + self.sphere_vertices_lut_h = None + self.sphere_vertices_lut_d = [] + self.sphere_vertices_lut_array_d = [] + + self.probe_frac = probe_frac + self.probe_quality = probe_quality + + def set_macros(self, gpu_tracker): + self.macros["LOG2_WIDTH"] = str(int(self.log2_width)) + self.macros["WIDTH_MASK"] = str(int(self.width_mask)) + self.macros["PROBE_FRAC"] = str(float(self.probe_frac)) + self.macros["PROBE_QUALITY"] = str(float(self.probe_quality)) + self.macros["PROBE_STEP_SIZE"] = str(float(((gpu_tracker.step_size / self.probe_frac) / (self.probe_quality)))) + self.macros["MAX_CURVATURE"] = str(float(self.probe_frac * 2.0 * np.sin(gpu_tracker.max_angle / 2.0) / (gpu_tracker.step_size))) - def _shared_mem_bytes(self, sp): - return 0 + def allocate_on_gpu(self, n): + if REAL_SIZE != 4: + raise ValueError( + ("PTT on CUDA uses texture memory which only supports 32-bit floats") + ) + + channelDesc = checkCudaErrors( + runtime.cudaCreateChannelDesc( + 32, 0, 0, 0, runtime.cudaChannelFormatKind.cudaChannelFormatKindFloat + ) + ) + extent = runtime.make_cudaExtent( + self.odf_lut_res, self.odf_lut_res, self.odf_lut_res + ) + sphere_vertices_array = checkCudaErrors( + runtime.cudaMalloc3DArray(channelDesc, extent, 0) + ) + + copyParams = runtime.cudaMemcpy3DParms() + copyParams.srcPtr = runtime.make_cudaPitchedPtr( + self.sphere_vertices_lut_h.ctypes.data, + self.odf_lut_res * 4, + self.odf_lut_res, + self.odf_lut_res, + ) + + copyParams.dstArray = sphere_vertices_array + copyParams.extent = extent + copyParams.kind = cudaMemcpyKind.cudaMemcpyHostToDevice + checkCudaErrors(runtime.cudaMemcpy3D(copyParams)) + + resDesc = runtime.cudaResourceDesc() + resDesc.resType = runtime.cudaResourceType.cudaResourceTypeArray + resDesc.res.array.array = sphere_vertices_array + + texDesc = runtime.cudaTextureDesc() + texDesc.addressMode[0] = runtime.cudaTextureAddressMode.cudaAddressModeClamp + texDesc.addressMode[1] = runtime.cudaTextureAddressMode.cudaAddressModeClamp + texDesc.addressMode[2] = runtime.cudaTextureAddressMode.cudaAddressModeClamp + texDesc.filterMode = runtime.cudaTextureFilterMode.cudaFilterModePoint + texDesc.readMode = runtime.cudaTextureReadMode.cudaReadModeElementType + texDesc.normalizedCoords = 1 + + texObj = checkCudaErrors( + runtime.cudaCreateTextureObject(resDesc, texDesc, None) + ) + self.sphere_vertices_lut_d.append(texObj) + self.sphere_vertices_lut_array_d.append(sphere_vertices_array) + + def deallocate_on_gpu(self, n): + if self.sphere_vertices_lut_d[n]: + checkCudaErrors( + runtime.cudaDestroyTextureObject(self.sphere_vertices_lut_d[n]) + ) + if self.sphere_vertices_lut_array_d[n]: + checkCudaErrors(runtime.cudaFreeArray(self.sphere_vertices_lut_array_d[n])) + + def prepare_data(self, dataf, stop_map, stop_threshold, sphere_vertices): + dimx, dimy, dimz, dimt = dataf.shape + dataf = dataf.clip(min=0) + + # zeros outside of tracking mask helps with probing + dataf[stop_map < stop_threshold, :] = 0 + + # normalize ODFs to max of 1 + odf_sums = dataf.max(axis=3, keepdims=True) + nonzero_mask = odf_sums > 0 + np.divide(dataf, odf_sums, out=dataf, where=nonzero_mask) + + # This rearrangement is for cuda texture memory + # In particular, for texture memory, we want each dimension + # to be less than 65,535, so we tile t across x and y + # additionally, we then make the tiles in the x dim + # a power of 2 to ensure it is fast to calculate indices + # into the tiles + ideal_tiles_per_row = math.ceil(math.sqrt(dimt)) + self.log2_width = math.ceil(math.log2(ideal_tiles_per_row)) + tiles_per_row = 1 << self.log2_width + self.width_mask = tiles_per_row - 1 + tiles_per_col = math.ceil(dimt / tiles_per_row) + total_slots = tiles_per_row * tiles_per_col + if dimt < total_slots: + padding = np.zeros((dimx, dimy, dimz, total_slots - dimt), dtype=np.float32) + data_f_rearranged = np.concatenate([dataf, padding], axis=3) + else: + data_f_rearranged = dataf + + data_f_rearranged = data_f_rearranged.reshape(dimx, dimy, dimz, tiles_per_col, tiles_per_row) + data_f_rearranged = data_f_rearranged.transpose(2, 3, 1, 4, 0).reshape( + dimz, + tiles_per_col * dimy, + tiles_per_row * dimx + ) + data_f_rearranged = np.ascontiguousarray(data_f_rearranged, dtype=np.float32) + + # Generate a 3D LUT that maps each point in a 128x128x128 grid to + # the index of the closest sphere vertex + coords = np.linspace(-1, 1, self.odf_lut_res) + grid_x, grid_y, grid_z = np.meshgrid(coords, coords, coords, indexing="ij") + grid_points = np.stack([grid_x.ravel(), grid_y.ravel(), grid_z.ravel()], axis=1) + + tree = KDTree(sphere_vertices) + _, closest_indices = tree.query(grid_points) + lut = closest_indices.reshape( + (self.odf_lut_res, self.odf_lut_res, self.odf_lut_res) + ) + lut = np.ascontiguousarray(lut, dtype=np.float32) + self.sphere_vertices_lut_h = lut + + return data_f_rearranged + + def generateStreamlines(self, n, nseeds_gpu, block, grid, sp): + ker = self.module.get_kernel(self.genstreamlines_kernel_name) + shared_memory = 0 + config = LaunchConfig(block=block, grid=grid, shmem_size=shared_memory) + + launch( + sp.gpu_tracker.streams[n], + config, + ker, + sp.gpu_tracker.rng_offset + n * nseeds_gpu, + nseeds_gpu, + sp.seeds_d[n], + sp.gpu_tracker.dataf_d[n].getPtr(), + sp.gpu_tracker.metric_map_d[n].getPtr(), + self.sphere_vertices_lut_d[n].getPtr(), + sp.gpu_tracker.sphere_edges_d[n], + sp.slinesOffs_d[n], + sp.shDirTemp0_d[n], + sp.slineSeed_d[n], + sp.slineLen_d[n], + sp.sline_d[n], + ) diff --git a/cuslines/cuda_python/cu_propagate_seeds.py b/cuslines/cuda_python/cu_propagate_seeds.py index 79735f9..a7fc7bb 100644 --- a/cuslines/cuda_python/cu_propagate_seeds.py +++ b/cuslines/cuda_python/cu_propagate_seeds.py @@ -1,26 +1,25 @@ -import numpy as np -import math import gc +import logging +import math + +import numpy as np from cuda.bindings import runtime from cuda.bindings.runtime import cudaMemcpyKind - -from nibabel.streamlines.array_sequence import ArraySequence, MEGABYTE -import logging +from nibabel.streamlines.array_sequence import MEGABYTE, ArraySequence from cuslines.cuda_python.cutils import ( - REAL_SIZE, - REAL_DTYPE, - REAL3_DTYPE, - MAX_SLINE_LEN, + DEV_PTR, EXCESS_ALLOC_FACT, - THR_X_SL, + MAX_SLINE_LEN, + REAL3_DTYPE, + REAL_DTYPE, + REAL_SIZE, THR_X_BL, - DEV_PTR, - div_up, + THR_X_SL, checkCudaErrors, + div_up, ) - logger = logging.getLogger("GPUStreamlines") @@ -204,10 +203,6 @@ def _cleanup(self): self.nSlines_old = self.nSlines self.gpu_tracker.rng_offset += self.nseeds - # TODO: performance: better queuing/batching of seeds, - # if more performance needed, - # given exponential nature of streamlines - # May be better to do in cuda code directly def propagate(self, seeds): self.nseeds = len(seeds) self.nseeds_per_gpu = ( @@ -264,6 +259,4 @@ def _yield_slines(): return _yield_slines() def as_array_sequence(self): - return ArraySequence( - self.as_generator(), - self.get_buffer_size()) + return ArraySequence(self.as_generator(), self.get_buffer_size()) diff --git a/cuslines/cuda_python/cu_tractography.py b/cuslines/cuda_python/cu_tractography.py index 02e9d31..7a46590 100644 --- a/cuslines/cuda_python/cu_tractography.py +++ b/cuslines/cuda_python/cu_tractography.py @@ -1,29 +1,27 @@ -from cuda.bindings import runtime -from cuda.bindings.runtime import cudaMemcpyKind -# TODO: consider cuda core over cuda bindings +import logging +from math import radians import numpy as np +from cuda.bindings import runtime +from cuda.bindings.runtime import cudaMemcpyKind +from dipy.io.stateful_tractogram import Space, StatefulTractogram +from nibabel.streamlines.array_sequence import ArraySequence +from nibabel.streamlines.tractogram import Tractogram from tqdm import tqdm -import logging -from math import radians +from trx.trx_file_memmap import TrxFile -from cuslines.cuda_python.cutils import ( - REAL_SIZE, - REAL_DTYPE, - checkCudaErrors, -) from cuslines.cuda_python.cu_direction_getters import ( - GPUDirectionGetter, BootDirectionGetter, + GPUDirectionGetter, + PttDirectionGetter, ) from cuslines.cuda_python.cu_propagate_seeds import SeedBatchPropagator - -from trx.trx_file_memmap import TrxFile - -from nibabel.streamlines.tractogram import Tractogram -from nibabel.streamlines.array_sequence import ArraySequence, MEGABYTE - -from dipy.io.stateful_tractogram import Space, StatefulTractogram +from cuslines.cuda_python.cutils import ( + REAL_DTYPE, + REAL_SIZE, + checkCudaErrors, + allocate_texture, +) logger = logging.getLogger("GPUStreamlines") @@ -103,12 +101,21 @@ def __init__( Number of seeds to process in each chunk per GPU default: 25000 """ - self.dataf = np.ascontiguousarray(dataf, dtype=REAL_DTYPE) - self.metric_map = np.ascontiguousarray(stop_map, dtype=REAL_DTYPE) + self.dimx, self.dimy, self.dimz, self.dimt = dataf.shape + if hasattr(dg, "prepare_data"): + self.dataf = dg.prepare_data( + dataf, + stop_map, + stop_threshold, + sphere_vertices, + ) + else: + self.dataf = np.ascontiguousarray(dataf, dtype=REAL_DTYPE) + + self.metric_map = np.ascontiguousarray(stop_map, dtype=np.float32) self.sphere_vertices = np.ascontiguousarray(sphere_vertices, dtype=REAL_DTYPE) self.sphere_edges = np.ascontiguousarray(sphere_edges, dtype=np.int32) - self.dimx, self.dimy, self.dimz, self.dimt = dataf.shape self.nedges = int(sphere_edges.shape[0]) if isinstance(dg, BootDirectionGetter): self.samplm_nr = int(dg.sampling_matrix.shape[0]) @@ -137,7 +144,9 @@ def __init__( logger.info("Creating GPUTracker with %d GPUs...", self.ngpus) self.dataf_d = [] + self.dataf_array = [] self.metric_map_d = [] + self.metric_map_array = [] self.sphere_vertices_d = [] self.sphere_edges_d = [] @@ -145,11 +154,10 @@ def __init__( self.managed_data = [] self.seed_propagator = SeedBatchPropagator( - gpu_tracker=self, - minlen=min_pts, - maxlen=max_pts + gpu_tracker=self, minlen=min_pts, maxlen=max_pts ) self._allocated = False + self.dg.compile_program(self) def __enter__(self): self._allocate() @@ -170,38 +178,30 @@ def _allocate(self): for ii in range(self.ngpus): checkCudaErrors(runtime.cudaSetDevice(ii)) - # TODO: performance: dataf could be managed or texture memory instead? - self.dataf_d.append( - checkCudaErrors(runtime.cudaMalloc(REAL_SIZE * self.dataf.size)) - ) - self.metric_map_d.append( - checkCudaErrors(runtime.cudaMalloc(REAL_SIZE * self.metric_map.size)) - ) - self.sphere_vertices_d.append( - checkCudaErrors( - runtime.cudaMalloc(REAL_SIZE * self.sphere_vertices.size) + if isinstance(self.dg, PttDirectionGetter): + if REAL_SIZE != 4: + raise ValueError( + ("PTT on CUDA only supports 32-bit floats") + ) + dataf_d, dataf_array = allocate_texture(self.dataf) + self.dataf_d.append(dataf_d) + self.dataf_array.append(dataf_array) + else: + self.dataf_d.append( + checkCudaErrors(runtime.cudaMalloc(REAL_SIZE * self.dataf.size)) ) - ) - self.sphere_edges_d.append( checkCudaErrors( - runtime.cudaMalloc(np.int32().nbytes * self.sphere_edges.size) + runtime.cudaMemcpy( + self.dataf_d[ii], + self.dataf.ctypes.data, + REAL_SIZE * self.dataf.size, + cudaMemcpyKind.cudaMemcpyHostToDevice, + ) ) - ) - checkCudaErrors( - runtime.cudaMemcpy( - self.dataf_d[ii], - self.dataf.ctypes.data, - REAL_SIZE * self.dataf.size, - cudaMemcpyKind.cudaMemcpyHostToDevice, - ) - ) - checkCudaErrors( - runtime.cudaMemcpy( - self.metric_map_d[ii], - self.metric_map.ctypes.data, - REAL_SIZE * self.metric_map.size, - cudaMemcpyKind.cudaMemcpyHostToDevice, + self.sphere_vertices_d.append( + checkCudaErrors( + runtime.cudaMalloc(REAL_SIZE * self.sphere_vertices.size) ) ) checkCudaErrors( @@ -212,6 +212,16 @@ def _allocate(self): cudaMemcpyKind.cudaMemcpyHostToDevice, ) ) + + metric_map_d, metric_map_array = allocate_texture(self.metric_map, address_mode="border") + self.metric_map_d.append(metric_map_d) + self.metric_map_array.append(metric_map_array) + + self.sphere_edges_d.append( + checkCudaErrors( + runtime.cudaMalloc(np.int32().nbytes * self.sphere_edges.size) + ) + ) checkCudaErrors( runtime.cudaMemcpy( self.sphere_edges_d[ii], @@ -229,10 +239,18 @@ def __exit__(self, exc_type, exc, tb): for n in range(self.ngpus): checkCudaErrors(runtime.cudaSetDevice(n)) - if self.dataf_d[n]: - checkCudaErrors(runtime.cudaFree(self.dataf_d[n])) + if isinstance(self.dg, PttDirectionGetter): + if self.dataf_d[n]: + checkCudaErrors(runtime.cudaDestroyTextureObject(self.dataf_d[n])) + if self.dataf_array[n]: + checkCudaErrors(runtime.cudaFreeArray(self.dataf_array[n])) + else: + if self.dataf_d[n]: + checkCudaErrors(runtime.cudaFree(self.dataf_d[n])) if self.metric_map_d[n]: - checkCudaErrors(runtime.cudaFree(self.metric_map_d[n])) + checkCudaErrors(runtime.cudaDestroyTextureObject(self.metric_map_d[n])) + if self.metric_map_array[n]: + checkCudaErrors(runtime.cudaFreeArray(self.metric_map_array[n])) if self.sphere_vertices_d[n]: checkCudaErrors(runtime.cudaFree(self.sphere_vertices_d[n])) if self.sphere_edges_d[n]: @@ -276,16 +294,18 @@ def generate_trx(self, seeds, ref_img): n_sls_guess = sl_per_seed_guess * seeds.shape[0] # trx files use memory mapping - trx_reference = TrxFile( - reference=ref_img + trx_reference = TrxFile(reference=ref_img) + trx_reference.streamlines._data = trx_reference.streamlines._data.astype( + np.float32 + ) + trx_reference.streamlines._offsets = trx_reference.streamlines._offsets.astype( + np.uint64 ) - trx_reference.streamlines._data = trx_reference.streamlines._data.astype(np.float32) - trx_reference.streamlines._offsets = trx_reference.streamlines._offsets.astype(np.uint64) trx_file = TrxFile( nb_streamlines=n_sls_guess, nb_vertices=n_sls_guess * sl_len_guess, - init_as=trx_reference + init_as=trx_reference, ) offsets_idx = 0 sls_data_idx = 0 diff --git a/cuslines/cuda_python/cutils.py b/cuslines/cuda_python/cutils.py index db4115a..4ad733c 100644 --- a/cuslines/cuda_python/cutils.py +++ b/cuslines/cuda_python/cutils.py @@ -1,8 +1,9 @@ -from cuda.bindings import driver, nvrtc +from enum import IntEnum import numpy as np - -from enum import IntEnum +from cuda.bindings import driver, nvrtc +from cuda.bindings import runtime +from cuda.bindings.runtime import cudaMemcpyKind from cuslines.cuda_python._globals import * @@ -62,3 +63,54 @@ def checkCudaErrors(result): def div_up(a, b): return (a + b - 1) // b + + +def allocate_texture(data, address_mode="clamp"): + channelDesc = checkCudaErrors( + runtime.cudaCreateChannelDesc( + 32, 0, 0, 0, runtime.cudaChannelFormatKind.cudaChannelFormatKindFloat + ) + ) + + dim0, dim1, dim2 = data.shape + extent = runtime.make_cudaExtent(dim2, dim1, dim0) + dataf_array = checkCudaErrors(runtime.cudaMalloc3DArray(channelDesc, extent, 0)) + + copyParams = runtime.cudaMemcpy3DParms() + copyParams.srcPtr = runtime.make_cudaPitchedPtr( + data.ctypes.data, + dim2 * 4, + dim2, + dim1, + ) + + copyParams.dstArray = dataf_array + copyParams.extent = extent + copyParams.kind = cudaMemcpyKind.cudaMemcpyHostToDevice + checkCudaErrors(runtime.cudaMemcpy3D(copyParams)) + + resDesc = runtime.cudaResourceDesc() + resDesc.resType = runtime.cudaResourceType.cudaResourceTypeArray + resDesc.res.array.array = dataf_array + + texDesc = runtime.cudaTextureDesc() + if address_mode == "clamp": + address_mode = runtime.cudaTextureAddressMode.cudaAddressModeClamp + elif address_mode == "border": + address_mode = runtime.cudaTextureAddressMode.cudaAddressModeBorder + texDesc.borderColor[0] = -1.0; + texDesc.borderColor[1] = -1.0; + texDesc.borderColor[2] = -1.0; + else: + raise ValueError(f"Unsupported address_mode: {address_mode}") + texDesc.addressMode[0] = address_mode + texDesc.addressMode[1] = address_mode + texDesc.addressMode[2] = address_mode + texDesc.filterMode = runtime.cudaTextureFilterMode.cudaFilterModeLinear + texDesc.readMode = runtime.cudaTextureReadMode.cudaReadModeElementType + texDesc.normalizedCoords = 0 + + texObj = checkCudaErrors( + runtime.cudaCreateTextureObject(resDesc, texDesc, None) + ) + return texObj, dataf_array diff --git a/pyproject.toml b/pyproject.toml index e3b3461..1406cf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,33 @@ build-backend = "setuptools.build_meta" [tool.setuptools_scm] +[tool.ruff] +target-version = "py37" +include = [ + "cuslines/cuda_python/*.py"] +exclude = ["docs/source/conf.py", + "examples/"] + +[tool.ruff.lint] +select = [ + "F", + "E", + "C", + "W", + "B", + "I", +] +ignore = [ + "B905", + "C901", + "E203", + "F821", + "B021", + "C408", + "I001", + "B027", +] + [project] name = "cuslines" dynamic = ["version"] @@ -15,9 +42,16 @@ dependencies = [ "nibabel", "tqdm", "dipy", - "trx-python" + "trx-python", + "scipy", ] +[project.urls] +Homepage = "https://github.com/dipy/GPUStreamlines" + +[tool.setuptools.package-data] +cuslines = ["cuda_c/*", "metal_shaders/*", "wgsl_shaders/*"] + [project.optional-dependencies] cu13 = [ "nvidia-cuda-runtime", @@ -44,6 +78,10 @@ webgpu = [ "wgpu>=0.18", ] +dev = [ + "ruff>=0.14.10" +] + [tool.setuptools.packages.find] where = ["."] include = ["cuslines*"] diff --git a/run_gpu_streamlines.py b/run_gpu_streamlines.py index f4fb8cf..90600a3 100644 --- a/run_gpu_streamlines.py +++ b/run_gpu_streamlines.py @@ -28,32 +28,33 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import argparse +import os.path as op import random import time -import numpy as np - import dipy.reconst.dti as dti +import nibabel as nib +import numpy as np +from dipy.core.gradients import gradient_table, unique_bvals_magnitude +from dipy.data import default_sphere, get_fnames, read_stanford_pve_maps, small_sphere +from dipy.direction import ( + BootDirectionGetter as cpu_BootDirectionGetter, +) +from dipy.direction import ( + ProbabilisticDirectionGetter as cpu_ProbDirectionGetter, +) +from dipy.direction import ( + PTTDirectionGetter as cpu_PTTDirectionGetter, +) from dipy.io import read_bvals_bvecs from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.io.streamline import save_tractogram -from dipy.tracking import utils -from dipy.core.gradients import gradient_table, unique_bvals_magnitude -from dipy.data import default_sphere, small_sphere -from dipy.direction import ( - BootDirectionGetter as cpu_BootDirectionGetter, - ProbabilisticDirectionGetter as cpu_ProbDirectionGetter, - PTTDirectionGetter as cpu_PTTDirectionGetter) -from dipy.reconst.shm import OpdtModel, CsaOdfModel from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel, auto_response_ssst +from dipy.reconst.shm import CsaOdfModel, OpdtModel +from dipy.tracking import utils from dipy.tracking.local_tracking import LocalTracking from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion -from dipy.data import get_fnames -from dipy.data import read_stanford_pve_maps - -import nibabel as nib from nibabel.orientations import aff2axcodes - from trx.io import save as save_trx from cuslines import ( @@ -70,197 +71,336 @@ np.random.seed(0) random.seed(0) -#Get Gradient values + +# Get Gradient values def get_gtab(fbval, fbvec): bvals, bvecs = read_bvals_bvecs(fbval, fbvec) gtab = gradient_table(bvals=bvals, bvecs=bvecs) return gtab + def get_img(ep2_seq): img = nib.load(ep2_seq) return img + print("parsing arguments") parser = argparse.ArgumentParser() -parser.add_argument("nifti_file", nargs='?', default='hardi', help="path to the DWI nifti file") -parser.add_argument("bvals", nargs='?', default='hardi', help="path to the bvals") -parser.add_argument("bvecs", nargs='?', default='hardi', help="path to the bvecs") -parser.add_argument("mask_nifti", nargs='?', default='hardi', help="path to the mask file") -parser.add_argument("roi_nifti", nargs='?', default='hardi', help="path to the ROI file") -parser.add_argument("--device", type=str, default ='gpu', choices=['cpu', 'gpu', 'metal', 'webgpu'], help="Whether to use cpu, gpu (auto-detect), metal, or webgpu") -parser.add_argument("--sphere", type=str, default='default', choices=['default', 'small'], help="Which sphere to use for direction getting") -parser.add_argument("--output-prefix", type=str, default ='', help="path to the output file") -parser.add_argument("--chunk-size", type=int, default=25000, help="how many seeds to process per sweep, per GPU") -parser.add_argument("--nseeds", type=int, default=100000, help="how many seeds to process in total") -parser.add_argument("--ngpus", type=int, default=1, help="number of GPUs to use if using gpu") +parser.add_argument( + "nifti_file", nargs="?", default="hardi", help="path to the DWI nifti file" +) +parser.add_argument("bvals", nargs="?", default="hardi", help="path to the bvals") +parser.add_argument("bvecs", nargs="?", default="hardi", help="path to the bvecs") +parser.add_argument( + "mask_nifti", nargs="?", default="hardi", help="path to the mask file" +) +parser.add_argument( + "roi_nifti", nargs="?", default="hardi", help="path to the ROI file" +) +parser.add_argument( + "--device", + type=str, + default="gpu", + choices=["cpu", "gpu", "metal", "webgpu"], + help="Whether to use cpu, gpu (auto-detect), metal, or webgpu", +) +parser.add_argument( + "--sphere", + type=str, + default="default", + choices=["default", "small"], + help="Which sphere to use for direction getting", +) +parser.add_argument( + "--output-prefix", type=str, default="", help="path to the output file" +) +parser.add_argument( + "--chunk-size", + type=int, + default=25000, + help="how many seeds to process per sweep, per GPU", +) +parser.add_argument( + "--nseeds", type=int, default=100000, help="how many seeds to process in total" +) +parser.add_argument( + "--ngpus", type=int, default=1, help="number of GPUs to use if using gpu" +) parser.add_argument("--write-method", type=str, default="trk", help="Can be trx or trk") -parser.add_argument("--max-angle", type=float, default=60, help="max angle (in degrees)") +parser.add_argument( + "--max-angle", type=float, default=60, help="max angle (in degrees)" +) parser.add_argument("--min-signal", type=float, default=1.0, help="default: 1.0") parser.add_argument("--step-size", type=float, default=0.5, help="default: 0.5") -parser.add_argument("--sh-order",type=int,default=4,help="sh_order") -parser.add_argument("--fa-threshold",type=float,default=0.1,help="FA threshold") -parser.add_argument("--relative-peak-threshold",type=float,default=0.25,help="relative peak threshold") -parser.add_argument("--min-separation-angle",type=float,default=45,help="min separation angle (in degrees)") -parser.add_argument("--sm-lambda",type=float,default=0.006,help="smoothing lambda") -parser.add_argument("--model", type=str, default="default", choices=['default', 'opdt', 'csa', 'csd'], help="model to use") -parser.add_argument("--dg", type=str, default="boot", choices=['boot', 'prob', 'ptt'], help="direction getting scheme to use") +parser.add_argument("--sh-order", type=int, default=4, help="sh_order") +parser.add_argument("--fa-threshold", type=float, default=0.1, help="FA threshold") +parser.add_argument( + "--relative-peak-threshold", + type=float, + default=0.25, + help="relative peak threshold", +) +parser.add_argument( + "--min-separation-angle", + type=float, + default=45, + help="min separation angle (in degrees)", +) +parser.add_argument("--sm-lambda", type=float, default=0.006, help="smoothing lambda") +parser.add_argument( + "--model", + type=str, + default="default", + choices=["default", "opdt", "csa", "csd"], + help="model to use", +) +parser.add_argument( + "--dg", + type=str, + default="boot", + choices=["boot", "prob", "ptt"], + help="direction getting scheme to use", +) +parser.add_argument( + "--cache-dir", type=str, default="", help="cache directory for FA and ODFs" +) +parser.add_argument("--seed-seed", type=int, default=None, help="seed for seeding") args = parser.parse_args() if args.model == "default": - if args.dg == "boot": - args.model = "opdt" - else: - args.model = "csd" + if args.dg == "boot": + args.model = "opdt" + else: + args.model = "csd" if args.device == "metal": - if BACKEND != "metal": - raise RuntimeError("Metal backend requested but not available. " - "Install: pip install 'cuslines[metal]'") - if args.ngpus > 1: - print("WARNING: Metal backend supports only 1 GPU, ignoring --ngpus %d" % args.ngpus) - args.ngpus = 1 - args.device = "gpu" # use the GPU code path + if BACKEND != "metal": + raise RuntimeError( + "Metal backend requested but not available. " + "Install: pip install 'cuslines[metal]'" + ) + if args.ngpus > 1: + print( + "WARNING: Metal backend supports only 1 GPU, ignoring --ngpus %d" + % args.ngpus + ) + args.ngpus = 1 + args.device = "gpu" # use the GPU code path elif args.device == "webgpu": - try: - from cuslines.webgpu import ( - WebGPUTracker as GPUTracker, - WebGPUProbDirectionGetter as ProbDirectionGetter, - WebGPUPttDirectionGetter as PttDirectionGetter, - WebGPUBootDirectionGetter as BootDirectionGetter, - ) - except ImportError: - raise RuntimeError("WebGPU backend requested but not available. " - "Install: pip install 'cuslines[webgpu]'") - if args.ngpus > 1: - print("WARNING: WebGPU backend supports only 1 GPU, ignoring --ngpus %d" % args.ngpus) - args.ngpus = 1 - print("Using webgpu backend") - args.device = "gpu" # use the GPU code path + try: + from cuslines.webgpu import ( + WebGPUBootDirectionGetter as BootDirectionGetter, + ) + from cuslines.webgpu import ( + WebGPUProbDirectionGetter as ProbDirectionGetter, + ) + from cuslines.webgpu import ( + WebGPUPttDirectionGetter as PttDirectionGetter, + ) + from cuslines.webgpu import ( + WebGPUTracker as GPUTracker, + ) + except ImportError: + raise RuntimeError( + "WebGPU backend requested but not available. " + "Install: pip install 'cuslines[webgpu]'" + ) + if args.ngpus > 1: + print( + "WARNING: WebGPU backend supports only 1 GPU, ignoring --ngpus %d" + % args.ngpus + ) + args.ngpus = 1 + print("Using webgpu backend") + args.device = "gpu" # use the GPU code path elif args.device == "gpu": - print("Using %s backend" % BACKEND) + print("Using %s backend" % BACKEND) if args.device == "cpu" and args.write_method != "trk": - print("WARNING: only trk write method is implemented for cpu testing.") - write_method = "trk" + print("WARNING: only trk write method is implemented for cpu testing.") + write_method = "trk" else: - write_method = args.write_method + write_method = args.write_method -if 'hardi' in [args.nifti_file, args.bvals, args.bvecs, args.mask_nifti, args.roi_nifti]: - if not all(arg == 'hardi' for arg in [args.nifti_file, args.bvals, args.bvecs, args.mask_nifti, args.roi_nifti]): - raise ValueError("If any of the arguments is 'hardi', all must be 'hardi'") - # Get Stanford HARDI data - hardi_nifti_fname, hardi_bval_fname, hardi_bvec_fname = get_fnames( - name='stanford_hardi') - csf, gm, wm = read_stanford_pve_maps() - wm_data = wm.get_fdata() +if "hardi" in [ + args.nifti_file, + args.bvals, + args.bvecs, + args.mask_nifti, + args.roi_nifti, +]: + if not all( + arg == "hardi" + for arg in [ + args.nifti_file, + args.bvals, + args.bvecs, + args.mask_nifti, + args.roi_nifti, + ] + ): + raise ValueError("If any of the arguments is 'hardi', all must be 'hardi'") + # Get Stanford HARDI data + hardi_nifti_fname, hardi_bval_fname, hardi_bvec_fname = get_fnames( + name="stanford_hardi" + ) + csf, gm, wm = read_stanford_pve_maps() + wm_data = wm.get_fdata() - img = get_img(hardi_nifti_fname) - voxel_order = "".join(aff2axcodes(img.affine)) + img = get_img(hardi_nifti_fname) + voxel_order = "".join(aff2axcodes(img.affine)) - gtab = get_gtab(hardi_bval_fname, hardi_bvec_fname) + gtab = get_gtab(hardi_bval_fname, hardi_bvec_fname) - data = img.get_fdata() - roi_data = (wm_data > 0.5) - mask = roi_data + data = img.get_fdata() + roi_data = wm_data > 0.5 + mask = roi_data else: - img = get_img(args.nifti_file) - voxel_order = "".join(aff2axcodes(img.affine)) - gtab = get_gtab(args.bvals, args.bvecs) - roi = get_img(args.roi_nifti) - mask = get_img(args.mask_nifti) - data = img.get_fdata() - roi_data = roi.get_fdata() - mask = mask.get_fdata() + img = get_img(args.nifti_file) + voxel_order = "".join(aff2axcodes(img.affine)) + gtab = get_gtab(args.bvals, args.bvecs) + roi = get_img(args.roi_nifti) + mask = get_img(args.mask_nifti) + data = img.get_fdata() + roi_data = roi.get_fdata() + mask = mask.get_fdata() -tenmodel = dti.TensorModel(gtab, fit_method='WLS') -print('Fitting Tensor') -tenfit = tenmodel.fit(data, mask=mask) -print('Computing anisotropy measures (FA,MD,RGB)') -FA = tenfit.fa + +fa_cache_file = op.join(args.cache_dir, "fa.npy") +if args.cache_dir != "" and op.exists(fa_cache_file): + print("Loading FA from cache") + FA = np.load(fa_cache_file) +else: + tenmodel = dti.TensorModel(gtab, fit_method="WLS") + print("Fitting Tensor") + tenfit = tenmodel.fit(data, mask=mask) + print("Computing anisotropy measures (FA,MD,RGB)") + FA = tenfit.fa + if args.cache_dir != "": + np.save(fa_cache_file, FA) # Setup tissue_classifier args tissue_classifier = ThresholdStoppingCriterion(FA, args.fa_threshold) # Create seeds for ROI -seed_mask = np.asarray(utils.random_seeds_from_mask( - roi_data, seeds_count=args.nseeds, - seed_count_per_voxel=False, - affine=np.eye(4))) +seed_mask = np.asarray( + utils.random_seeds_from_mask( + roi_data, + seeds_count=args.nseeds, + seed_count_per_voxel=False, + random_seed=args.seed_seed, + affine=np.eye(4), + ) +) # Setup model if args.sphere == "small": - sphere = small_sphere + sphere = small_sphere else: - sphere = default_sphere + sphere = default_sphere if args.model == "opdt": - if args.device == "cpu": - model = OpdtModel(gtab, sh_order=args.sh_order, smooth=args.sm_lambda, min_signal=args.min_signal) - dg = cpu_BootDirectionGetter - else: - dg = BootDirectionGetter.from_dipy_opdt( - gtab, - sphere, - sh_order_max=args.sh_order, - sh_lambda=args.sm_lambda, - min_signal=args.min_signal) + if args.device == "cpu": + model = OpdtModel( + gtab, + sh_order=args.sh_order, + smooth=args.sm_lambda, + min_signal=args.min_signal, + ) + dg = cpu_BootDirectionGetter + else: + dg = BootDirectionGetter.from_dipy_opdt( + gtab, + sphere, + sh_order_max=args.sh_order, + sh_lambda=args.sm_lambda, + min_signal=args.min_signal, + ) elif args.model == "csa": - if args.device == "cpu": - model = CsaOdfModel(gtab, sh_order=args.sh_order, smooth=args.sm_lambda, min_signal=args.min_signal) - dg = cpu_BootDirectionGetter - else: - dg = BootDirectionGetter.from_dipy_csa( - gtab, - sphere, - sh_order_max=args.sh_order, - sh_lambda=args.sm_lambda, - min_signal=args.min_signal) + if args.device == "cpu": + model = CsaOdfModel( + gtab, + sh_order=args.sh_order, + smooth=args.sm_lambda, + min_signal=args.min_signal, + ) + dg = cpu_BootDirectionGetter + else: + dg = BootDirectionGetter.from_dipy_csa( + gtab, + sphere, + sh_order_max=args.sh_order, + sh_lambda=args.sm_lambda, + min_signal=args.min_signal, + ) else: - print("Running CSD model...") - unique_bvals = unique_bvals_magnitude(gtab.bvals) - if len(unique_bvals[unique_bvals > 0]) > 1: - low_shell_idx = gtab.bvals <= unique_bvals[unique_bvals > 0][0] - response_gtab = gradient_table( # reinit as single shell for this CSD - gtab.bvals[low_shell_idx], - gtab.bvecs[low_shell_idx]) - data = data[..., low_shell_idx] - else: - response_gtab = gtab - response, _ = auto_response_ssst( - response_gtab, - data, - roi_radii=10, - fa_thr=0.7) - model = ConstrainedSphericalDeconvModel(response_gtab, response, sh_order=args.sh_order) - fit = model.fit(data, mask=(FA >= args.fa_threshold)) - data = fit.odf(sphere).clip(min=0) - if args.dg == "ptt": - if args.device == "cpu": - dg = cpu_PTTDirectionGetter() - else: - # Set FOD to 0 outside mask for probing - data[FA < args.fa_threshold, :] = 0 - dg = PttDirectionGetter() - elif args.dg == "prob": - if args.device == "cpu": - dg = cpu_ProbDirectionGetter() - else: - dg = ProbDirectionGetter() - else: - raise ValueError("Unknown direction getter type: {}".format(args.dg)) + csd_odf_cache_file = op.join(args.cache_dir, "csd_odf.npy") + if args.cache_dir != "" and op.exists(csd_odf_cache_file): + print("Loading CSD ODF from cache") + data = np.load(csd_odf_cache_file) + else: + print("Running CSD model...") + unique_bvals = unique_bvals_magnitude(gtab.bvals) + if len(unique_bvals[unique_bvals > 0]) > 1: + low_shell_idx = gtab.bvals <= unique_bvals[unique_bvals > 0][0] + response_gtab = gradient_table( # reinit as single shell for this CSD + gtab.bvals[low_shell_idx], gtab.bvecs[low_shell_idx] + ) + data = data[..., low_shell_idx] + else: + response_gtab = gtab + response, _ = auto_response_ssst(response_gtab, data, roi_radii=10, fa_thr=0.7) + model = ConstrainedSphericalDeconvModel( + response_gtab, response, sh_order=args.sh_order + ) + fit = model.fit(data, mask=(FA >= args.fa_threshold)) + data = fit.odf(sphere).clip(min=0) + + if args.cache_dir != "": + np.save(csd_odf_cache_file, data) + if args.dg == "ptt": + if args.device == "cpu": + dg = cpu_PTTDirectionGetter() + else: + # Set FOD to 0 outside mask for probing + data[FA < args.fa_threshold, :] = 0 + dg = PttDirectionGetter() + elif args.dg == "prob": + if args.device == "cpu": + dg = cpu_ProbDirectionGetter() + else: + dg = ProbDirectionGetter() + else: + raise ValueError("Unknown direction getter type: {}".format(args.dg)) # Setup direction getter args if args.device == "cpu": - if args.dg != "boot": - dg = dg.from_pmf(data, max_angle=args.max_angle, sphere=sphere, relative_peak_threshold=args.relative_peak_threshold, min_separation_angle=args.min_separation_angle) - else: - dg = dg.from_data(data, model, max_angle=args.max_angle, sphere=sphere, sh_order=args.sh_order, relative_peak_threshold=args.relative_peak_threshold, min_separation_angle=args.min_separation_angle) + if args.dg != "boot": + dg = dg.from_pmf( + data, + max_angle=args.max_angle, + sphere=sphere, + relative_peak_threshold=args.relative_peak_threshold, + min_separation_angle=args.min_separation_angle, + ) + else: + dg = dg.from_data( + data, + model, + max_angle=args.max_angle, + sphere=sphere, + sh_order=args.sh_order, + relative_peak_threshold=args.relative_peak_threshold, + min_separation_angle=args.min_separation_angle, + ) - ts = time.time() - streamline_generator = LocalTracking(dg, tissue_classifier, seed_mask, affine=np.eye(4), step_size=args.step_size) - sft = StatefulTractogram(streamline_generator, img, Space.VOX) - n_sls = len(sft.streamlines) - te = time.time() + ts = time.time() + streamline_generator = LocalTracking( + dg, tissue_classifier, seed_mask, affine=np.eye(4), step_size=args.step_size + ) + sft = StatefulTractogram(streamline_generator, img, Space.VOX) + n_sls = len(sft.streamlines) + te = time.time() else: with GPUTracker( dg, @@ -269,13 +409,13 @@ def get_img(ep2_seq): args.fa_threshold, sphere.vertices, sphere.edges, - max_angle=args.max_angle * np.pi/180, + max_angle=args.max_angle * np.pi / 180, step_size=args.step_size, relative_peak_thresh=args.relative_peak_threshold, - min_separation_angle=args.min_separation_angle * np.pi/180, + min_separation_angle=args.min_separation_angle * np.pi / 180, ngpus=args.ngpus, rng_seed=0, - chunk_size=args.chunk_size + chunk_size=args.chunk_size, ) as gpu_tracker: ts = time.time() if args.output_prefix and write_method == "trx": @@ -285,14 +425,16 @@ def get_img(ep2_seq): sft = gpu_tracker.generate_sft(seed_mask, img) n_sls = len(sft.streamlines) te = time.time() -print("Generated {} streamlines from {} seeds, time: {} s".format(n_sls, - seed_mask.shape[0], - te-ts)) +print( + "Generated {} streamlines from {} seeds, time: {} s".format( + n_sls, seed_mask.shape[0], te - ts + ) +) if args.output_prefix: - if write_method == "trx": - fname = "{}.trx".format(args.output_prefix) - save_trx(trx_file, fname) - else: - fname = "{}.trk".format(args.output_prefix) - save_tractogram(sft, fname) + if write_method == "trx": + fname = "{}.trx".format(args.output_prefix) + save_trx(trx_file, fname) + else: + fname = "{}.trk".format(args.output_prefix) + save_tractogram(sft, fname) diff --git a/setup.py b/setup.py deleted file mode 100644 index b3cd873..0000000 --- a/setup.py +++ /dev/null @@ -1,53 +0,0 @@ -from setuptools import setup -from setuptools.command.build_py import build_py -from pathlib import Path -import os.path as op -import re - - -def defines_to_python(src, dst): - root = Path(__file__).parent - - src = root / src - dst = root / dst - - INT_DEFINE = re.compile( - r"#define\s+(\w+)\s+\(?\s*([0-9]+)\s*\)?" - ) - - REAL_CAST_DEFINE = re.compile( - r"#define\s+(\w+)\s+\(\(REAL\)\s*([0-9eE\.\+\-]+)\s*\)" - ) - - defines = {} - - for line in src.read_text().splitlines(): - if m := INT_DEFINE.match(line): - defines[m.group(1)] = int(m.group(2)) - elif m := REAL_CAST_DEFINE.match(line): - defines[m.group(1)] = float(m.group(2)) - - dst.parent.mkdir(parents=True, exist_ok=True) - - with dst.open("w") as f: - f.write("# AUTO-GENERATED FROM globals.h — DO NOT EDIT\n\n") - for k, v in sorted(defines.items()): - f.write(f"{k} = {v}\n") - -class build_py_with_cuda(build_py): - def run(self): - globals_src = op.join("cuslines", "cuda_c", "globals.h") - globals_dst = op.join("cuslines", "cuda_python", "_globals.py") - defines_to_python(globals_src, globals_dst) - - super().run() - -setup( - cmdclass={"build_py": build_py_with_cuda}, - package_data={ - "cuslines": ["cuda_c/*", "metal_shaders/*", "wgsl_shaders/*"], - }, - project_urls={ - "Homepage": "https://github.com/dipy/GPUStreamlines", - } -)