Skip to content

vllm.distributed.eplb.eplb_communicator

EPLB communicator implementations and factory.

EplbCommunicator

Bases: ABC

Abstract EPLB communicator for expert weight transfers.

Source code in vllm/distributed/eplb/eplb_communicator.py
class EplbCommunicator(ABC):
    """Abstract EPLB communicator for expert weight transfers."""

    @abstractmethod
    def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None:
        pass

    @abstractmethod
    def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None:
        pass

    @abstractmethod
    def execute(self) -> None:
        pass

    @property
    def needs_profile_buffer_reservation(self) -> bool:
        """Whether the profile path must run a dummy collective operation to reserve
        communication buffers."""
        return True

    def set_stream(self, cuda_stream: torch.cuda.Stream | None) -> None:
        self._cuda_stream = cuda_stream

    def _log_initialized(self) -> None:
        if is_local_first_rank():
            logger.info("Initialized EPLB communicator: %s.", self.__class__.__name__)

needs_profile_buffer_reservation property

needs_profile_buffer_reservation: bool

Whether the profile path must run a dummy collective operation to reserve communication buffers.

NixlEplbCommunicator

Bases: EplbCommunicator

EPLB communicator backed by NIXL READ transfers.

Source code in vllm/distributed/eplb/eplb_communicator.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
class NixlEplbCommunicator(EplbCommunicator):
    """EPLB communicator backed by NIXL READ transfers."""

    def __init__(
        self,
        cpu_group: ProcessGroup,
        expert_weights: Sequence[torch.Tensor],
        cuda_stream: torch.cuda.Stream | None = None,
    ) -> None:
        assert expert_weights, "NixlEplbCommunicator requires non-empty expert_weights."
        if NixlWrapper is None:
            raise RuntimeError("NIXL/ RIXL is unavailable.")
        self._cpu_group = cpu_group
        self._cuda_stream = cuda_stream
        self._world_size = cpu_group.size()
        self._rank = cpu_group.rank()
        self._send_tensors: dict[torch.dtype, list[list[torch.Tensor]]] = {}
        self._recv_tensors: dict[torch.dtype, list[list[torch.Tensor]]] = {}
        self._dtypes: list[torch.dtype] = []
        self._device = expert_weights[0].device
        for tensor in expert_weights:
            assert tensor.device == self._device, (
                "All local EPLB tensors are expected to be on the same device: "
                f"expected={self._device}, got={tensor.device}"
            )
            if tensor.dtype not in self._dtypes:
                self._dtypes.append(tensor.dtype)

        config = (
            nixl_agent_config(capture_telemetry=False)
            if nixl_agent_config is not None
            else None
        )
        self._nixl_wrapper = NixlWrapper(self._make_agent_name(), config)
        self._nixl_memory_type = "VRAM"
        self._registered_desc: object | None = None
        self._remote_agents: dict[int, str] = {}
        self._remote_send_meta: dict[int, tuple[int, int, int]] = {}
        self._send_buffer: torch.Tensor = torch.empty(0)
        self._recv_buffer: torch.Tensor = torch.empty(0)
        self._peer_partition_bytes: int = 0
        self._dtype_max_bytes: dict[torch.dtype, int] = {}
        self._cuda_device_id = int(self._device.index or 0)
        self._xfer_cache: dict[tuple[int, int, int], tuple[int, int, int]] = {}
        self._init_step("buffers", self._init_registered_buffers, expert_weights)
        self._init_step("agents", self._init_remote_agents)
        self._init_step("send meta", self._exchange_remote_send_meta)
        self._log_initialized()

    @property
    def needs_profile_buffer_reservation(self) -> bool:
        return False

    @staticmethod
    def _init_step(name: str, fn: object, *args: object, **kwargs: object) -> None:
        try:
            fn(*args, **kwargs)  # type: ignore[operator]
        except Exception as exc:
            raise RuntimeError(f"NIXL EPLB init failed: {name}") from exc

    def _make_agent_name(self) -> str:
        """Build a deployment-unique nixl agent name."""
        pp_size = get_pp_group().world_size
        pp_suffix = f"-pp{get_pp_group().rank_in_group}" if pp_size > 1 else ""
        uid = uuid.uuid4().hex[:8]
        return f"eplb-{self._rank}{pp_suffix}-{uid}"

    def _get_peer_buckets(
        self,
        bucket_map: dict[torch.dtype, list[list[torch.Tensor]]],
        dtype: torch.dtype,
    ) -> list[list[torch.Tensor]]:
        peer_buckets = bucket_map.get(dtype)
        if peer_buckets is None:
            peer_buckets = [[] for _ in range(self._world_size)]
            bucket_map[dtype] = peer_buckets
        return peer_buckets

    def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None:
        assert dst_rank != self._rank, (
            "EPLB communicator should not enqueue same-rank sends: "
            f"rank={self._rank}, dst_rank={dst_rank}"
        )
        self._get_peer_buckets(self._send_tensors, tensor.dtype)[dst_rank].append(
            tensor
        )

    def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None:
        assert src_rank != self._rank, (
            "EPLB communicator should not enqueue same-rank recvs: "
            f"rank={self._rank}, src_rank={src_rank}"
        )
        self._get_peer_buckets(self._recv_tensors, tensor.dtype)[src_rank].append(
            tensor
        )

    def _init_remote_agents(self) -> None:
        local_metadata = self._nixl_wrapper.get_agent_metadata()
        gathered_metadata: list[bytes | None] = [None] * self._world_size
        torch.distributed.all_gather_object(
            gathered_metadata, local_metadata, group=self._cpu_group
        )
        for peer in range(self._world_size):
            if peer == self._rank:
                continue
            peer_metadata = gathered_metadata[peer]
            assert peer_metadata is not None
            self._remote_agents[peer] = self._nixl_wrapper.add_remote_agent(
                peer_metadata
            )

    def _init_registered_buffers(self, expert_weights: Sequence[torch.Tensor]) -> None:
        total_max_bytes = 0
        for dtype in self._dtypes:
            max_numel = max(
                sum(t.numel() for t in expert_weights if t.dtype == dtype), 1
            )
            max_bytes = max_numel * dtype.itemsize
            self._dtype_max_bytes[dtype] = max_bytes
            total_max_bytes += max_bytes

        self._peer_partition_bytes = total_max_bytes

        # The send buffer needs world_size partitions because remote peers
        # READ from fixed offsets (rank * partition_bytes).
        # This allocates world_size * partition_bytes
        # which can cause OOM on large models.
        # TODO(ilmarkov): shrink to const * partition_bytes and execute
        # communication in multiple steps dealing with the worst case.
        send_total_bytes = self._peer_partition_bytes * self._world_size

        self._send_buffer = torch.empty(
            send_total_bytes, device=self._device, dtype=torch.uint8
        )
        self._recv_buffer = torch.empty(
            self._peer_partition_bytes, device=self._device, dtype=torch.uint8
        )

        descs = self._nixl_wrapper.get_reg_descs([self._send_buffer, self._recv_buffer])
        self._nixl_wrapper.register_memory(descs)
        self._registered_desc = descs

    def _exchange_remote_send_meta(self) -> None:
        """Exchange send-buffer metadata so each rank can build dynamic
        descriptors at execute time."""
        local_meta: tuple[int, int, int] = (
            self._send_buffer.data_ptr(),
            self._peer_partition_bytes,
            self._cuda_device_id,
        )
        gathered_meta: list[tuple[int, int, int] | None] = [None] * self._world_size
        torch.distributed.all_gather_object(
            gathered_meta, local_meta, group=self._cpu_group
        )

        for peer in self._remote_agents:
            peer_meta = gathered_meta[peer]
            assert peer_meta is not None
            self._remote_send_meta[peer] = peer_meta

    @staticmethod
    def _pack_send_buffer(
        peer_tensors: list[torch.Tensor],
        send_buffer: torch.Tensor,
        byte_offset: int,
    ) -> int:
        """
        Returns the byte offset after the last written byte.
        """
        for tensor in peer_tensors:
            raw = tensor.reshape(-1).view(torch.uint8)
            if raw.numel() == 0:
                continue
            send_buffer[byte_offset : byte_offset + raw.numel()].copy_(
                raw, non_blocking=True
            )
            byte_offset += raw.numel()
        return byte_offset

    @staticmethod
    def _unpack_recv_buffer(
        recv_buffer: torch.Tensor,
        peer_tensors: list[torch.Tensor],
        byte_offset: int,
    ) -> int:
        """
        Returns the byte offset after the last read byte.
        """
        for tensor in peer_tensors:
            num_bytes = tensor.numel() * tensor.element_size()
            if num_bytes == 0:
                continue
            tensor.reshape(-1).view(torch.uint8).copy_(
                recv_buffer[byte_offset : byte_offset + num_bytes],
                non_blocking=True,
            )
            byte_offset += num_bytes
        return byte_offset

    def _release_all_cached_handles(self) -> None:
        """Best-effort release of every cached dlist and xfer handle."""
        for local_dlist, remote_dlist, xfer in self._xfer_cache.values():
            for release_fn, handle in (
                (self._nixl_wrapper.release_xfer_handle, xfer),
                (self._nixl_wrapper.release_dlist_handle, local_dlist),
                (self._nixl_wrapper.release_dlist_handle, remote_dlist),
            ):
                with contextlib.suppress(Exception):
                    release_fn(handle)
        self._xfer_cache.clear()

    def _wait_for_all_transfers(self, handles: list[int]) -> None:
        pending = set(handles)
        while pending:
            completed: list[int] = []
            for handle in pending:
                state = self._nixl_wrapper.check_xfer_state(handle)
                if state == "DONE":
                    completed.append(handle)
                    continue
                if state != "PROC":
                    raise RuntimeError(f"NIXL transfer failed with state={state}")
            for handle in completed:
                pending.remove(handle)
            if pending:
                time.sleep(0.0005)

    def _get_or_create_xfer(self, src: int, total_bytes: int, recv_offset: int) -> int:
        """Return a cached xfer handle or create and cache a new one."""
        key = (src, total_bytes, recv_offset)
        cached = self._xfer_cache.get(key)
        if cached is not None:
            return cached[2]

        recv_base = self._recv_buffer.data_ptr()
        local_desc = self._nixl_wrapper.get_xfer_descs(
            [
                (
                    recv_base + recv_offset,
                    total_bytes,
                    self._cuda_device_id,
                )
            ],
            self._nixl_memory_type,
        )
        local_handle = self._nixl_wrapper.prep_xfer_dlist(
            "NIXL_INIT_AGENT",
            local_desc,
        )

        remote_base, remote_part_bytes, remote_dev = self._remote_send_meta[src]
        agent_name = self._remote_agents[src]
        remote_desc = self._nixl_wrapper.get_xfer_descs(
            [
                (
                    remote_base + self._rank * remote_part_bytes,
                    total_bytes,
                    remote_dev,
                )
            ],
            self._nixl_memory_type,
        )
        remote_handle = self._nixl_wrapper.prep_xfer_dlist(
            agent_name,
            remote_desc,
        )

        xfer_handle = self._nixl_wrapper.make_prepped_xfer(
            "READ",
            local_handle,
            [0],
            remote_handle,
            [0],
        )
        self._xfer_cache[key] = (local_handle, remote_handle, xfer_handle)
        return xfer_handle

    def execute(self) -> None:
        xfer_handles: list[int] = []
        try:
            # Phase 1: pack send buffers.
            with torch.cuda.stream(self._cuda_stream):
                for dst in range(self._world_size):
                    byte_offset = dst * self._peer_partition_bytes
                    for dtype in self._dtypes:
                        peer_tensors = self._send_tensors.get(
                            dtype, [[] for _ in range(self._world_size)]
                        )[dst]
                        actual_bytes = sum(
                            t.numel() * t.element_size() for t in peer_tensors
                        )
                        if actual_bytes > self._dtype_max_bytes[dtype]:
                            raise RuntimeError(
                                "NIXL EPLB send overflow for dtype "
                                f"{dtype}: peer={dst}, "
                                f"required={actual_bytes}, "
                                f"capacity={self._dtype_max_bytes[dtype]}"
                            )
                        byte_offset = self._pack_send_buffer(
                            peer_tensors,
                            self._send_buffer,
                            byte_offset,
                        )

            # Ensure all packed data is visible in device memory before pulls.
            if self._cuda_stream is not None:
                self._cuda_stream.synchronize()
            else:
                torch.cuda.current_stream().synchronize()
            # READ is receiver-initiated; synchronize all ranks before transfer.
            # We use monitored_barrier so a rank that crashes or exits early
            # produces a diagnostic timeout instead of a silent hang.
            torch.distributed.monitored_barrier(
                group=self._cpu_group,
                timeout=timedelta(minutes=5),
            )

            # Phase 2: look up or create descriptors and issue all READs.
            # Data from all peers is packed sequentially into the single
            # partition-sized recv buffer at running offsets.
            recv_offsets: dict[int, int] = {}
            recv_offset = 0
            for src in range(self._world_size):
                if src == self._rank:
                    continue
                actual_total_bytes = 0
                for dtype in self._dtypes:
                    peer_tensors = self._recv_tensors.get(
                        dtype, [[] for _ in range(self._world_size)]
                    )[src]
                    actual_total_bytes += sum(
                        t.numel() * t.element_size() for t in peer_tensors
                    )
                if actual_total_bytes == 0:
                    continue

                recv_offsets[src] = recv_offset
                xfer_handle = self._get_or_create_xfer(
                    src, actual_total_bytes, recv_offset
                )
                self._nixl_wrapper.transfer(xfer_handle)
                xfer_handles.append(xfer_handle)
                recv_offset += actual_total_bytes

            # Phase 3: single wait for all in-flight transfers, then unpack.
            self._wait_for_all_transfers(xfer_handles)

            with torch.cuda.stream(self._cuda_stream):
                for src, offset in recv_offsets.items():
                    byte_offset = offset
                    for dtype in self._dtypes:
                        peer_tensors = self._recv_tensors.get(
                            dtype, [[] for _ in range(self._world_size)]
                        )[src]
                        byte_offset = self._unpack_recv_buffer(
                            self._recv_buffer,
                            peer_tensors,
                            byte_offset,
                        )
        except Exception:
            self._release_all_cached_handles()
            raise
        finally:
            self._send_tensors.clear()
            self._recv_tensors.clear()

    def __del__(self) -> None:
        try:
            self._release_all_cached_handles()
            if self._registered_desc is not None:
                self._nixl_wrapper.deregister_memory(self._registered_desc)
                self._registered_desc = None
            for agent_name in self._remote_agents.values():
                self._nixl_wrapper.remove_remote_agent(agent_name)
            self._remote_agents.clear()
        except Exception as e:
            logger.warning("Error during NixlEplbCommunicator cleanup: %s", e)

