Skip to content

vllm.distributed.device_communicators.pynccl

_NCCL_SYMM_OPS_REGISTERED module-attribute

_NCCL_SYMM_OPS_REGISTERED = False

logger module-attribute

logger = init_logger(__name__)

PyNcclCommunicator

Source code in vllm/distributed/device_communicators/pynccl.py
class PyNcclCommunicator:

    def __init__(
        self,
        group: Union[ProcessGroup, StatelessProcessGroup],
        device: Union[int, str, torch.device],
        library_path: Optional[str] = None,
    ):
        """
        Args:
            group: the process group to work on. If None, it will use the
                default process group.
            device: the device to bind the PyNcclCommunicator to. If None,
                it will be bound to f"cuda:{local_rank}".
            library_path: the path to the NCCL library. If None, it will
                use the default library path.
        It is the caller's responsibility to make sure each communicator
        is bind to a unique device.
        """
        if not isinstance(group, StatelessProcessGroup):
            assert dist.is_initialized()
            assert dist.get_backend(group) != dist.Backend.NCCL, (
                "PyNcclCommunicator should be attached to a non-NCCL group.")
            # note: this rank is the rank in the group
            self.rank = dist.get_rank(group)
            self.world_size = dist.get_world_size(group)
        else:
            self.rank = group.rank
            self.world_size = group.world_size

        self.group = group

        # if world_size == 1, no need to create communicator
        if self.world_size == 1 or envs.VLLM_DISABLE_PYNCCL:
            self.available = False
            self.disabled = True
            return
        try:
            self.nccl = NCCLLibrary(library_path)
        except Exception:
            # disable because of missing NCCL library
            # e.g. in a non-GPU environment
            self.available = False
            self.disabled = True
            return

        self.available = True
        self.disabled = False

        self.nccl_version = self.nccl.ncclGetRawVersion()
        logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())

        if self.rank == 0:
            # get the unique id from NCCL
            self.unique_id = self.nccl.ncclGetUniqueId()
        else:
            # construct an empty unique id
            self.unique_id = ncclUniqueId()

        if not isinstance(group, StatelessProcessGroup):
            tensor = torch.ByteTensor(list(self.unique_id.internal))
            ranks = dist.get_process_group_ranks(group)
            # arg `src` in `broadcast` is the global rank
            dist.broadcast(tensor, src=ranks[0], group=group)
            byte_list = tensor.tolist()
            for i, byte in enumerate(byte_list):
                self.unique_id.internal[i] = byte
        else:
            self.unique_id = group.broadcast_obj(self.unique_id, src=0)
        if isinstance(device, int):
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        # now `device` is a `torch.device` object
        assert isinstance(device, torch.device)
        self.device = device
        # nccl communicator and stream will use this device
        # `torch.cuda.device` is a context manager that changes the
        # current cuda device to the specified one
        with torch.cuda.device(device):
            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
                self.world_size, self.unique_id, self.rank)

            stream = current_stream()
            # A small all_reduce for warmup.
            data = torch.zeros(1, device=device)
            self.all_reduce(data)
            stream.synchronize()
            del data

    def all_reduce(self,
                   in_tensor: torch.Tensor,
                   out_tensor: torch.Tensor = None,
                   op: ReduceOp = ReduceOp.SUM,
                   stream=None) -> torch.Tensor:
        if self.disabled:
            return None
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert in_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {in_tensor.device}")

        if out_tensor is None:
            out_tensor = torch.empty_like(in_tensor)

        if stream is None:
            stream = current_stream()
        self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
                                buffer_type(out_tensor.data_ptr()),
                                in_tensor.numel(),
                                ncclDataTypeEnum.from_torch(in_tensor.dtype),
                                ncclRedOpTypeEnum.from_torch(op), self.comm,
                                cudaStream_t(stream.cuda_stream))
        return out_tensor

    def all_gather(self,
                   output_tensor: torch.Tensor,
                   input_tensor: torch.Tensor,
                   stream=None):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
            stream = current_stream()
        self.nccl.ncclAllGather(
            buffer_type(input_tensor.data_ptr()),
            buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
            cudaStream_t(stream.cuda_stream))

    def all_gatherv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
            stream = current_stream()
        assert output_tensor.shape[0] == sum(sizes)
        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
            dst_slice = output_tensor[split_offset:split_offset + split_size]
            self.nccl.ncclBroadcast(
                buffer_type(input_tensor.data_ptr()),
                buffer_type(dst_slice.data_ptr()),
                dst_slice.numel(),
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
                root,
                self.comm,
                cudaStream_t(stream.cuda_stream),
            )
            split_offset += split_size
        self.nccl.ncclGroupEnd()

    def reduce_scatter(self,
                       output_tensor: torch.Tensor,
                       input_tensor: torch.Tensor,
                       op: ReduceOp = ReduceOp.SUM,
                       stream=None):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
            stream = current_stream()
        self.nccl.ncclReduceScatter(
            buffer_type(input_tensor.data_ptr()),
            buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
            ncclRedOpTypeEnum.from_torch(op), self.comm,
            cudaStream_t(stream.cuda_stream))

    def reduce_scatterv(
        self,
        output_tensor: torch.Tensor,
        input_tensor: torch.Tensor,
        sizes: list[int],
        op: ReduceOp = ReduceOp.SUM,
        stream=None,
    ):
        if self.disabled:
            return
        # nccl communicator created on a specific device
        # will only work on tensors on the same device
        # otherwise it will cause "illegal memory access"
        assert input_tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {input_tensor.device}")
        if stream is None:
            stream = current_stream()

        split_offset = 0
        self.nccl.ncclGroupStart()
        for root, split_size in enumerate(sizes):
            chunk = input_tensor[split_offset:split_offset + split_size, ...]
            self.nccl.ncclReduce(
                buffer_type(chunk.data_ptr()),
                buffer_type(output_tensor.data_ptr()), chunk.numel(),
                ncclDataTypeEnum.from_torch(input_tensor.dtype),
                ncclRedOpTypeEnum.from_torch(op), root, self.comm,
                cudaStream_t(stream.cuda_stream))
            split_offset += split_size
        self.nccl.ncclGroupEnd()

    def send(self, tensor: torch.Tensor, dst: int, stream=None):
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
            stream = current_stream()
        self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), dst,
                           self.comm, cudaStream_t(stream.cuda_stream))

    def recv(self, tensor: torch.Tensor, src: int, stream=None):
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
            stream = current_stream()
        self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
                           ncclDataTypeEnum.from_torch(tensor.dtype), src,
                           self.comm, cudaStream_t(stream.cuda_stream))

    def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
        if self.disabled:
            return
        assert tensor.device == self.device, (
            f"this nccl communicator is created to work on {self.device}, "
            f"but the input tensor is on {tensor.device}")
        if stream is None:
            stream = current_stream()
        if src == self.rank:
            sendbuff = buffer_type(tensor.data_ptr())
            # NCCL requires the sender also to have a receive buffer
            recvbuff = buffer_type(tensor.data_ptr())
        else:
            sendbuff = buffer_type()
            recvbuff = buffer_type(tensor.data_ptr())
        self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
                                ncclDataTypeEnum.from_torch(tensor.dtype), src,
                                self.comm, cudaStream_t(stream.cuda_stream))

    def group_start(self):
        self.nccl.ncclGroupStart()

    def group_end(self):
        self.nccl.ncclGroupEnd()

    def register_comm_window(self, tensor: torch.Tensor):
        return self.nccl.ncclCommWindowRegister(
            self.comm,
            buffer_type(tensor.data_ptr()),
            tensor.numel() * tensor.element_size(),
            1,
        )

    def register_comm_window_raw(self, ptr: int, size: int):
        return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr),
                                                size, 1)

    def deregister_comm_window(self, window):
        return self.nccl.ncclCommWindowDeregister(self.comm, window)

