Skip to content

vllm.utils.deep_gemm

Compatibility wrapper for DeepGEMM API changes.

Users of vLLM should always import only these wrappers.

DEFAULT_BLOCK_SIZE module-attribute

DEFAULT_BLOCK_SIZE = [128, 128]

__all__ module-attribute

__all__ = [
    "calc_diff",
    "fp8_gemm_nt",
    "m_grouped_fp8_gemm_nt_contiguous",
    "fp8_m_grouped_gemm_nt_masked",
    "fp8_mqa_logits",
    "fp8_paged_mqa_logits",
    "get_paged_mqa_logits_metadata",
    "per_block_cast_to_fp8",
    "is_deep_gemm_e8m0_used",
    "is_deep_gemm_supported",
    "get_num_sms",
    "should_use_deepgemm_for_fp8_linear",
    "get_col_major_tma_aligned_tensor",
]

_fp8_gemm_nt_impl module-attribute

_fp8_gemm_nt_impl: Callable[..., Any] | None = None

_fp8_mqa_logits_impl module-attribute

_fp8_mqa_logits_impl: Callable[..., Any] | None = None

_fp8_paged_mqa_logits_impl module-attribute

_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None

_get_mn_major_tma_aligned_tensor_impl module-attribute

_get_mn_major_tma_aligned_tensor_impl: (
    Callable[..., Any] | None
) = None

_get_paged_mqa_logits_metadata_impl module-attribute

_get_paged_mqa_logits_metadata_impl: (
    Callable[..., Any] | None
) = None

_grouped_impl module-attribute

_grouped_impl: Callable[..., Any] | None = None

_grouped_masked_impl module-attribute

_grouped_masked_impl: Callable[..., Any] | None = None

_align

_align(x: int, y: int) -> int
Source code in vllm/utils/deep_gemm.py
def _align(x: int, y: int) -> int:
    return cdiv(x, y) * y

_ceil_to_ue8m0

_ceil_to_ue8m0(x: Tensor)
Source code in vllm/utils/deep_gemm.py
def _ceil_to_ue8m0(x: torch.Tensor):
    return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))

_lazy_init

_lazy_init() -> None

Import deep_gemm and resolve symbols on first use.

Source code in vllm/utils/deep_gemm.py
def _lazy_init() -> None:
    """Import deep_gemm and resolve symbols on first use."""
    global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
    global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
    global _get_paged_mqa_logits_metadata_impl
    global _get_mn_major_tma_aligned_tensor_impl

    # fast path
    if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
            or _grouped_masked_impl is not None
            or _fp8_mqa_logits_impl is not None
            or _fp8_paged_mqa_logits_impl is not None
            or _get_paged_mqa_logits_metadata_impl is not None):
        return

    if not has_deep_gemm():
        return

    # Set up deep_gemm cache path
    DEEP_GEMM_JIT_CACHE_ENV_NAME = 'DG_JIT_CACHE_DIR'
    if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
        os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
            envs.VLLM_CACHE_ROOT, "deep_gemm")

    _dg = importlib.import_module("deep_gemm")

    _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
    _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
    _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
    _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
    _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
    _get_paged_mqa_logits_metadata_impl = getattr(
        _dg, "get_paged_mqa_logits_metadata", None)
    _get_mn_major_tma_aligned_tensor_impl = getattr(
        _dg, "get_mn_major_tma_aligned_tensor", None)

_missing

_missing(*_: Any, **__: Any) -> NoReturn

Placeholder for unavailable DeepGEMM backend.