_exchange_remote_send_meta

_exchange_remote_send_meta() -> None

Exchange send-buffer metadata so each rank can build dynamic descriptors at execute time.

Source code in vllm/distributed/eplb/eplb_communicator.py
def _exchange_remote_send_meta(self) -> None:
    """Exchange send-buffer metadata so each rank can build dynamic
    descriptors at execute time."""
    local_meta: tuple[int, int, int] = (
        self._send_buffer.data_ptr(),
        self._peer_partition_bytes,
        self._cuda_device_id,
    )
    gathered_meta: list[tuple[int, int, int] | None] = [None] * self._world_size
    torch.distributed.all_gather_object(
        gathered_meta, local_meta, group=self._cpu_group
    )

    for peer in self._remote_agents:
        peer_meta = gathered_meta[peer]
        assert peer_meta is not None
        self._remote_send_meta[peer] = peer_meta

_get_or_create_xfer

_get_or_create_xfer(
    src: int, total_bytes: int, recv_offset: int
) -> int

Return a cached xfer handle or create and cache a new one.

Source code in vllm/distributed/eplb/eplb_communicator.py
def _get_or_create_xfer(self, src: int, total_bytes: int, recv_offset: int) -> int:
    """Return a cached xfer handle or create and cache a new one."""
    key = (src, total_bytes, recv_offset)
    cached = self._xfer_cache.get(key)
    if cached is not None:
        return cached[2]

    recv_base = self._recv_buffer.data_ptr()
    local_desc = self._nixl_wrapper.get_xfer_descs(
        [
            (
                recv_base + recv_offset,
                total_bytes,
                self._cuda_device_id,
            )
        ],
        self._nixl_memory_type,
    )
    local_handle = self._nixl_wrapper.prep_xfer_dlist(
        "NIXL_INIT_AGENT",
        local_desc,
    )

    remote_base, remote_part_bytes, remote_dev = self._remote_send_meta[src]
    agent_name = self._remote_agents[src]
    remote_desc = self._nixl_wrapper.get_xfer_descs(
        [
            (
                remote_base + self._rank * remote_part_bytes,
                total_bytes,
                remote_dev,
            )
        ],
        self._nixl_memory_type,
    )
    remote_handle = self._nixl_wrapper.prep_xfer_dlist(
        agent_name,
        remote_desc,
    )

    xfer_handle = self._nixl_wrapper.make_prepped_xfer(
        "READ",
        local_handle,
        [0],
        remote_handle,
        [0],
    )
    self._xfer_cache[key] = (local_handle, remote_handle, xfer_handle)
    return xfer_handle