available instance-attribute

available = True

comm instance-attribute

comm: ncclComm_t = ncclCommInitRank(
    world_size, unique_id, rank
)

device instance-attribute

device = device

disabled instance-attribute

disabled = False

group instance-attribute

group = group

nccl instance-attribute

nccl = NCCLLibrary(library_path)

nccl_version instance-attribute

nccl_version = ncclGetRawVersion()

rank instance-attribute

rank = get_rank(group)

unique_id instance-attribute

unique_id = ncclGetUniqueId()

world_size instance-attribute

world_size = get_world_size(group)

__init__

__init__(
    group: Union[ProcessGroup, StatelessProcessGroup],
    device: Union[int, str, device],
    library_path: Optional[str] = None,
)

Parameters:

Name Type Description Default
group Union[ProcessGroup, StatelessProcessGroup]

the process group to work on. If None, it will use the default process group.

required
device Union[int, str, device]

the device to bind the PyNcclCommunicator to. If None, it will be bound to f"cuda:{local_rank}".

required
library_path Optional[str]

the path to the NCCL library. If None, it will use the default library path.

None

It is the caller's responsibility to make sure each communicator is bind to a unique device.

Source code in vllm/distributed/device_communicators/pynccl.py
def __init__(
    self,
    group: Union[ProcessGroup, StatelessProcessGroup],
    device: Union[int, str, torch.device],
    library_path: Optional[str] = None,
):
    """
    Args:
        group: the process group to work on. If None, it will use the
            default process group.
        device: the device to bind the PyNcclCommunicator to. If None,
            it will be bound to f"cuda:{local_rank}".
        library_path: the path to the NCCL library. If None, it will
            use the default library path.
    It is the caller's responsibility to make sure each communicator
    is bind to a unique device.
    """
    if not isinstance(group, StatelessProcessGroup):
        assert dist.is_initialized()
        assert dist.get_backend(group) != dist.Backend.NCCL, (
            "PyNcclCommunicator should be attached to a non-NCCL group.")
        # note: this rank is the rank in the group
        self.rank = dist.get_rank(group)
        self.world_size = dist.get_world_size(group)
    else:
        self.rank = group.rank
        self.world_size = group.world_size

    self.group = group

    # if world_size == 1, no need to create communicator
    if self.world_size == 1 or envs.VLLM_DISABLE_PYNCCL:
        self.available = False
        self.disabled = True
        return
    try:
        self.nccl = NCCLLibrary(library_path)
    except Exception:
        # disable because of missing NCCL library
        # e.g. in a non-GPU environment
        self.available = False
        self.disabled = True
        return

    self.available = True
    self.disabled = False

    self.nccl_version = self.nccl.ncclGetRawVersion()
    logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())

    if self.rank == 0:
        # get the unique id from NCCL
        self.unique_id = self.nccl.ncclGetUniqueId()
    else:
        # construct an empty unique id
        self.unique_id = ncclUniqueId()

    if not isinstance(group, StatelessProcessGroup):
        tensor = torch.ByteTensor(list(self.unique_id.internal))
        ranks = dist.get_process_group_ranks(group)
        # arg `src` in `broadcast` is the global rank
        dist.broadcast(tensor, src=ranks[0], group=group)
        byte_list = tensor.tolist()
        for i, byte in enumerate(byte_list):
            self.unique_id.internal[i] = byte
    else:
        self.unique_id = group.broadcast_obj(self.unique_id, src=0)
    if isinstance(device, int):
        device = torch.device(f"cuda:{device}")
    elif isinstance(device, str):
        device = torch.device(device)
    # now `device` is a `torch.device` object
    assert isinstance(device, torch.device)
    self.device = device
    # nccl communicator and stream will use this device
    # `torch.cuda.device` is a context manager that changes the
    # current cuda device to the specified one
    with torch.cuda.device(device):
        self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
            self.world_size, self.unique_id, self.rank)

        stream = current_stream()
        # A small all_reduce for warmup.
        data = torch.zeros(1, device=device)
        self.all_reduce(data)
        stream.synchronize()
        del data