Source code in vllm/utils/deep_gemm.py
def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable DeepGEMM backend."""
    raise RuntimeError(
        "DeepGEMM backend is not available or outdated. Please install or "
        "update the `deep_gemm` to a newer version to enable FP8 kernels.")

calc_diff

calc_diff(x: Tensor, y: Tensor)

Return a global difference metric for unit tests.

DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element error, causing torch.testing.assert_close to fail. Instead of checking every element, we compute a cosine-style similarity over the whole tensor and report 1 - sim. Once kernel accuracy improves this helper can be removed.

Source code in vllm/utils/deep_gemm.py
def calc_diff(x: torch.Tensor, y: torch.Tensor):
    """Return a global difference metric for unit tests.

    DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
    error, causing ``torch.testing.assert_close`` to fail.  Instead of checking
    every element, we compute a cosine-style similarity over the whole tensor
    and report ``1 - sim``.  Once kernel accuracy improves this helper can be
    removed.
    """

    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim

fp8_gemm_nt

fp8_gemm_nt(*args, **kwargs)
Source code in vllm/utils/deep_gemm.py
def fp8_gemm_nt(*args, **kwargs):
    _lazy_init()
    if _fp8_gemm_nt_impl is None:
        return _missing(*args, **kwargs)
    if "is_deep_gemm_e8m0_used" in kwargs:
        use_ue8m0 = kwargs["is_deep_gemm_e8m0_used"]
        del kwargs["is_deep_gemm_e8m0_used"]
    else:
        use_ue8m0 = is_deep_gemm_e8m0_used()
    return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs)

fp8_m_grouped_gemm_nt_masked

fp8_m_grouped_gemm_nt_masked(*args, **kwargs)
Source code in vllm/utils/deep_gemm.py
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
    _lazy_init()
    if _grouped_masked_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_masked_impl(
        *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs)

fp8_mqa_logits

fp8_mqa_logits(
    q: Tensor,
    kv: tuple[Tensor, Tensor],
    weights: Tensor,
    cu_seqlen_ks: Tensor,
    cu_seqlen_ke: Tensor,
) -> Tensor

Compute FP8 MQA logits for a single sequence without KV paging.

Parameters:

Name Type Description Default
q Tensor

Query tensor of shape [M, H, D]. Casted to torch.float8_e4m3fn by caller.

required
kv tuple[Tensor, Tensor]

Tuple (k_fp8, k_scales) where k_fp8 has shape [N, D] with dtype torch.float8_e4m3fn and k_scales has shape [N] (or [N, 1]) with dtype torch.float32.

required
weights Tensor

weights of shape [M, H], dtype torch.float32.

required
cu_seqlen_ks Tensor

Start indices (inclusive) for valid K per query position, shape [M], dtype int32.

required
cu_seqlen_ke Tensor

End indices (exclusive) for valid K per query position, shape [M], dtype int32.

required

Returns:

Type Description
Tensor

Logits tensor of shape [M, N], dtype torch.float32.

Source code in vllm/utils/deep_gemm.py
def fp8_mqa_logits(
    q: torch.Tensor,
    kv: tuple[torch.Tensor, torch.Tensor],
    weights: torch.Tensor,
    cu_seqlen_ks: torch.Tensor,
    cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
    """Compute FP8 MQA logits for a single sequence without KV paging.

    Args:
        q: Query tensor of shape [M, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
            dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
            [N, 1]) with dtype `torch.float32`.
        weights: weights of shape [M, H], dtype `torch.float32`.
        cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
            shape [M], dtype int32.
        cu_seqlen_ke: End indices (exclusive) for valid K per query position,
            shape [M], dtype int32.

    Returns:
        Logits tensor of shape [M, N], dtype `torch.float32`.
    """
    _lazy_init()
    if _fp8_mqa_logits_impl is None:
        return _missing()
    return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)

fp8_paged_mqa_logits

fp8_paged_mqa_logits(
    q_fp8: Tensor,
    kv_cache_fp8: Tensor,
    weights: Tensor,
    context_lens: Tensor,
    block_tables: Tensor,
    schedule_metadata: Tensor,
    max_model_len: int,
) -> Tensor

Compute FP8 MQA logits using paged KV-cache.

Parameters:

Name Type Description Default
q_fp8 Tensor

Query tensor of shape [B, next_n, H, D]. Casted to torch.float8_e4m3fn by caller.

required
kv_cache_fp8 Tensor

Paged KV-cache in packed FP8+scale layout with shape [num_blocks, block_size, 1, D+4], dtype torch.uint8. The last 4 bytes per (block,pos) store the float dequant scale.

required
weights Tensor

Tensor of shape [B * next_n, H], dtype torch.float32.

required
context_lens Tensor

Tensor of shape [B], dtype int32; effective context length for each batch element.

required
block_tables Tensor

Tensor of shape [B, max_blocks], dtype int32; maps logical block indices to physical blocks in the paged cache.

required
schedule_metadata Tensor

Returned by get_paged_mqa_logits_metadata; used to distribute work across SMs.

required
max_model_len int

Maximum sequence length used to size the logits output.

required

Returns:

Type Description
Tensor

Logits tensor of shape [B * next_n, max_model_len], dtype

Tensor

torch.float32.

Source code in vllm/utils/deep_gemm.py
def fp8_paged_mqa_logits(
    q_fp8: torch.Tensor,
    kv_cache_fp8: torch.Tensor,
    weights: torch.Tensor,
    context_lens: torch.Tensor,
    block_tables: torch.Tensor,
    schedule_metadata: torch.Tensor,
    max_model_len: int,
) -> torch.Tensor:
    """Compute FP8 MQA logits using paged KV-cache.

    Args:
        q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
            [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
            4 bytes per (block,pos) store the `float` dequant scale.
        weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
        context_lens: Tensor of shape [B], dtype int32; effective context length
            for each batch element.
        block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
            block indices to physical blocks in the paged cache.
        schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
            used to distribute work across SMs.
        max_model_len: Maximum sequence length used to size the logits output.

    Returns:
        Logits tensor of shape [B * next_n, max_model_len], dtype
        `torch.float32`.
    """
    _lazy_init()
    if _fp8_paged_mqa_logits_impl is None:
        return _missing()
    return _fp8_paged_mqa_logits_impl(q_fp8,
                                      kv_cache_fp8,
                                      weights,
                                      context_lens,
                                      block_tables,
                                      schedule_metadata,
                                      max_model_len,
                                      clean_logits=True)

get_col_major_tma_aligned_tensor

get_col_major_tma_aligned_tensor(x: Tensor) -> Tensor

Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor

Source code in vllm/utils/deep_gemm.py
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
    """Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
    _lazy_init()
    if _get_mn_major_tma_aligned_tensor_impl is None:
        return _missing()
    return _get_mn_major_tma_aligned_tensor_impl(x)

get_num_sms

get_num_sms() -> int
Source code in vllm/utils/deep_gemm.py
def get_num_sms() -> int:
    _lazy_init()
    _dg = importlib.import_module("deep_gemm")
    return int(_dg.get_num_sms())

get_paged_mqa_logits_metadata

get_paged_mqa_logits_metadata(
    context_lens: Tensor, block_size: int, num_sms: int
) -> Tensor

Build scheduling metadata for paged MQA logits.

Parameters:

Name Type Description Default
context_lens Tensor

Tensor of shape [B], dtype int32; effective context length per batch element.

required
block_size int

KV-cache block size in tokens (e.g., 64).

required
num_sms int

Number of SMs available. 132 for Hopper

required

Returns:

Type Description
Tensor

Backend-specific tensor consumed by fp8_paged_mqa_logits to

Tensor

schedule work across SMs.

Source code in vllm/utils/deep_gemm.py
def get_paged_mqa_logits_metadata(context_lens: torch.Tensor, block_size: int,
                                  num_sms: int) -> torch.Tensor:
    """Build scheduling metadata for paged MQA logits.

    Args:
        context_lens: Tensor of shape [B], dtype int32; effective context length
            per batch element.
        block_size: KV-cache block size in tokens (e.g., 64).
        num_sms: Number of SMs available. 132 for Hopper

    Returns:
        Backend-specific tensor consumed by `fp8_paged_mqa_logits` to
        schedule work across SMs.
    """
    _lazy_init()
    if _get_paged_mqa_logits_metadata_impl is None:
        return _missing()
    return _get_paged_mqa_logits_metadata_impl(context_lens, block_size,
                                               num_sms)

is_deep_gemm_e8m0_used cached

is_deep_gemm_e8m0_used() -> bool

Return True if vLLM is configured to use DeepGEMM " "E8M0 scale on a Hopper or Blackwell-class GPU.

Source code in vllm/utils/deep_gemm.py
@functools.cache
def is_deep_gemm_e8m0_used() -> bool:
    """Return ``True`` if vLLM is configured to use DeepGEMM "
    "E8M0 scale on a Hopper or Blackwell-class GPU.
    """
    if not is_deep_gemm_supported():
        logger.debug_once(
            "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.")
        return False

    _lazy_init()

    if _fp8_gemm_nt_impl is None:
        logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
        return False

    if envs.VLLM_USE_FLASHINFER_MOE_FP8:
        logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.")
        return False

    if current_platform.is_device_capability(100) and \
            envs.VLLM_USE_DEEP_GEMM_E8M0:
        logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
        return True

    if current_platform.is_device_capability(90) and \
            envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER:
        logger.info_once("DeepGEMM E8M0 enabled on Hopper GPU.")
        return True

    logger.info_once("DeepGEMM E8M0 disabled on current configuration.")
    return False

is_deep_gemm_supported cached

is_deep_gemm_supported() -> bool

Return True if DeepGEMM is supported on the current platform. Currently, only Hopper and Blackwell GPUs are supported.

Source code in vllm/utils/deep_gemm.py
@functools.cache
def is_deep_gemm_supported() -> bool:
    """Return ``True`` if DeepGEMM is supported on the current platform.
    Currently, only Hopper and Blackwell GPUs are supported.
    """
    is_supported_arch = current_platform.is_cuda() and (
        current_platform.is_device_capability(90)
        or current_platform.is_device_capability(100))
    return (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
            and not envs.VLLM_USE_FLASHINFER_MOE_FP8)

m_grouped_fp8_gemm_nt_contiguous

m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs)
Source code in vllm/utils/deep_gemm.py
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
    _lazy_init()
    if _grouped_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_impl(*args,
                         disable_ue8m0_cast=not is_deep_gemm_e8m0_used(),
                         **kwargs)