_make_agent_name

_make_agent_name() -> str

Build a deployment-unique nixl agent name.

Source code in vllm/distributed/eplb/eplb_communicator.py
def _make_agent_name(self) -> str:
    """Build a deployment-unique nixl agent name."""
    pp_size = get_pp_group().world_size
    pp_suffix = f"-pp{get_pp_group().rank_in_group}" if pp_size > 1 else ""
    uid = uuid.uuid4().hex[:8]
    return f"eplb-{self._rank}{pp_suffix}-{uid}"

_pack_send_buffer staticmethod

_pack_send_buffer(
    peer_tensors: list[Tensor],
    send_buffer: Tensor,
    byte_offset: int,
) -> int

Returns the byte offset after the last written byte.

Source code in vllm/distributed/eplb/eplb_communicator.py
@staticmethod
def _pack_send_buffer(
    peer_tensors: list[torch.Tensor],
    send_buffer: torch.Tensor,
    byte_offset: int,
) -> int:
    """
    Returns the byte offset after the last written byte.
    """
    for tensor in peer_tensors:
        raw = tensor.reshape(-1).view(torch.uint8)
        if raw.numel() == 0:
            continue
        send_buffer[byte_offset : byte_offset + raw.numel()].copy_(
            raw, non_blocking=True
        )
        byte_offset += raw.numel()
    return byte_offset