all_gather

all_gather(
    output_tensor: Tensor, input_tensor: Tensor, stream=None
)
Source code in vllm/distributed/device_communicators/pynccl.py
def all_gather(self,
               output_tensor: torch.Tensor,
               input_tensor: torch.Tensor,
               stream=None):
    if self.disabled:
        return
    # nccl communicator created on a specific device
    # will only work on tensors on the same device
    # otherwise it will cause "illegal memory access"
    assert input_tensor.device == self.device, (
        f"this nccl communicator is created to work on {self.device}, "
        f"but the input tensor is on {input_tensor.device}")
    if stream is None:
        stream = current_stream()
    self.nccl.ncclAllGather(
        buffer_type(input_tensor.data_ptr()),
        buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
        ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
        cudaStream_t(stream.cuda_stream))

all_gatherv

all_gatherv(
    output_tensor: Tensor,
    input_tensor: Tensor,
    sizes: list[int],
    stream=None,
)
Source code in vllm/distributed/device_communicators/pynccl.py
def all_gatherv(
    self,
    output_tensor: torch.Tensor,
    input_tensor: torch.Tensor,
    sizes: list[int],
    stream=None,
):
    if self.disabled:
        return
    # nccl communicator created on a specific device
    # will only work on tensors on the same device
    # otherwise it will cause "illegal memory access"
    assert input_tensor.device == self.device, (
        f"this nccl communicator is created to work on {self.device}, "
        f"but the input tensor is on {input_tensor.device}")
    if stream is None:
        stream = current_stream()
    assert output_tensor.shape[0] == sum(sizes)
    split_offset = 0
    self.nccl.ncclGroupStart()
    for root, split_size in enumerate(sizes):
        dst_slice = output_tensor[split_offset:split_offset + split_size]
        self.nccl.ncclBroadcast(
            buffer_type(input_tensor.data_ptr()),
            buffer_type(dst_slice.data_ptr()),
            dst_slice.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
            root,
            self.comm,
            cudaStream_t(stream.cuda_stream),
        )
        split_offset += split_size
    self.nccl.ncclGroupEnd()

