Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 87 additions & 33 deletions distconv/distconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def check_is_distconv_supported(
stride: List[int],
padding: List[int],
dilation: List[int],
transpose: bool,
output_padding: List[int],
) -> None:
"""
Check if the distributed convolution is supported with the given parameters.
Expand All @@ -152,31 +154,42 @@ def check_is_distconv_supported(
stride (List[int]): The stride of the convolution.
padding (List[int]): The padding added to the input tensor.
dilation (List[int]): The dilation applied to the kernel.
transpose (bool): Is transposed convolution.
dilation (List[int]): The output padding for transposed convolution.

Raises:
Exception: If dilation is not 1.
Exception: If input size is not divisible by stride.
Exception: If kernel size is odd and padding is not equivalent to "same".
Exception: If kernel size is even and padding is not zero.
Exception: If kernel size is even and stride is not divisible by kernel size.
Exception: If local input size is not equal to stride times output size.
Exception: If local output size is not equal to stride times input size for transposed convolution.
"""
shard_dim = tensor_shard_dim - 2
kernel_size = weight.size(tensor_shard_dim)
if dilation[shard_dim] != 1:
raise Exception("DistConv: dilation must be 1")
if tensor.size(tensor_shard_dim) % stride[shard_dim] != 0:
raise Exception("DistConv: input size must be divisible by stride")
if kernel_size % 2 == 1:
if (kernel_size // 2) != padding[shard_dim]:

input_size = tensor.size(tensor_shard_dim)

if not transpose:
output_size = (input_size + 2 * padding[shard_dim] - kernel_size) // stride[
shard_dim
] + 1

if output_size * stride[shard_dim] != input_size:
raise Exception(
'DistConv: when kernel size is odd, padding must be equivalent to "same"'
"DistConv: The input size along the shard dimension must equal the stride times the output size for the local tensors.\n"
+ "This indicates incompatible kernel size, stride, and/or padding for the given input shape and parallel strategy."
)
else:
if padding[shard_dim] != 0:
raise Exception("DistConv: when kernel size is even, padding must be zero")
if stride[shard_dim] % kernel_size != 0:
output_size = (
(input_size - 1) * stride[shard_dim]
- 2 * padding[shard_dim]
+ kernel_size
+ output_padding[shard_dim]
)

if output_size != input_size * stride[shard_dim]:
raise Exception(
"DistConv: when kernel size is even, stride must be divisble by kernel size"
"DistConv: The output size along the shard dimension must equal the stride times the input size for the local tensors.\n"
+ "This indicates incompatible kernel size, stride, padding, and/or output padding for the given input shape and parallel strategy."
)


Expand Down Expand Up @@ -383,18 +396,25 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor":
args = list(args)

# Unpack the necessary arguments
tensor, weight, bias, stride, padding, dilation = args[:6]
tensor, weight, bias, stride, padding, dilation, transpose, output_padding = args[
:8
]

# Extract the parallel strategy and shard dimension from the input tensor
parallel_strategy = tensor._parallel_strategy
shard_dim = parallel_strategy.shard_dim
is_periodic = tensor._is_periodic
for i, shard_dim_i in enumerate(shard_dim):
if is_periodic[i]:
assert padding[shard_dim_i - 2] == 0, (
"Cannot zero-pad a tensor marked for periodic padding on the shard dimension"
)
padding[shard_dim_i - 2] = tensor._periodic_shard_padding[i]
if transpose:
padding[shard_dim_i - 2] -= (
stride[shard_dim_i - 2] * tensor._periodic_shard_padding[i]
)
else:
assert padding[shard_dim_i - 2] == 0, (
"Cannot zero-pad a tensor marked for periodic padding on the shard dimension"
)
padding[shard_dim_i - 2] = tensor._periodic_shard_padding[i]

# Unwrap the underlying tensor from the DCTensor
torch_tensor = tensor._tensor
Expand All @@ -404,7 +424,14 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor":
halo_sizes = []
for i, shard_dim_i in enumerate(shard_dim):
check_is_distconv_supported(
shard_dim_i, torch_tensor, weight, stride, padding, dilation
shard_dim_i,
torch_tensor,
weight,
stride,
padding,
dilation,
transpose,
output_padding,
)

# Determine the halo size for halo exchange
Expand All @@ -420,10 +447,14 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor":
# Save the tensor with its halo for the backward pass.
tensor._tensor_with_halo = tensor_with_halo

if transpose:
padding[shard_dim_i - 2] += stride[shard_dim_i - 2] * halo_size
else:
padding[shard_dim_i - 2] = 0
# Update the arguments with the tensor including halos and adjusted padding
args[0] = tensor_with_halo
padding[shard_dim_i - 2] = 0
args[4] = padding
args[7] = output_padding