_release_all_cached_handles

_release_all_cached_handles() -> None

Best-effort release of every cached dlist and xfer handle.

Source code in vllm/distributed/eplb/eplb_communicator.py
def _release_all_cached_handles(self) -> None:
    """Best-effort release of every cached dlist and xfer handle."""
    for local_dlist, remote_dlist, xfer in self._xfer_cache.values():
        for release_fn, handle in (
            (self._nixl_wrapper.release_xfer_handle, xfer),
            (self._nixl_wrapper.release_dlist_handle, local_dlist),
            (self._nixl_wrapper.release_dlist_handle, remote_dlist),
        ):
            with contextlib.suppress(Exception):
                release_fn(handle)
    self._xfer_cache.clear()

_unpack_recv_buffer staticmethod

_unpack_recv_buffer(
    recv_buffer: Tensor,
    peer_tensors: list[Tensor],
    byte_offset: int,
) -> int

Returns the byte offset after the last read byte.

Source code in vllm/distributed/eplb/eplb_communicator.py
@staticmethod
def _unpack_recv_buffer(
    recv_buffer: torch.Tensor,
    peer_tensors: list[torch.Tensor],
    byte_offset: int,
) -> int:
    """
    Returns the byte offset after the last read byte.
    """
    for tensor in peer_tensors:
        num_bytes = tensor.numel() * tensor.element_size()
        if num_bytes == 0:
            continue
        tensor.reshape(-1).view(torch.uint8).copy_(
            recv_buffer[byte_offset : byte_offset + num_bytes],
            non_blocking=True,
        )
        byte_offset += num_bytes
    return byte_offset