all_reduce

all_reduce(
    in_tensor: Tensor,
    out_tensor: Tensor = None,
    op: ReduceOp = SUM,
    stream=None,
) -> Tensor
Source code in vllm/distributed/device_communicators/pynccl.py
def all_reduce(self,
               in_tensor: torch.Tensor,
               out_tensor: torch.Tensor = None,
               op: ReduceOp = ReduceOp.SUM,
               stream=None) -> torch.Tensor:
    if self.disabled:
        return None
    # nccl communicator created on a specific device
    # will only work on tensors on the same device
    # otherwise it will cause "illegal memory access"
    assert in_tensor.device == self.device, (
        f"this nccl communicator is created to work on {self.device}, "
        f"but the input tensor is on {in_tensor.device}")

    if out_tensor is None:
        out_tensor = torch.empty_like(in_tensor)

    if stream is None:
        stream = current_stream()
    self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
                            buffer_type(out_tensor.data_ptr()),
                            in_tensor.numel(),
                            ncclDataTypeEnum.from_torch(in_tensor.dtype),
                            ncclRedOpTypeEnum.from_torch(op), self.comm,
                            cudaStream_t(stream.cuda_stream))
    return out_tensor

broadcast

broadcast(tensor: Tensor, src: int, stream=None)
Source code in vllm/distributed/device_communicators/pynccl.py
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
    if self.disabled:
        return
    assert tensor.device == self.device, (
        f"this nccl communicator is created to work on {self.device}, "
        f"but the input tensor is on {tensor.device}")
    if stream is None:
        stream = current_stream()
    if src == self.rank:
        sendbuff = buffer_type(tensor.data_ptr())
        # NCCL requires the sender also to have a receive buffer
        recvbuff = buffer_type(tensor.data_ptr())
    else:
        sendbuff = buffer_type()
        recvbuff = buffer_type(tensor.data_ptr())
    self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
                            ncclDataTypeEnum.from_torch(tensor.dtype), src,
                            self.comm, cudaStream_t(stream.cuda_stream))

deregister_comm_window

deregister_comm_window(window)
Source code in vllm/distributed/device_communicators/pynccl.py
def deregister_comm_window(self, window):
    return self.nccl.ncclCommWindowDeregister(self.comm, window)

group_end

group_end()
Source code in vllm/distributed/device_communicators/pynccl.py
def group_end(self):
    self.nccl.ncclGroupEnd()

group_start

group_start()
Source code in vllm/distributed/device_communicators/pynccl.py
def group_start(self):
    self.nccl.ncclGroupStart()

recv

recv(tensor: Tensor, src: int, stream=None)
Source code in vllm/distributed/device_communicators/pynccl.py
def recv(self, tensor: torch.Tensor, src: int, stream=None):
    if self.disabled:
        return
    assert tensor.device == self.device, (
        f"this nccl communicator is created to work on {self.device}, "
        f"but the input tensor is on {tensor.device}")
    if stream is None:
        stream = current_stream()
    self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
                       ncclDataTypeEnum.from_torch(tensor.dtype), src,
                       self.comm, cudaStream_t(stream.cuda_stream))

reduce_scatter

reduce_scatter(
    output_tensor: Tensor,
    input_tensor: Tensor,
    op: ReduceOp = SUM,
    stream=None,
)
Source code in vllm/distributed/device_communicators/pynccl.py
def reduce_scatter(self,
                   output_tensor: torch.Tensor,
                   input_tensor: torch.Tensor,
                   op: ReduceOp = ReduceOp.SUM,
                   stream=None):
    if self.disabled:
        return
    # nccl communicator created on a specific device
    # will only work on tensors on the same device
    # otherwise it will cause "illegal memory access"
    assert input_tensor.device == self.device, (
        f"this nccl communicator is created to work on {self.device}, "
        f"but the input tensor is on {input_tensor.device}")
    if stream is None:
        stream = current_stream()
    self.nccl.ncclReduceScatter(
        buffer_type(input_tensor.data_ptr()),
        buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
        ncclDataTypeEnum.from_torch(input_tensor.dtype),
        ncclRedOpTypeEnum.from_torch(op), self.comm,
        cudaStream_t(stream.cuda_stream))