tensor._tensor = tensor_with_halo
for i, shard_dim_i in enumerate(shard_dim):
Expand Down Expand Up @@ -456,20 +487,33 @@ def distconv_backward(
args = list(args)

# Unpack the necessary arguments
grad_out_tensor, input_tensor, weight, bias_size, stride, padding, dilation = args[
:7
]
(
grad_out_tensor,
input_tensor,
weight,
bias_size,
stride,
padding,
dilation,
transpose,
output_padding,
) = args[:9]

# Extract the parallel strategy and shard dimension from the gradient output tensor
parallel_strategy = grad_out_tensor._parallel_strategy
shard_dim = parallel_strategy.shard_dim
is_periodic = input_tensor._is_periodic
for i, shard_dim_i in enumerate(shard_dim):
if is_periodic[i]:
assert padding[shard_dim_i - 2] == 0, (
"Cannot zero-pad a tensor marked for periodic padding on the shard dimension"
)
padding[shard_dim_i - 2] = input_tensor._periodic_shard_padding[i]
if transpose:
padding[shard_dim_i - 2] -= (
stride[shard_dim_i - 2] * input_tensor._periodic_shard_padding[i]
)
else:
assert padding[shard_dim_i - 2] == 0, (
"Cannot zero-pad a tensor marked for periodic padding on the shard dimension"
)
padding[shard_dim_i - 2] = input_tensor._periodic_shard_padding[i]

# Unwrap the underlying tensors from the DCTensors
grad_out_tensor = grad_out_tensor._tensor
Expand All @@ -478,15 +522,24 @@ def distconv_backward(
# Check if the distributed convolution is supported with the given parameters
halo_sizes = []
for i, shard_dim_i in enumerate(shard_dim):
check_is_distconv_supported(
shard_dim_i, input_torch_tensor, weight, stride, padding, dilation
)

# Determine the halo size for halo exchange
kernel_size = weight.size(shard_dim_i)
halo_size = kernel_size // 2 if (kernel_size % 2 == 1) else 0
halo_sizes.append(halo_size)
padding[shard_dim_i - 2] = 0
check_is_distconv_supported(
shard_dim_i,
input_torch_tensor,
weight,
stride,
padding,
dilation,
transpose,
output_padding,
)
if transpose:
padding[shard_dim_i - 2] += stride[shard_dim_i - 2] * halo_size
else:
padding[shard_dim_i - 2] = 0

# Get the input tensor including halos if available, otherwise perform forward halo exchange
if input_tensor._tensor_with_halo is not None:
Expand All @@ -506,6 +559,7 @@ def distconv_backward(
args[0] = grad_out_tensor
args[1] = input_tensor_with_halo
args[5] = padding
args[8] = output_padding

# Perform the backward convolution operation
grad_in_tensor, grad_weight, grad_bias = func(*args, **kwargs)
Expand Down
170 changes: 170 additions & 0 deletions tests/test_convtranspose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import pytest
import torch
import torch.nn as nn
from utils import cleanup_parallel_strategy, fp32_allclose

from distconv import DCTensor, DistConvDDP, ParallelStrategy


@pytest.fixture(scope="module")
def parallel_strategy(device: torch.device):
ps = ParallelStrategy(num_shards=4, device_type=device.type)
yield ps
cleanup_parallel_strategy(ps)


def find_padding(kernel_size, stride=1, explicit_padding=False):
ep = kernel_size // 2 if explicit_padding else 0
pad = (kernel_size + 2 * ep * stride - 1) // 2
out_pad = stride - 1
if explicit_padding:
return pad, out_pad, ep
return pad, out_pad


def generate_configs():
configs = []
for ndims in [1, 2, 3]:
for shard_dim in range(ndims):
for kernel_size in [1, 3, 5]:
for stride in [1, 2, 4]:
configs.append((ndims, shard_dim, kernel_size, stride))

return "ndims,shard_dim,kernel_size,stride", configs


@pytest.mark.parametrize(*generate_configs())
def test_transposeconv_zerospadding(
parallel_strategy: ParallelStrategy,
ndims: int,
shard_dim: int,
kernel_size: int,
stride: int,
device: torch.device,
):
"""
Test distributed convolution with different number of dimensions, kernel sizes, and strides.
Checks the output and gradients of the distributed convolution against the non-distributed
convolution.

Args:
parallel_strategy (ParallelStrategy): Parallel strategy for the distributed convolution.
ndims (int): Number of dimensions for the convolution (1, 2, or 3).
shard_dim (int): Dimension along which the tensor is sharded.
kernel_size (int): Size of the convolution kernel.
stride (int): Stride of the convolution.
device (torch.device): Torch device to run test with.
"""
# Set the shard dimension for the parallel strategy
parallel_strategy.shard_dim = 2 + shard_dim
padding, output_padding = find_padding(kernel_size, stride)

# Initialize the input tensor and convolution layer
shape = [1, 4] + [64] * ndims
x = torch.randn(*shape, device=device, requires_grad=True)
conv_class = getattr(nn, f"ConvTranspose{ndims}d")
conv = conv_class(
4,
8,
kernel_size=kernel_size,
padding=padding,
stride=stride,
output_padding=output_padding,
).to(device)

# Perform forward and backward pass for reference (non-distributed) convolution
conv.zero_grad()
ref_y = conv(x)
ref_y.square().mean().backward()
ref_x_grad = x.grad
ref_conv_grad = conv.weight.grad

# Perform forward and backward pass for distributed convolution
conv.zero_grad()
dist_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy)
dcx = DCTensor.distribute(x, parallel_strategy)
dcy = dist_conv(dcx)
dcy_merge = dcy.to_replicate()
dc_loss = dcy.to_ddp().square().mean()
dc_loss.backward()
x_grad = dcx.grad.to_replicate()
dc_conv_grad = conv.weight.grad

assert fp32_allclose(ref_y, dcy_merge)
assert fp32_allclose(ref_x_grad, x_grad)
assert fp32_allclose(ref_conv_grad, dc_conv_grad)


@pytest.mark.parametrize(*generate_configs())
def test_transposeconv_circularpadding(
parallel_strategy: ParallelStrategy,
ndims: int,
shard_dim: int,
kernel_size: int,
stride: int,
device: torch.device,
):
"""
Test distributed convolution with different number of dimensions, kernel sizes, and strides.
Checks the output and gradients of the distributed convolution against the non-distributed
convolution.

Args:
parallel_strategy (ParallelStrategy): Parallel strategy for the distributed convolution.
ndims (int): Number of dimensions for the convolution (1, 2, or 3).
shard_dim (int): Dimension along which the tensor is sharded.
kernel_size (int): Size of the convolution kernel.
stride (int): Stride of the convolution.
device (torch.device): Torch device to run test with.
"""
# Set the shard dimension for the parallel strategy
parallel_strategy.shard_dim = 2 + shard_dim
padding, output_padding, explicit_padding = find_padding(
kernel_size, stride, explicit_padding=True
)

# Initialize the input tensor and convolution layer
shape = [1, 4] + [64] * ndims
x = torch.randn(*shape, device=device, requires_grad=True)

conv_kwargs = dict(
kernel_size=kernel_size, stride=stride, output_padding=output_padding
)

# set periodic padding case for reference
explicit_padding = [explicit_padding, explicit_padding] * ndims
x_periodic = torch.nn.functional.pad(input=x, pad=explicit_padding, mode="circular")

conv_class = getattr(nn, f"ConvTranspose{ndims}d")
conv = (
conv_class(4, 8, padding=padding, **conv_kwargs)
.to(device)
.requires_grad_(False)
)
conv.requires_grad_(True)

# Perform forward and backward pass for reference (non-distributed) convolution
conv.zero_grad()
ref_y = conv(x_periodic)
ref_y.square().mean().backward()
ref_x_grad = x.grad
ref_conv_grad = conv.weight.grad

# Perform forward and backward pass for distributed convolution
conv.zero_grad()
dist_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy)
dcx = DCTensor.distribute(x, parallel_strategy)
dcx_periodic = torch.nn.functional.pad(
input=dcx, pad=explicit_padding, mode="circular"
)
dcy = dist_conv(dcx_periodic)
dcy_merge = dcy.to_replicate()
dc_loss = dcy.to_ddp().square().mean()
dc_loss.backward()
x_grad = dcx.grad.to_replicate()
dc_conv_grad = conv.weight.grad

# Validate the results
assert fp32_allclose(ref_y, dcy_merge)
assert fp32_allclose(ref_x_grad, x_grad)
assert fp32_allclose(ref_conv_grad, dc_conv_grad)
4 changes: 3 additions & 1 deletion tests/test_periodic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from math import ceil

import pytest
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -50,7 +52,7 @@ def test_periodic(

conv_kwargs = dict(
kernel_size=kernel_size,
padding=kernel_size // 2,
padding=ceil((kernel_size - stride) / 2),
bias=False,
stride=stride,
padding_mode="circular",
Expand Down