PyNcclEplbCommunicator

Bases: EplbCommunicator

EPLB communicator backed by PyNcclCommunicator using ncclSend/ncclRecv.

Source code in vllm/distributed/eplb/eplb_communicator.py
class PyNcclEplbCommunicator(EplbCommunicator):
    """EPLB communicator backed by PyNcclCommunicator using ncclSend/ncclRecv."""

    def __init__(
        self,
        pynccl_comm: PyNcclCommunicator,
        cuda_stream: torch.cuda.Stream | None = None,
    ) -> None:
        self._pynccl_comm = pynccl_comm
        self._cuda_stream = cuda_stream
        self._group_started = False
        self._log_initialized()

    def _ensure_group_started(self) -> None:
        if not self._group_started:
            self._pynccl_comm.group_start()
            self._group_started = True

    def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None:
        self._ensure_group_started()
        self._pynccl_comm.send(tensor, dst_rank, stream=self._cuda_stream)

    def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None:
        self._ensure_group_started()
        self._pynccl_comm.recv(tensor, src_rank, stream=self._cuda_stream)

    def execute(self) -> None:
        if self._group_started:
            self._pynccl_comm.group_end()
            self._group_started = False

TorchDistGlooStagedEplbCommunicator

Bases: EplbCommunicator

EPLB communicator using gloo P2P with CPU staging.