reduce_scatterv

reduce_scatterv(
    output_tensor: Tensor,
    input_tensor: Tensor,
    sizes: list[int],
    op: ReduceOp = SUM,
    stream=None,
)
Source code in vllm/distributed/device_communicators/pynccl.py
def reduce_scatterv(
    self,
    output_tensor: torch.Tensor,
    input_tensor: torch.Tensor,
    sizes: list[int],
    op: ReduceOp = ReduceOp.SUM,
    stream=None,
):
    if self.disabled:
        return
    # nccl communicator created on a specific device
    # will only work on tensors on the same device
    # otherwise it will cause "illegal memory access"
    assert input_tensor.device == self.device, (
        f"this nccl communicator is created to work on {self.device}, "
        f"but the input tensor is on {input_tensor.device}")
    if stream is None:
        stream = current_stream()

    split_offset = 0
    self.nccl.ncclGroupStart()
    for root, split_size in enumerate(sizes):
        chunk = input_tensor[split_offset:split_offset + split_size, ...]
        self.nccl.ncclReduce(
            buffer_type(chunk.data_ptr()),
            buffer_type(output_tensor.data_ptr()), chunk.numel(),
            ncclDataTypeEnum.from_torch(input_tensor.dtype),
            ncclRedOpTypeEnum.from_torch(op), root, self.comm,
            cudaStream_t(stream.cuda_stream))
        split_offset += split_size
    self.nccl.ncclGroupEnd()

register_comm_window

register_comm_window(tensor: Tensor)
Source code in vllm/distributed/device_communicators/pynccl.py
def register_comm_window(self, tensor: torch.Tensor):
    return self.nccl.ncclCommWindowRegister(
        self.comm,
        buffer_type(tensor.data_ptr()),
        tensor.numel() * tensor.element_size(),
        1,
    )

register_comm_window_raw

register_comm_window_raw(ptr: int, size: int)
Source code in vllm/distributed/device_communicators/pynccl.py
def register_comm_window_raw(self, ptr: int, size: int):
    return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr),
                                            size, 1)

send

send(tensor: Tensor, dst: int, stream=None)
Source code in vllm/distributed/device_communicators/pynccl.py
def send(self, tensor: torch.Tensor, dst: int, stream=None):
    if self.disabled:
        return
    assert tensor.device == self.device, (
        f"this nccl communicator is created to work on {self.device}, "
        f"but the input tensor is on {tensor.device}")
    if stream is None:
        stream = current_stream()
    self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
                       ncclDataTypeEnum.from_torch(tensor.dtype), dst,
                       self.comm, cudaStream_t(stream.cuda_stream))

register_nccl_symmetric_ops

register_nccl_symmetric_ops(pynccl_comm)
Source code in vllm/distributed/device_communicators/pynccl.py
def register_nccl_symmetric_ops(pynccl_comm):
    from vllm.distributed.device_communicators.pynccl_allocator import (
        nccl_symm_mem_context)
    from vllm.utils import direct_register_custom_op

    global _NCCL_SYMM_OPS_REGISTERED
    if _NCCL_SYMM_OPS_REGISTERED:
        return
    _NCCL_SYMM_OPS_REGISTERED = True

    def all_reduce_symmetric_with_copy_impl(
            input_tensor: torch.Tensor) -> torch.Tensor:
        with nccl_symm_mem_context(pynccl_comm):
            symm_input = torch.empty_like(input_tensor)
            symm_output = torch.empty_like(input_tensor)
        symm_input.copy_(input_tensor)
        symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
        return symm_output

    def all_reduce_symmetric_with_copy_fake(
            input_tensor: torch.Tensor) -> torch.Tensor:
        return torch.empty_like(input_tensor)

    direct_register_custom_op(
        op_name="all_reduce_symmetric_with_copy",
        op_func=all_reduce_symmetric_with_copy_impl,
        fake_impl=all_reduce_symmetric_with_copy_fake,
    )