Skip to content

Onboard ragged sorting kernels from tpu inference to maxtext#3814

Draft
NuojCheng wants to merge 8 commits into
mainfrom
chengnuojin-ragged-gather-nightly
Draft

Onboard ragged sorting kernels from tpu inference to maxtext#3814
NuojCheng wants to merge 8 commits into
mainfrom
chengnuojin-ragged-gather-nightly

Conversation

@NuojCheng
Copy link
Copy Markdown
Collaborator

@NuojCheng NuojCheng commented May 5, 2026

Description

This PR onboards two critical kernels: ragged_gather and ragged_gather_reduce.

The key difference lies in their routing and output shapes:

  • ragged_gather: Performs simultaneous permutation and fan-out (num_tokens x emb_dimnum_tokens x top_k x emb_dim).
  • ragged_gather_reduce: Adds an accumulation step for fan-in (num_tokens x top_k x emb_dimnum_tokens x emb_dim).

During the forward pass, we use ragged_gather for dispatch and ragged_gather_reduce for combine. In the backward pass, these roles are swapped.

FIXES: b/496676734

Tests

We observe performance advantage when ragged sort is enabled, while also notice some regression issue between JAX==0.10.0 and 0.10.1.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@NuojCheng NuojCheng added pull ready draft Draft PR and removed pull ready labels May 5, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented May 5, 2026

@NuojCheng NuojCheng changed the title [Draft] Onboard ragged sorting kernels from tpu inference to maxtext May 11, 2026
@NuojCheng NuojCheng force-pushed the chengnuojin-ragged-gather-nightly branch from 803c1de to 4ac0bcf Compare May 11, 2026 18:25
@NuojCheng NuojCheng force-pushed the chengnuojin-ragged-gather-nightly branch from 4ac0bcf to b1ce84c Compare May 11, 2026 18:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

draft Draft PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant