Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ load_balance_loss_weight: 0.0 # weight for the load balance loss
use_random_routing: false # whether to use random routing for debug/test purpose
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism
use_ragged_sort: false # whether to use ragged sorting in forward pass, only working with ring of experts
use_gather_mosaic_kernel: false # whether to use a custom mosaic kernel for token gather ops
# tunable tiling dimensions used for mlp gmm
# megablox/jax ragged dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`)
Expand Down
10 changes: 10 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,9 @@ class MoEGeneral(BaseModel):
False,
description="Whether to use Ring of Experts for sparse matmul expert parallelism.",
)
use_ragged_sort: bool = Field(
False, description="Whether to use ragged kernel for sorting, only valid when EP is enabled."
)
use_gather_mosaic_kernel: bool = Field(
False,
description="Whether to use a custom mosaic kernel for token gather ops.",
Expand Down Expand Up @@ -2256,6 +2259,12 @@ def validate_ragged_buffer_factor(self):
if self.use_ring_of_experts:
raise ValueError("Currently we only support ragged buffer factor with ragged a2a approach.")

def validate_ragged_sort(self):
if not self.use_ragged_sort:
return
if not self.use_ring_of_experts:
raise ValueError("Ragged sorting kernels only work when use_ring_of_experts=True.")

@model_validator(mode="after")
def set_derived_and_validate_values(self) -> "MaxTextConfig":
"""
Expand Down Expand Up @@ -2801,6 +2810,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block != DecoderBlockType.DEEPSEEK:
raise ValueError("Loss-free load balancing is only supported for the DeepSeek decoder block.")
self.validate_ragged_buffer_factor()
self.validate_ragged_sort()
if self.use_multimodal:
valid_mm_models = (
"gemma3-4b",
Expand Down
298 changes: 298 additions & 0 deletions src/maxtext/kernels/ragged/ragged_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Ragged gather kernel implementation from tpu-inference."""
# Source from https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/kernels/sparse_core/ragged_gather.py

import functools

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.pallas import tpu_sc as plsc
from packaging.version import Version


# JAX <= 0.10.0 used `out_shape`/`scratch_shapes` kwargs for `pl.kernel`; later
# versions renamed them to `out_type`/`scratch_types`.
if Version(jax.__version__) <= Version("0.10.0"):
_OUT_KW = "out_shape"
_SCRATCH_KW = "scratch_shapes"
else:
_OUT_KW = "out_type"
_SCRATCH_KW = "scratch_types"


def main_kernel(
# Inputs.
start_ref: jax.Ref,
end_ref: jax.Ref,
in_hbm_ref: jax.Ref,
indices_hbm_ref: jax.Ref,
# Outputs.
out_hbm_ref: jax.Ref,
# Scratch.
start_vmem_ref: jax.Ref,
end_vmem_ref: jax.Ref,
out_vmem_ref: jax.Ref,
indices_vmem_ref: jax.Ref,
sem_ref: jax.Ref,
*,
core_axis_name: str,
subcore_axis_name: str,
):
"""Core ragger gather operation"""
tpu_info = pltpu.get_tpu_info()
sc_info = tpu_info.sparse_core
assert sc_info is not None
num_simd_lanes = sc_info.num_lanes
num_lanes = tpu_info.num_lanes
hidden_size = in_hbm_ref.shape[-1]
col_size = out_vmem_ref.shape[-1]
num_cores = jax.lax.axis_size((core_axis_name, subcore_axis_name))
block_size = num_simd_lanes * num_cores

recv_sem = sem_ref.at[0]
send_sem = sem_ref.at[1]

# Read start and end tensor values.
dma_list = []
dma = pltpu.make_async_copy(start_ref, start_vmem_ref.at[:1], recv_sem)
dma_list.append(dma)
dma = pltpu.make_async_copy(end_ref, end_vmem_ref.at[:1], recv_sem)
dma_list.append(dma)

jax.tree.map(lambda x: x.start(), dma_list)
jax.tree.map(lambda x: x.wait(), dma_list)

# Calculate number of tiles to visit using start and end arrays.
start = start_vmem_ref[...][0]
end = end_vmem_ref[...][0]

block_start = start // block_size
block_end = pl.cdiv(end, block_size)
num_blocks = block_end - block_start
num_blocks = jnp.where(end == start, 0, num_blocks)
aligned_start = block_start * block_size

num_cols = pl.cdiv(hidden_size, col_size)

@functools.partial(
pltpu.emit_pipeline,
grid=(num_blocks, num_cores, num_cols),
core_axis_name=(core_axis_name, subcore_axis_name),
dimension_semantics=(pltpu.ARBITRARY, pltpu.PARALLEL, pltpu.ARBITRARY),
)
def inner_kernel():
block_id = pl.program_id(0)
core_id = pl.program_id(1)
col_id = pl.program_id(2)

row_tile_start = aligned_start + block_id * block_size + core_id * num_simd_lanes
col_tile_start = col_id * col_size

@pl.when(col_id == 0)
def _():
pltpu.sync_copy(
indices_hbm_ref.at[pl.ds(row_tile_start, num_simd_lanes)],
indices_vmem_ref,
)

# HBM to VMEM transfer.
indices = indices_vmem_ref[...]

dtype = out_hbm_ref.dtype
dtype_bits = jax.dtypes.itemsize_bits(dtype)
packing = 32 // dtype_bits

# To fetch only one sublane at a time, we need to use (packing, 128) layout.
# But, the inputs are in (8, 128) layout and thus we need to perform
# relayout. For 32-bits, this can be done with a simple reinterpretation,
# but for other bitwidths, this is not possible. Therefore, we bitcast data
# into 32-bits first to fetch packing number of rows per dma and later
# perform bitwise unpacking / packing to obtain desired results.
in_32b_hbm_ref = in_hbm_ref.bitcast(jnp.uint32)
out_32b_hbm_ref = out_hbm_ref.bitcast(jnp.uint32)

for col_vmem_start in range(0, col_size, num_lanes):
col_hbm_start = col_tile_start + col_vmem_start
for row_vmem in range(num_simd_lanes):
row_hbm = indices[row_vmem] // packing
pltpu.make_async_copy(
in_32b_hbm_ref.at[row_hbm, pl.ds(col_hbm_start, num_lanes)],
out_vmem_ref.at[row_vmem, pl.ds(col_vmem_start, num_lanes)],
recv_sem,
).start()

# VMEM to HBM transfer.
# Use dynamic loop to minimize register spills.
@pl.loop(0, col_size, step=num_lanes)
@jax.named_scope("dma_write_loop")
def dma_write_loop(col_vmem_start):
col_hbm_start = col_tile_start + col_vmem_start

# Wait for data to be received.
# NOTE: Because a single semaphore was used for all dma calls, we need
# to make sure the order of the wait is the same as order of start.
# Otherwise, a dma finish can trigger wrong dma wait to exit.
for _ in range(num_simd_lanes):
pltpu.make_async_copy(
in_32b_hbm_ref.at[0, :num_lanes],
out_vmem_ref.at[0, :num_lanes],
recv_sem,
).wait()

# If multiple elements are packed in single 32-bits, extract the desired
# elements and reorder them.
if packing > 1:
for col_compute_offset in range(0, num_lanes, num_simd_lanes):
col_slice = pl.ds(col_vmem_start + col_compute_offset, num_simd_lanes)

out = None
for row_src in range(num_simd_lanes):
row_src_pack = indices[row_src] % packing
row_dst_pack = row_src % packing

rightshift_bits = row_src_pack * dtype_bits
leftshift_bits = row_dst_pack * dtype_bits

# Load data from vmem.
data = out_vmem_ref[row_src, col_slice]

# Right shift to make first n bits stores target data.
data = jnp.bitwise_right_shift(data, rightshift_bits)
# Mask out unwanted bits.
data = jnp.bitwise_and(data, 2**dtype_bits - 1)
# Left shift data into the target bit location.
data = jnp.bitwise_left_shift(data, leftshift_bits)

if row_dst_pack == 0:
out = data
else:
assert out is not None
out = jnp.bitwise_or(out, data)

if row_dst_pack == packing - 1:
# Store packed data into correct position.
row_dst = row_src // packing
out_vmem_ref[row_dst, col_slice] = out

# Start dma write.
for row_vmem in range(num_simd_lanes // packing):
row_hbm = row_tile_start // packing + row_vmem
pltpu.make_async_copy(
out_vmem_ref.at[row_vmem, pl.ds(col_vmem_start, num_lanes)],
out_32b_hbm_ref.at[row_hbm, pl.ds(col_hbm_start, num_lanes)],
send_sem,
).start()

# Wait for dma write to finish.
for _ in range(0, col_size, num_lanes):
for _ in range(num_simd_lanes // packing):
pltpu.make_async_copy(
out_vmem_ref.at[0, :num_lanes],
out_32b_hbm_ref.at[0, :num_lanes],
send_sem,
).wait()

inner_kernel()


def calculate_col_size(hidden_size: int) -> int:
"""Calculate col size for ragged gather kernel."""
tpu_info = pltpu.get_tpu_info()
sc_info = tpu_info.sparse_core
assert sc_info is not None
num_lanes = tpu_info.num_lanes
num_simd_lanes = sc_info.num_lanes

match tpu_info.chip_version:
case 6:
target_bytes = (256 * 1024) * 0.9
case 7:
target_bytes = (512 * 1024) * 0.9
case _:
target_bytes = (128 * 1024) * 0.9

base_bytes = num_simd_lanes * hidden_size * (32 // 8)
num_cols = 1

while pl.cdiv(base_bytes, num_cols * num_lanes) * num_lanes > target_bytes:
num_cols += 1
return pl.cdiv(hidden_size, (num_cols * num_lanes)) * num_lanes


@jax.jit
def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.Array) -> jax.Array:
"""Perform gather on indices within dynamic array start and end."""

assert x.ndim == 2, "Ragged gather only supports 2d inputs."
assert indices.ndim == 1, "Ragged gather only supports 1d indices."

if jnp.isscalar(start):
start = start[None]
if jnp.isscalar(end):
end = end[None]

dtype = x.dtype

sc_info = pltpu.get_tpu_info().sparse_core
if sc_info is None:
# Sparse core is not available. Fallback to regular gather.
return x[indices]

hidden_size = x.shape[-1]
out_size = indices.size

num_simd_lanes = sc_info.num_lanes
num_cores = sc_info.num_cores * sc_info.num_subcores
block_size = num_simd_lanes * num_cores
col_size = calculate_col_size(hidden_size)

# Pad to align to the block size.
out_pad_size = pl.cdiv(out_size, block_size) * block_size - out_size
indices = jnp.pad(indices, ((0, out_pad_size)))

aligned_hidden_size = pl.cdiv(hidden_size, col_size) * col_size

vector_mesh = plsc.VectorSubcoreMesh(
num_cores=sc_info.num_cores,
num_subcores=sc_info.num_subcores,
core_axis_name="core",
subcore_axis_name="subcore",
)
return pl.kernel(
functools.partial(
main_kernel,
core_axis_name=vector_mesh.core_axis_name,
subcore_axis_name=vector_mesh.subcore_axis_name,
),
compiler_params=pltpu.CompilerParams(
use_tc_tiling_on_sc=True,
disable_bounds_checks=True,
),
mesh=vector_mesh,
name="sc_ragged_gather",
**{
_OUT_KW: jax.ShapeDtypeStruct((out_size + out_pad_size, aligned_hidden_size), dtype),
_SCRATCH_KW: [
pltpu.VMEM((num_simd_lanes,), jnp.int32),
pltpu.VMEM((num_simd_lanes,), jnp.int32),
pltpu.VMEM((num_simd_lanes, col_size), jnp.uint32),
pltpu.VMEM((num_simd_lanes,), jnp.int32),
pltpu.SemaphoreType.DMA((2,)),
],
},
)(start, end, x, indices)[:out_size, :hidden_size]
Loading
Loading