Source code in vllm/distributed/eplb/eplb_communicator.py
class TorchDistGlooStagedEplbCommunicator(EplbCommunicator):
    """EPLB communicator using gloo P2P with CPU staging."""

    def __init__(
        self,
        cpu_group: ProcessGroup,
        cuda_stream: torch.cuda.Stream | None = None,
    ) -> None:
        self._cpu_group = cpu_group
        self._cuda_stream = cuda_stream
        self._ops: list[tuple[str, torch.Tensor, int]] = []
        self._log_initialized()

    def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None:
        self._ops.append(("send", tensor, dst_rank))

    def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None:
        self._ops.append(("recv", tensor, src_rank))

    def execute(self) -> None:
        if not self._ops:
            return

        p2p_ops: list[P2POp] = []
        recv_staging: list[tuple[torch.Tensor, torch.Tensor]] = []

        def build_ops() -> None:
            for op, tensor, peer_rank in self._ops:
                if op == "send":
                    cpu_tensor = tensor.to(device="cpu", non_blocking=True)
                    p2p_ops.append(
                        P2POp(
                            torch.distributed.isend,
                            cpu_tensor,
                            peer_rank,
                            self._cpu_group,
                        )
                    )
                    continue
                cpu_tensor = torch.empty_like(tensor, device="cpu")
                p2p_ops.append(
                    P2POp(
                        torch.distributed.irecv,
                        cpu_tensor,
                        peer_rank,
                        self._cpu_group,
                    )
                )
                recv_staging.append((tensor, cpu_tensor))

        try:
            with torch.cuda.stream(self._cuda_stream):
                build_ops()
        finally:
            self._ops.clear()

        # Wait for all D2H copies to finish
        # before issuing gloo batch_isend_irecv operations.
        if self._cuda_stream is not None:
            self._cuda_stream.synchronize()
        else:
            torch.cuda.current_stream().synchronize()

        reqs = batch_isend_irecv(p2p_ops)
        for req in reqs:
            req.wait()

        if not recv_staging:
            return
        with torch.cuda.stream(self._cuda_stream):
            for dst_tensor, cpu_tensor in recv_staging:
                dst_tensor.copy_(cpu_tensor, non_blocking=True)