per_block_cast_to_fp8

per_block_cast_to_fp8(
    x: Tensor,
    block_size: list[int] = DEFAULT_BLOCK_SIZE,
    use_ue8m0: bool = False,
) -> tuple[Tensor, Tensor]
Source code in vllm/utils/deep_gemm.py
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def per_block_cast_to_fp8(
        x: torch.Tensor,
        block_size: list[int] = DEFAULT_BLOCK_SIZE,
        use_ue8m0: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2
    m, n = x.shape
    block_m, block_n = block_size
    x_padded = torch.zeros((_align(m, block_m), _align(n, block_n)),
                           dtype=x.dtype,
                           device=x.device)
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
    sf = x_amax / 448.0
    sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
    x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
        x_view.size(0), x_view.size(2))

should_use_deepgemm_for_fp8_linear

should_use_deepgemm_for_fp8_linear(
    output_dtype: dtype,
    weight: Tensor,
    supports_deep_gemm: Optional[bool] = None,
)
Source code in vllm/utils/deep_gemm.py
def should_use_deepgemm_for_fp8_linear(
        output_dtype: torch.dtype,
        weight: torch.Tensor,
        supports_deep_gemm: Optional[bool] = None):
    if supports_deep_gemm is None:
        supports_deep_gemm = is_deep_gemm_supported()
    return (supports_deep_gemm and output_dtype == torch.bfloat16
            and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)