From a34b8236dfc7dcac7170ae25e0fd19a2cedf6956 Mon Sep 17 00:00:00 2001 From: Josh Williams Date: Tue, 1 Jul 2025 10:21:10 +0100 Subject: [PATCH 1/5] Add transpose support convolution to distconv --- distconv/distconv.py | 90 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 76 insertions(+), 14 deletions(-) diff --git a/distconv/distconv.py b/distconv/distconv.py index 7e9ed6b..c4662e9 100644 --- a/distconv/distconv.py +++ b/distconv/distconv.py @@ -169,7 +169,7 @@ def check_is_distconv_supported( if kernel_size % 2 == 1: if (kernel_size // 2) != padding[shard_dim]: raise Exception( - 'DistConv: when kernel size is odd, padding must be equivalent to "same"' + f'DistConv: when kernel size is odd ({kernel_size}), padding must be equivalent to "same" but found ({padding})' ) else: if padding[shard_dim] != 0: @@ -383,18 +383,31 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": args = list(args) # Unpack the necessary arguments - tensor, weight, bias, stride, padding, dilation = args[:6] + tensor, weight, bias, stride, padding, dilation, transpose, output_padding = args[ + :8 + ] + padding_orig = copy(padding) # Extract the parallel strategy and shard dimension from the input tensor parallel_strategy = tensor._parallel_strategy shard_dim = parallel_strategy.shard_dim + shard_ind = parallel_strategy.shard_ind + world_size = parallel_strategy.world_size + kernel_size = weight.size(shard_dim) is_periodic = tensor._is_periodic for i, shard_dim_i in enumerate(shard_dim): if is_periodic[i]: - assert padding[shard_dim_i - 2] == 0, ( - "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" - ) - padding[shard_dim_i - 2] = tensor._periodic_shard_padding[i] + if transpose: + assert padding[shard_dim_i - 2] == dilation[shard_dim_i - 2] * ( + kernel_size - 1 + ), f"padding is incorrect" + padding[shard_dim_i - 2] = tensor._periodic_shard_padding[i] + padding_orig[shard_dim_i - 2] = tensor._periodic_shard_padding[i] + else: + assert padding[shard_dim_i - 2] == 0, ( + "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" + ) + padding[shard_dim_i - 2] = tensor._periodic_shard_padding[i] # Unwrap the underlying tensor from the DCTensor torch_tensor = tensor._tensor @@ -420,10 +433,24 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": # Save the tensor with its halo for the backward pass. tensor._tensor_with_halo = tensor_with_halo + if transpose: + padding[shard_dim_i - 2] = dilation[shard_dim_i - 2] * (kernel_size - 1) + for dim_i in range(tensor.ndim - 2): + if dim_i + 2 == shard_dim_i: + padding[shard_dim_i - 2] += (kernel_size - 1 - padding_orig[dim_i]) * ( + stride[dim_i] - 1 + ) + # modify output_padding for strided transpose convolution + if shard_ind < world_size - 1: + output_padding[dim_i] += ( + tensor.size(shard_dim_i) + 2 * padding_orig[dim_i] - kernel_size + ) % stride[dim_i] + else: + padding[shard_dim_i - 2] = 0 # Update the arguments with the tensor including halos and adjusted padding args[0] = tensor_with_halo - padding[shard_dim_i - 2] = 0 args[4] = padding + args[7] = output_padding tensor._tensor = tensor_with_halo for i, shard_dim_i in enumerate(shard_dim): @@ -456,20 +483,40 @@ def distconv_backward( args = list(args) # Unpack the necessary arguments - grad_out_tensor, input_tensor, weight, bias_size, stride, padding, dilation = args[ - :7 - ] + ( + grad_out_tensor, + input_tensor, + weight, + bias_size, + stride, + padding, + dilation, + transpose, + output_padding, + ) = args[:9] + padding_orig = copy(padding) # Extract the parallel strategy and shard dimension from the gradient output tensor parallel_strategy = grad_out_tensor._parallel_strategy shard_dim = parallel_strategy.shard_dim + shard_ind = parallel_strategy.shard_ind + world_size = parallel_strategy.world_size is_periodic = input_tensor._is_periodic + padding_orig = [0] * len(padding) for i, shard_dim_i in enumerate(shard_dim): + kernel_size = weight.size(shard_dim_i) if is_periodic[i]: - assert padding[shard_dim_i - 2] == 0, ( - "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" - ) - padding[shard_dim_i - 2] = input_tensor._periodic_shard_padding[i] + if transpose: + assert padding[shard_dim_i - 2] == dilation[shard_dim_i - 2] * ( + kernel_size - 1 + ), f"shard-dim padding incorrect" + padding[shard_dim_i - 2] = input_tensor._periodic_shard_padding[i] + padding_orig[shard_dim_i - 2] = input_tensor._periodic_shard_padding[i] + else: + assert padding[shard_dim_i - 2] == 0, ( + "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" + ) + padding[shard_dim_i - 2] = input_tensor._periodic_shard_padding[i] # Unwrap the underlying tensors from the DCTensors grad_out_tensor = grad_out_tensor._tensor @@ -487,6 +534,20 @@ def distconv_backward( halo_size = kernel_size // 2 if (kernel_size % 2 == 1) else 0 halo_sizes.append(halo_size) padding[shard_dim_i - 2] = 0 + if transpose: + padding[shard_dim_i - 2] = dilation[shard_dim_i - 2] * (kernel_size - 1) + for dim_i in range(input_tensor.ndim - 2): + if dim_i + 2 == shard_dim_i: + padding[dim_i] += padding_orig[dim_i] * (stride[dim_i] - 1) + if shard_ind < world_size - 1: + crop_amount = ( + input_tensor.size(shard_dim_i) + + 2 * padding_orig[dim_i] + - kernel_size + ) % stride[dim_i] + output_padding[dim_i] = crop_amount + else: + padding[shard_dim_i - 2] = 0 # Get the input tensor including halos if available, otherwise perform forward halo exchange if input_tensor._tensor_with_halo is not None: @@ -506,6 +567,7 @@ def distconv_backward( args[0] = grad_out_tensor args[1] = input_tensor_with_halo args[5] = padding + args[8] = output_padding # Perform the backward convolution operation grad_in_tensor, grad_weight, grad_bias = func(*args, **kwargs) From 3116a4dfe869abd823bcb1eae429c11d356f6434 Mon Sep 17 00:00:00 2001 From: Josh Williams Date: Tue, 1 Jul 2025 10:21:29 +0100 Subject: [PATCH 2/5] Add transpose convolution unit test for zero and periodic padding --- tests/test_convtranspose.py | 189 ++++++++++++++++++++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 tests/test_convtranspose.py diff --git a/tests/test_convtranspose.py b/tests/test_convtranspose.py new file mode 100644 index 0000000..29f059d --- /dev/null +++ b/tests/test_convtranspose.py @@ -0,0 +1,189 @@ +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from utils import cleanup_parallel_strategy, fp32_allclose + +from distconv import DCTensor, DistConvDDP, ParallelStrategy + + +def all_gather_vlen(tensor: torch.Tensor, group=None, dim=0) -> list[torch.Tensor]: + """Gather tensors with the same number of dimensions but different lengths. + + Credit: https://stackoverflow.com/a/78934638 + """ + world_size = dist.get_world_size(group=group) + # Gather lengths first + shape = torch.as_tensor(tensor.shape, device=tensor.device) + shapes = [torch.empty_like(shape) for _ in range(world_size)] + dist.all_gather(shapes, shape, group=group) + # Gather data + inputs = [tensor] * world_size + outputs = [ + torch.empty(*_shape, dtype=tensor.dtype, device=tensor.device) + for _shape in shapes + ] + dist.all_to_all(outputs, inputs, group=group) + return torch.cat(outputs, dim=dim) + + +@pytest.fixture(scope="module") +def parallel_strategy(device: torch.device): + ps = ParallelStrategy(num_shards=4, device_type=device.type) + yield ps + cleanup_parallel_strategy(ps) + + +def find_padding(kernel_size): + if kernel_size % 2 != 0: + return kernel_size // 2 + else: + return 0 + + +def generate_configs(): + configs = [] + for ndims in [1, 2, 3]: + for shard_dim in range(ndims): + for kernel_size in [1, 3, 5]: + for stride in [1, 2, 4]: + configs.append((ndims, shard_dim, kernel_size, stride)) + + return "ndims,shard_dim,kernel_size,stride", configs + + +@pytest.mark.parametrize(*generate_configs()) +def test_transposeconv_zerospadding( + parallel_strategy: ParallelStrategy, + ndims: int, + shard_dim: int, + kernel_size: int, + padding: int, + stride: int, + device: torch.device, +): + """ + Test distributed convolution with different number of dimensions, kernel sizes, and strides. + Checks the output and gradients of the distributed convolution against the non-distributed + convolution. + + Args: + parallel_strategy (ParallelStrategy): Parallel strategy for the distributed convolution. + ndims (int): Number of dimensions for the convolution (1, 2, or 3). + shard_dim (int): Dimension along which the tensor is sharded. + kernel_size (int): Size of the convolution kernel. + stride (int): Stride of the convolution. + device (torch.device): Torch device to run test with. + """ + # Set the shard dimension for the parallel strategy + parallel_strategy.shard_dim = 2 + shard_dim + padding = find_padding(kernel_size) + + # Initialize the input tensor and convolution layer + shape = [1, 4] + [64] * ndims + x = torch.randn(*shape, device=device, requires_grad=True) + conv_class = getattr(nn, f"ConvTranspose{ndims}d") + conv = conv_class(4, 8, kernel_size=kernel_size, padding=padding, stride=stride).to( + device + ) + + # Perform forward and backward pass for reference (non-distributed) convolution + conv.zero_grad() + ref_y = conv(x) + ref_y.sum().backward() + ref_x_grad = x.grad + ref_conv_grad = conv.weight.grad + + # Perform forward and backward pass for distributed convolution + conv.zero_grad() + dist_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) + dcx = DCTensor.distribute(x, parallel_strategy) + dcy = dist_conv(dcx) + dcy_merge = all_gather_vlen(dcy, dim=(parallel_strategy.shard_dim)) + dc_loss = dcy.sum() + dist.all_reduce(dc_loss) + dc_loss.backward() + x_grad = dcx.grad.to_replicate() + dc_conv_grad = conv.weight.grad + + assert fp32_allclose(ref_y, dcy_merge) + assert fp32_allclose(ref_x_grad, x_grad) + assert fp32_allclose(ref_conv_grad, dc_conv_grad) + + +@pytest.mark.parametrize(*generate_configs()) +def test_transposeconv_circularpadding( + parallel_strategy: ParallelStrategy, + ndims: int, + shard_dim: int, + kernel_size: int, + stride: int, + device: torch.device, +): + """ + Test distributed convolution with different number of dimensions, kernel sizes, and strides. + Checks the output and gradients of the distributed convolution against the non-distributed + convolution. + + Args: + parallel_strategy (ParallelStrategy): Parallel strategy for the distributed convolution. + ndims (int): Number of dimensions for the convolution (1, 2, or 3). + shard_dim (int): Dimension along which the tensor is sharded. + kernel_size (int): Size of the convolution kernel. + stride (int): Stride of the convolution. + device (torch.device): Torch device to run test with. + """ + # Set the shard dimension for the parallel strategy + parallel_strategy.shard_dim = 2 + shard_dim + padding = find_padding(kernel_size) + + # Initialize the input tensor and convolution layer + shape = [1, 4] + [64] * ndims + x = torch.randn(*shape, device=device, requires_grad=True) + + conv_kwargs = dict(kernel_size=kernel_size, stride=stride) + + # set periodic padding case for reference + new_padding = [padding, padding] * ndims + x_periodic = torch.nn.functional.pad(input=x, pad=new_padding, mode="circular") + ref_padding = kernel_size - 1 + + conv_class = getattr(nn, f"ConvTranspose{ndims}d") + conv = ( + conv_class(4, 8, padding=ref_padding, **conv_kwargs) + .to(device) + .requires_grad_(False) + ) + conv.requires_grad_(True) + + # Perform forward and backward pass for reference (non-distributed) convolution + conv.zero_grad() + ref_y = conv(x_periodic) + for i in range(0, ndims): + crop_amount = (kernel_size - 1 - padding) * (stride - 1) + ref_y = ref_y.narrow(i + 2, crop_amount, ref_y.shape[i + 2] - 2 * crop_amount) + ref_y.sum().backward() + ref_x_grad = x.grad + ref_conv_grad = conv.weight.grad + + # Perform forward and backward pass for distributed convolution + conv.zero_grad() + dist_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) + dcx = DCTensor.distribute(x, parallel_strategy) + dcx_periodic = torch.nn.functional.pad(input=dcx, pad=new_padding, mode="circular") + dcy = dist_conv(dcx_periodic) + for i in range(0, ndims): + if i != shard_dim: + crop_amount = (kernel_size - 1 - padding) * (stride - 1) + dcy = dcy.narrow(i + 2, crop_amount, dcy.shape[i + 2] - 2 * crop_amount) + dcy_merge = all_gather_vlen(dcy.contiguous(), dim=(parallel_strategy.shard_dim)) + dc_loss = dcy.sum() + dist.all_reduce(dc_loss) + dc_loss.backward() + x_grad = dcx.grad.to_replicate() + dc_conv_grad = conv.weight.grad + + # Validate the results + assert fp32_allclose(ref_y, dcy_merge) + assert fp32_allclose(ref_x_grad, x_grad) + assert fp32_allclose(ref_conv_grad, dc_conv_grad) From 95690981a63f8fc9f8eccbe1365117b6b5339114 Mon Sep 17 00:00:00 2001 From: Pier Fiedorowicz Date: Wed, 9 Jul 2025 16:37:20 -0700 Subject: [PATCH 3/5] Improve and simplify constraint checking and update tests Ruff formatting --- distconv/distconv.py | 114 +++++++++++++++----------------- tests/test_basic.py | 4 +- tests/test_convtranspose.py | 70 ++++++++++---------- tests/test_ddp_with_distconv.py | 4 +- tests/test_periodic.py | 8 ++- tests/test_strides.py | 4 +- tests/utils.py | 1 - 7 files changed, 100 insertions(+), 105 deletions(-) diff --git a/distconv/distconv.py b/distconv/distconv.py index c4662e9..27e1b36 100644 --- a/distconv/distconv.py +++ b/distconv/distconv.py @@ -141,6 +141,8 @@ def check_is_distconv_supported( stride: List[int], padding: List[int], dilation: List[int], + transpose: bool, + output_padding: List[int], ) -> None: """ Check if the distributed convolution is supported with the given parameters. @@ -152,31 +154,42 @@ def check_is_distconv_supported( stride (List[int]): The stride of the convolution. padding (List[int]): The padding added to the input tensor. dilation (List[int]): The dilation applied to the kernel. + transpose (bool): Is transposed convolution. + dilation (List[int]): The output padding for transposed convolution. Raises: - Exception: If dilation is not 1. - Exception: If input size is not divisible by stride. - Exception: If kernel size is odd and padding is not equivalent to "same". - Exception: If kernel size is even and padding is not zero. - Exception: If kernel size is even and stride is not divisible by kernel size. + Exception: If local input size is not equal to stride times output size. + Exception: If local output size is not equal to stride times input size for transposed convolution. """ shard_dim = tensor_shard_dim - 2 kernel_size = weight.size(tensor_shard_dim) if dilation[shard_dim] != 1: raise Exception("DistConv: dilation must be 1") - if tensor.size(tensor_shard_dim) % stride[shard_dim] != 0: - raise Exception("DistConv: input size must be divisible by stride") - if kernel_size % 2 == 1: - if (kernel_size // 2) != padding[shard_dim]: + + input_size = tensor.size(tensor_shard_dim) + + if not transpose: + output_size = (input_size + 2 * padding[shard_dim] - kernel_size) // stride[ + shard_dim + ] + 1 + + if output_size * stride[shard_dim] != input_size: raise Exception( - f'DistConv: when kernel size is odd ({kernel_size}), padding must be equivalent to "same" but found ({padding})' + "DistConv: The input size along the shard dimension must equal the stride times the output size for the local tensors.\n" + + "This indicates incompatible kernel size, stride, and/or padding for the given input shape and parallel strategy." ) else: - if padding[shard_dim] != 0: - raise Exception("DistConv: when kernel size is even, padding must be zero") - if stride[shard_dim] % kernel_size != 0: + output_size = ( + (input_size - 1) * stride[shard_dim] + - 2 * padding[shard_dim] + + kernel_size + + output_padding[shard_dim] + ) + + if output_size != input_size * stride[shard_dim]: raise Exception( - "DistConv: when kernel size is even, stride must be divisble by kernel size" + "DistConv: The output size along the shard dimension must equal the stride times the input size for the local tensors.\n" + + "This indicates incompatible kernel size, stride, padding, and/or output padding for the given input shape and parallel strategy." ) @@ -386,23 +399,17 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": tensor, weight, bias, stride, padding, dilation, transpose, output_padding = args[ :8 ] - padding_orig = copy(padding) # Extract the parallel strategy and shard dimension from the input tensor parallel_strategy = tensor._parallel_strategy shard_dim = parallel_strategy.shard_dim - shard_ind = parallel_strategy.shard_ind - world_size = parallel_strategy.world_size - kernel_size = weight.size(shard_dim) is_periodic = tensor._is_periodic for i, shard_dim_i in enumerate(shard_dim): if is_periodic[i]: if transpose: - assert padding[shard_dim_i - 2] == dilation[shard_dim_i - 2] * ( - kernel_size - 1 - ), f"padding is incorrect" - padding[shard_dim_i - 2] = tensor._periodic_shard_padding[i] - padding_orig[shard_dim_i - 2] = tensor._periodic_shard_padding[i] + padding[shard_dim_i - 2] -= ( + stride[shard_dim_i - 2] * tensor._periodic_shard_padding[i] + ) else: assert padding[shard_dim_i - 2] == 0, ( "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" @@ -417,7 +424,14 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": halo_sizes = [] for i, shard_dim_i in enumerate(shard_dim): check_is_distconv_supported( - shard_dim_i, torch_tensor, weight, stride, padding, dilation + shard_dim_i, + torch_tensor, + weight, + stride, + padding, + dilation, + transpose, + output_padding, ) # Determine the halo size for halo exchange @@ -434,17 +448,7 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": tensor._tensor_with_halo = tensor_with_halo if transpose: - padding[shard_dim_i - 2] = dilation[shard_dim_i - 2] * (kernel_size - 1) - for dim_i in range(tensor.ndim - 2): - if dim_i + 2 == shard_dim_i: - padding[shard_dim_i - 2] += (kernel_size - 1 - padding_orig[dim_i]) * ( - stride[dim_i] - 1 - ) - # modify output_padding for strided transpose convolution - if shard_ind < world_size - 1: - output_padding[dim_i] += ( - tensor.size(shard_dim_i) + 2 * padding_orig[dim_i] - kernel_size - ) % stride[dim_i] + padding[shard_dim_i - 2] += stride[shard_dim_i - 2] * halo_size else: padding[shard_dim_i - 2] = 0 # Update the arguments with the tensor including halos and adjusted padding @@ -494,24 +498,17 @@ def distconv_backward( transpose, output_padding, ) = args[:9] - padding_orig = copy(padding) # Extract the parallel strategy and shard dimension from the gradient output tensor parallel_strategy = grad_out_tensor._parallel_strategy shard_dim = parallel_strategy.shard_dim - shard_ind = parallel_strategy.shard_ind - world_size = parallel_strategy.world_size is_periodic = input_tensor._is_periodic - padding_orig = [0] * len(padding) for i, shard_dim_i in enumerate(shard_dim): - kernel_size = weight.size(shard_dim_i) if is_periodic[i]: if transpose: - assert padding[shard_dim_i - 2] == dilation[shard_dim_i - 2] * ( - kernel_size - 1 - ), f"shard-dim padding incorrect" - padding[shard_dim_i - 2] = input_tensor._periodic_shard_padding[i] - padding_orig[shard_dim_i - 2] = input_tensor._periodic_shard_padding[i] + padding[shard_dim_i - 2] -= ( + stride[shard_dim_i - 2] * input_tensor._periodic_shard_padding[i] + ) else: assert padding[shard_dim_i - 2] == 0, ( "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" @@ -525,27 +522,22 @@ def distconv_backward( # Check if the distributed convolution is supported with the given parameters halo_sizes = [] for i, shard_dim_i in enumerate(shard_dim): - check_is_distconv_supported( - shard_dim_i, input_torch_tensor, weight, stride, padding, dilation - ) - # Determine the halo size for halo exchange kernel_size = weight.size(shard_dim_i) halo_size = kernel_size // 2 if (kernel_size % 2 == 1) else 0 halo_sizes.append(halo_size) - padding[shard_dim_i - 2] = 0 + check_is_distconv_supported( + shard_dim_i, + input_torch_tensor, + weight, + stride, + padding, + dilation, + transpose, + output_padding, + ) if transpose: - padding[shard_dim_i - 2] = dilation[shard_dim_i - 2] * (kernel_size - 1) - for dim_i in range(input_tensor.ndim - 2): - if dim_i + 2 == shard_dim_i: - padding[dim_i] += padding_orig[dim_i] * (stride[dim_i] - 1) - if shard_ind < world_size - 1: - crop_amount = ( - input_tensor.size(shard_dim_i) - + 2 * padding_orig[dim_i] - - kernel_size - ) % stride[dim_i] - output_padding[dim_i] = crop_amount + padding[shard_dim_i - 2] += stride[shard_dim_i - 2] * halo_size else: padding[shard_dim_i - 2] = 0 diff --git a/tests/test_basic.py b/tests/test_basic.py index d0748f5..9322cc1 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -2,10 +2,10 @@ import torch import torch.distributed as dist import torch.nn as nn -from utils import cleanup_parallel_strategy, fp32_allclose - from distconv import DCTensor, DistConvDDP, ParallelStrategy +from utils import cleanup_parallel_strategy, fp32_allclose + @pytest.fixture(scope="module") def parallel_strategy(device: torch.device): diff --git a/tests/test_convtranspose.py b/tests/test_convtranspose.py index 29f059d..f353955 100644 --- a/tests/test_convtranspose.py +++ b/tests/test_convtranspose.py @@ -2,10 +2,10 @@ import torch import torch.distributed as dist import torch.nn as nn -from utils import cleanup_parallel_strategy, fp32_allclose - from distconv import DCTensor, DistConvDDP, ParallelStrategy +from utils import cleanup_parallel_strategy, fp32_allclose + def all_gather_vlen(tensor: torch.Tensor, group=None, dim=0) -> list[torch.Tensor]: """Gather tensors with the same number of dimensions but different lengths. @@ -34,11 +34,13 @@ def parallel_strategy(device: torch.device): cleanup_parallel_strategy(ps) -def find_padding(kernel_size): - if kernel_size % 2 != 0: - return kernel_size // 2 - else: - return 0 +def find_padding(kernel_size, stride=1, explicit_padding=False): + ep = kernel_size // 2 if explicit_padding else 0 + pad = (kernel_size + 2 * ep * stride - 1) // 2 + out_pad = stride - 1 + if explicit_padding: + return pad, out_pad, ep + return pad, out_pad def generate_configs(): @@ -58,7 +60,6 @@ def test_transposeconv_zerospadding( ndims: int, shard_dim: int, kernel_size: int, - padding: int, stride: int, device: torch.device, ): @@ -77,20 +78,25 @@ def test_transposeconv_zerospadding( """ # Set the shard dimension for the parallel strategy parallel_strategy.shard_dim = 2 + shard_dim - padding = find_padding(kernel_size) + padding, output_padding = find_padding(kernel_size, stride) # Initialize the input tensor and convolution layer shape = [1, 4] + [64] * ndims x = torch.randn(*shape, device=device, requires_grad=True) conv_class = getattr(nn, f"ConvTranspose{ndims}d") - conv = conv_class(4, 8, kernel_size=kernel_size, padding=padding, stride=stride).to( - device - ) + conv = conv_class( + 4, + 8, + kernel_size=kernel_size, + padding=padding, + stride=stride, + output_padding=output_padding, + ).to(device) # Perform forward and backward pass for reference (non-distributed) convolution conv.zero_grad() ref_y = conv(x) - ref_y.sum().backward() + ref_y.square().mean().backward() ref_x_grad = x.grad ref_conv_grad = conv.weight.grad @@ -99,9 +105,8 @@ def test_transposeconv_zerospadding( dist_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) dcx = DCTensor.distribute(x, parallel_strategy) dcy = dist_conv(dcx) - dcy_merge = all_gather_vlen(dcy, dim=(parallel_strategy.shard_dim)) - dc_loss = dcy.sum() - dist.all_reduce(dc_loss) + dcy_merge = dcy.to_replicate() + dc_loss = dcy.to_ddp().square().mean() dc_loss.backward() x_grad = dcx.grad.to_replicate() dc_conv_grad = conv.weight.grad @@ -135,22 +140,25 @@ def test_transposeconv_circularpadding( """ # Set the shard dimension for the parallel strategy parallel_strategy.shard_dim = 2 + shard_dim - padding = find_padding(kernel_size) + padding, output_padding, explicit_padding = find_padding( + kernel_size, stride, explicit_padding=True + ) # Initialize the input tensor and convolution layer shape = [1, 4] + [64] * ndims x = torch.randn(*shape, device=device, requires_grad=True) - conv_kwargs = dict(kernel_size=kernel_size, stride=stride) + conv_kwargs = dict( + kernel_size=kernel_size, stride=stride, output_padding=output_padding + ) # set periodic padding case for reference - new_padding = [padding, padding] * ndims - x_periodic = torch.nn.functional.pad(input=x, pad=new_padding, mode="circular") - ref_padding = kernel_size - 1 + explicit_padding = [explicit_padding, explicit_padding] * ndims + x_periodic = torch.nn.functional.pad(input=x, pad=explicit_padding, mode="circular") conv_class = getattr(nn, f"ConvTranspose{ndims}d") conv = ( - conv_class(4, 8, padding=ref_padding, **conv_kwargs) + conv_class(4, 8, padding=padding, **conv_kwargs) .to(device) .requires_grad_(False) ) @@ -159,10 +167,7 @@ def test_transposeconv_circularpadding( # Perform forward and backward pass for reference (non-distributed) convolution conv.zero_grad() ref_y = conv(x_periodic) - for i in range(0, ndims): - crop_amount = (kernel_size - 1 - padding) * (stride - 1) - ref_y = ref_y.narrow(i + 2, crop_amount, ref_y.shape[i + 2] - 2 * crop_amount) - ref_y.sum().backward() + ref_y.square().mean().backward() ref_x_grad = x.grad ref_conv_grad = conv.weight.grad @@ -170,15 +175,12 @@ def test_transposeconv_circularpadding( conv.zero_grad() dist_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) dcx = DCTensor.distribute(x, parallel_strategy) - dcx_periodic = torch.nn.functional.pad(input=dcx, pad=new_padding, mode="circular") + dcx_periodic = torch.nn.functional.pad( + input=dcx, pad=explicit_padding, mode="circular" + ) dcy = dist_conv(dcx_periodic) - for i in range(0, ndims): - if i != shard_dim: - crop_amount = (kernel_size - 1 - padding) * (stride - 1) - dcy = dcy.narrow(i + 2, crop_amount, dcy.shape[i + 2] - 2 * crop_amount) - dcy_merge = all_gather_vlen(dcy.contiguous(), dim=(parallel_strategy.shard_dim)) - dc_loss = dcy.sum() - dist.all_reduce(dc_loss) + dcy_merge = dcy.to_replicate() + dc_loss = dcy.to_ddp().square().mean() dc_loss.backward() x_grad = dcx.grad.to_replicate() dc_conv_grad = conv.weight.grad diff --git a/tests/test_ddp_with_distconv.py b/tests/test_ddp_with_distconv.py index f683d21..189c38e 100644 --- a/tests/test_ddp_with_distconv.py +++ b/tests/test_ddp_with_distconv.py @@ -2,11 +2,11 @@ import torch import torch.distributed as dist import torch.nn as nn +from distconv import DCTensor, DistConvDDP, ParallelStrategy from torch.distributed.tensor import Replicate, Shard, distribute_tensor from torch.nn.parallel import DistributedDataParallel as DDP -from utils import cleanup_parallel_strategy, fp32_allclose -from distconv import DCTensor, DistConvDDP, ParallelStrategy +from utils import cleanup_parallel_strategy, fp32_allclose @pytest.fixture(scope="module") diff --git a/tests/test_periodic.py b/tests/test_periodic.py index af6e309..9d5dd14 100644 --- a/tests/test_periodic.py +++ b/tests/test_periodic.py @@ -1,11 +1,13 @@ +from math import ceil + import pytest import torch import torch.distributed as dist import torch.nn as nn -from utils import cleanup_parallel_strategy, fp32_allclose - from distconv import DCTensor, DistConvDDP, ParallelStrategy +from utils import cleanup_parallel_strategy, fp32_allclose + def generate_configs(): configs = [] @@ -50,7 +52,7 @@ def test_periodic( conv_kwargs = dict( kernel_size=kernel_size, - padding=kernel_size // 2, + padding=ceil((kernel_size - stride) / 2), bias=False, stride=stride, padding_mode="circular", diff --git a/tests/test_strides.py b/tests/test_strides.py index efa82a6..d9e421e 100644 --- a/tests/test_strides.py +++ b/tests/test_strides.py @@ -2,10 +2,10 @@ import torch import torch.distributed as dist import torch.nn as nn -from utils import cleanup_parallel_strategy, fp32_allclose - from distconv import DCTensor, DistConvDDP, ParallelStrategy +from utils import cleanup_parallel_strategy, fp32_allclose + @pytest.fixture(scope="module") def parallel_strategy(device: torch.device): diff --git a/tests/utils.py b/tests/utils.py index 891254b..4e6fb39 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,5 @@ import torch import torch.distributed as dist - from distconv import ParallelStrategy From 773a5c7eb115354a28faf1c859b47886fff8e7c9 Mon Sep 17 00:00:00 2001 From: Josh Williams Date: Wed, 25 Mar 2026 11:47:12 +0000 Subject: [PATCH 4/5] Remove unused gather function in convtranspose test --- tests/test_convtranspose.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tests/test_convtranspose.py b/tests/test_convtranspose.py index f353955..54b5cce 100644 --- a/tests/test_convtranspose.py +++ b/tests/test_convtranspose.py @@ -7,25 +7,6 @@ from utils import cleanup_parallel_strategy, fp32_allclose -def all_gather_vlen(tensor: torch.Tensor, group=None, dim=0) -> list[torch.Tensor]: - """Gather tensors with the same number of dimensions but different lengths. - - Credit: https://stackoverflow.com/a/78934638 - """ - world_size = dist.get_world_size(group=group) - # Gather lengths first - shape = torch.as_tensor(tensor.shape, device=tensor.device) - shapes = [torch.empty_like(shape) for _ in range(world_size)] - dist.all_gather(shapes, shape, group=group) - # Gather data - inputs = [tensor] * world_size - outputs = [ - torch.empty(*_shape, dtype=tensor.dtype, device=tensor.device) - for _shape in shapes - ] - dist.all_to_all(outputs, inputs, group=group) - return torch.cat(outputs, dim=dim) - @pytest.fixture(scope="module") def parallel_strategy(device: torch.device): From f79e4634742eb36c6ef0d94006a8416f3fcc99e2 Mon Sep 17 00:00:00 2001 From: Josh Williams Date: Wed, 25 Mar 2026 11:51:30 +0000 Subject: [PATCH 5/5] Reformat test imports --- tests/test_basic.py | 4 ++-- tests/test_convtranspose.py | 4 +--- tests/test_ddp_with_distconv.py | 4 ++-- tests/test_periodic.py | 4 ++-- tests/test_strides.py | 4 ++-- tests/utils.py | 1 + 6 files changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/test_basic.py b/tests/test_basic.py index 9322cc1..d0748f5 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -2,10 +2,10 @@ import torch import torch.distributed as dist import torch.nn as nn -from distconv import DCTensor, DistConvDDP, ParallelStrategy - from utils import cleanup_parallel_strategy, fp32_allclose +from distconv import DCTensor, DistConvDDP, ParallelStrategy + @pytest.fixture(scope="module") def parallel_strategy(device: torch.device): diff --git a/tests/test_convtranspose.py b/tests/test_convtranspose.py index 54b5cce..202a10d 100644 --- a/tests/test_convtranspose.py +++ b/tests/test_convtranspose.py @@ -1,11 +1,9 @@ import pytest import torch -import torch.distributed as dist import torch.nn as nn -from distconv import DCTensor, DistConvDDP, ParallelStrategy - from utils import cleanup_parallel_strategy, fp32_allclose +from distconv import DCTensor, DistConvDDP, ParallelStrategy @pytest.fixture(scope="module") diff --git a/tests/test_ddp_with_distconv.py b/tests/test_ddp_with_distconv.py index 189c38e..f683d21 100644 --- a/tests/test_ddp_with_distconv.py +++ b/tests/test_ddp_with_distconv.py @@ -2,12 +2,12 @@ import torch import torch.distributed as dist import torch.nn as nn -from distconv import DCTensor, DistConvDDP, ParallelStrategy from torch.distributed.tensor import Replicate, Shard, distribute_tensor from torch.nn.parallel import DistributedDataParallel as DDP - from utils import cleanup_parallel_strategy, fp32_allclose +from distconv import DCTensor, DistConvDDP, ParallelStrategy + @pytest.fixture(scope="module") def parallel_strategy(device: torch.device): diff --git a/tests/test_periodic.py b/tests/test_periodic.py index 9d5dd14..9e13a88 100644 --- a/tests/test_periodic.py +++ b/tests/test_periodic.py @@ -4,10 +4,10 @@ import torch import torch.distributed as dist import torch.nn as nn -from distconv import DCTensor, DistConvDDP, ParallelStrategy - from utils import cleanup_parallel_strategy, fp32_allclose +from distconv import DCTensor, DistConvDDP, ParallelStrategy + def generate_configs(): configs = [] diff --git a/tests/test_strides.py b/tests/test_strides.py index d9e421e..efa82a6 100644 --- a/tests/test_strides.py +++ b/tests/test_strides.py @@ -2,10 +2,10 @@ import torch import torch.distributed as dist import torch.nn as nn -from distconv import DCTensor, DistConvDDP, ParallelStrategy - from utils import cleanup_parallel_strategy, fp32_allclose +from distconv import DCTensor, DistConvDDP, ParallelStrategy + @pytest.fixture(scope="module") def parallel_strategy(device: torch.device): diff --git a/tests/utils.py b/tests/utils.py index 4e6fb39..891254b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,6 @@ import torch import torch.distributed as dist + from distconv import ParallelStrategy