TorchDistNcclEplbCommunicator

Bases: EplbCommunicator

EPLB communicator backed by torch.distributed isend/irecv.

Source code in vllm/distributed/eplb/eplb_communicator.py
class TorchDistNcclEplbCommunicator(EplbCommunicator):
    """EPLB communicator backed by torch.distributed isend/irecv."""

    def __init__(
        self,
        ep_group: ProcessGroup,
        cuda_stream: torch.cuda.Stream | None = None,
    ) -> None:
        self._ep_group = ep_group
        self._cuda_stream = cuda_stream
        self._p2p_ops: list[P2POp] = []
        self._log_initialized()

    def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None:
        self._p2p_ops.append(
            P2POp(
                torch.distributed.isend,
                tensor,
                dst_rank,
                self._ep_group,
            )
        )

    def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None:
        self._p2p_ops.append(
            P2POp(
                torch.distributed.irecv,
                tensor,
                src_rank,
                self._ep_group,
            )
        )

    def execute(self) -> None:
        if not self._p2p_ops:
            return
        try:
            with torch.cuda.stream(self._cuda_stream):
                reqs = batch_isend_irecv(self._p2p_ops)
                for req in reqs:
                    req.wait()
        finally:
            self._p2p_ops.clear()

create_eplb_communicator

create_eplb_communicator(
    group_coordinator: GroupCoordinator,
    backend: str | None,
    expert_weights: Sequence[Tensor],
) -> EplbCommunicator

Create an EPLB communicator for the given backend.

Parameters:

Name Type Description Default
group_coordinator GroupCoordinator

Process-group coordinator that provides the device and CPU communication groups.

required
backend str | None

Communicator backend name ("torch_nccl", "torch_gloo", "pynccl", or "nixl"). Falls back to "torch_nccl" when None. Stateless (elastic EP) groups only support "torch_nccl" and "pynccl"; "torch_nccl" is silently promoted to "pynccl" in that case. When tensors reside on CPU, "torch_gloo" or "torch_nccl" are used via the CPU process group.

required
expert_weights Sequence[Tensor]

Expert weight tensors from one MoE layer. NixlEplbCommunicator pre-allocates send/recv buffers sized to this layer, so all other MoE layers must have the same tensor count, shapes, and dtypes.

required
Source code in vllm/distributed/eplb/eplb_communicator.py
def create_eplb_communicator(
    group_coordinator: GroupCoordinator,
    backend: str | None,
    expert_weights: Sequence[torch.Tensor],
) -> EplbCommunicator:
    """Create an EPLB communicator for the given backend.

    Args:
        group_coordinator: Process-group coordinator that provides the
            device and CPU communication groups.
        backend: Communicator backend name (``"torch_nccl"``,
            ``"torch_gloo"``, ``"pynccl"``, or ``"nixl"``).
            Falls back to ``"torch_nccl"`` when *None*.
            Stateless (elastic EP) groups only support ``"torch_nccl"``
            and ``"pynccl"``; ``"torch_nccl"`` is silently promoted to
            ``"pynccl"`` in that case.  When tensors reside on CPU,
            ``"torch_gloo"`` or ``"torch_nccl"`` are used via the CPU
            process group.
        expert_weights: Expert weight tensors from *one* MoE layer.
            NixlEplbCommunicator pre-allocates send/recv buffers sized
            to this layer, so all other MoE layers must have the same
            tensor count, shapes, and dtypes.
    """
    # Keep a safe default for callers that have not resolved communicator yet.
    if backend is None:
        backend = "torch_nccl"

    tensor_device_type = expert_weights[0].device.type if expert_weights else "cpu"
    torch_group = (
        group_coordinator.cpu_group
        if tensor_device_type == "cpu"
        else group_coordinator.device_group
    )

    def _create_pynccl() -> EplbCommunicator:
        if tensor_device_type == "cpu":
            raise RuntimeError(
                "EPLB communicator 'pynccl' supports only cuda-like devices "
                f"(got {tensor_device_type})."
            )
        unsupported_dtypes = sorted(
            {
                tensor.dtype
                for tensor in expert_weights
                if not ncclDataTypeEnum.supports_torch_dtype(tensor.dtype)
            },
            key=str,
        )
        if unsupported_dtypes:
            raise RuntimeError(
                "EPLB communicator 'pynccl' requested but expert weights contain "
                "unsupported dtypes: "
                f"({', '.join(str(dtype) for dtype in unsupported_dtypes)})."
            )

        device_comm = group_coordinator.device_communicator
        pynccl_comm = (
            getattr(device_comm, "pynccl_comm", None)
            if device_comm is not None
            else None
        )
        if pynccl_comm is None or pynccl_comm.disabled or not pynccl_comm.available:
            raise RuntimeError("EPLB communicator 'pynccl' requested but unavailable.")
        try:
            return PyNcclEplbCommunicator(pynccl_comm=pynccl_comm)
        except Exception as exc:
            raise RuntimeError(
                f"Failed to initialize PyNcclEplbCommunicator ({exc})."
            ) from exc

    is_stateless = isinstance(group_coordinator, StatelessGroupCoordinator)
    if is_stateless:
        if backend not in ("torch_nccl", "pynccl"):
            raise ValueError(
                f"Elastic EP requires 'torch_nccl' or 'pynccl' EPLB communicator "
                f"(got '{backend}')."
            )
        if backend == "torch_nccl":
            logger.warning(
                "Stateless elastic EP requires PyNCCL backend. "
                "Forcing EPLB communicator to 'pynccl'."
            )
            backend = "pynccl"
        return _create_pynccl()

    if backend == "nixl":
        if not has_nixl():
            raise RuntimeError(
                "EPLB communicator 'nixl' requested but NIXL is unavailable."
            )
        if not (current_platform.is_cuda_alike() and tensor_device_type != "cpu"):
            raise RuntimeError(
                "EPLB communicator 'nixl' supports only cuda-like devices "
                f"(got {tensor_device_type})."
            )
        try:
            return NixlEplbCommunicator(
                cpu_group=group_coordinator.cpu_group,
                expert_weights=expert_weights,
            )
        except Exception as exc:
            raise RuntimeError(
                f"Failed to initialize NixlEplbCommunicator ({exc})."
            ) from exc
    elif backend == "torch_gloo":
        return TorchDistGlooStagedEplbCommunicator(
            cpu_group=group_coordinator.cpu_group,
        )
    elif backend == "torch_nccl":
        return TorchDistNcclEplbCommunicator(ep_group=torch_group)
    elif backend == "pynccl":
        return _create_pynccl()
    raise ValueError(f"Unknown EPLB communicator backend: {backend}")

has_nixl

has_nixl() -> bool

Whether the optional NIXL / RIXL package is available.

Source code in vllm/distributed/eplb/eplb_communicator.py
def has_nixl() -> bool:
    """Whether the optional NIXL / RIXL package is available."""
    return NixlWrapper is not None