Skip to content

Add unified Cache-layer management for GLM-5 DSA Indexer keys#45595

Open
louzongzhi wants to merge 1 commit intohuggingface:mainfrom
louzongzhi:glm5
Open

Add unified Cache-layer management for GLM-5 DSA Indexer keys#45595
louzongzhi wants to merge 1 commit intohuggingface:mainfrom
louzongzhi:glm5

Conversation

@louzongzhi
Copy link
Copy Markdown
Contributor

What does this PR do?

This PR migrates the GLM-5 DSA IndexCache key cache from a self-managed register_buffer (as introduced in #45424) into the standard past_key_values (Cache) per-layer infrastructure. Indexer keys now share the same lifecycle as attention KV caches, enabling transparent support for beam-search reordering, cache cropping, batch selection/repetition, and offloading.

Background & Motivation

GLM-5 integrates the DeepSeek Sparse Attention (DSA) Indexer. In #45424, we added support for the IndexCache mechanism (THUDM/IndexCache, arXiv:2603.12201) to accelerate inference by caching indexer keys and allowing Shared (S) layers to reuse indices from Full (F) layers.

However, the implementation in #45424 managed these cached keys inside GlmMoeDsaIndexer via self.register_buffer("_cached_keys", None, persistent=False). While this works for basic generation, it creates architectural friction when building downstream models on top of GLM-5:

  • Lifecycle drift: Cache.reset, reorder_cache, crop, batch_repeat_interleave, and batch_select_indices do not propagate to the Indexer's isolated buffer.
  • Offloading blind spot: The indexer keys are invisible to the Cache offloading / prefetching system.
  • Duplicated logic: Prefill-vs-decode concatenation is manually implemented inside the Indexer, duplicating what DynamicCache already does.

To simplify our own downstream model code and benefit the broader GLM-5 ecosystem, we are upstreaming this unification first.

Changes

src/transformers/cache_utils.py

  • CacheLayerMixin: Added indexer_keys attribute and abstract method update_cached_keys().
  • DynamicLayer: Implemented update_cached_keys() by concatenating along the sequence dimension (dim=1); synchronized crop, batch_repeat_interleave, batch_select_indices, reset, and reorder_cache to also operate on indexer_keys.
  • StaticLayer: Added passthrough update_cached_keys() returning the input as-is.
  • Cache: Added update_cached_keys(cached_keys, layer_idx) and reset_cached_keys(layer_idx) to dispatch per layer, mirroring the existing update() / reset() API.

src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py

  • GlmMoeDsaIndexer:
    • Removed self.register_buffer("_cached_keys", None, persistent=False).
    • forward() now accepts past_key_values: Cache | None instead of use_cache: bool.
    • Prefill (seq_len > 1) triggers past_key_values.reset_cached_keys(layer_idx) before computing scores.
    • Decode (seq_len == 1) appends keys via past_key_values.update_cached_keys(k, layer_idx).
  • GlmMoeDsaAttention: Updated self.indexer(...) call to pass past_key_values=past_key_values directly.

Behavior equivalence

#45424 (self-managed IndexCache) This PR (unified Cache layer)
if seq_len > 1: self._cached_keys = None if seq_len > 1: past_key_values.reset_cached_keys(layer_idx)
if use_cache: cat([self._cached_keys, k]) if past_key_values is not None: past_key_values.update_cached_keys(k, layer_idx)
else: k_cached = k else: k_cached = k

Local verification

Verified locally with a tiny hidden_size=256 config:

  • Prefill + decode cache growth
  • Multi-layer isolation
  • Beam-search reorder_cache
  • batch_select_indices / batch_repeat_interleave
  • crop truncation
  • Shared indexer (skip_topk) reuse path
  • No-cache fallback (past_key_values=None)
  • Cache.reset() clears indexer_keys
  • End-to-end multi-turn chat: indexer_keys accumulate correctly across turns

Backward compatibility

  • No public API changes; generate() behavior is unchanged.
  • The removed _cached_keys buffer was never part of saved state (persistent=False), so existing checkpoints remain fully compatible.

Follow-up context

We are actively building a new LLM architecture on top of the GLM-5 backbone, leveraging the DSA Indexer and IndexCache design. Unifying the Indexer cache into the standard Cache layer is prerequisite infrastructure work for open-sourcing that model. We will share more details once training converges. Stay tuned!

<!-- Remove if not applicable -->

Fixes # (N/A)

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.

PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.

This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read CONTRIBUTING.md.

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of who to tag.
Please tag fewer than 3 people.

Models:

Library:

…ged buffer into past_key_values per-layer cache
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: glm_moe_dsa

@Rocketknight1
Copy link
Copy Markdown
Member

cc @ArthurZucker @vasqu who reviewed the last issue! #45424

@louzongzhi
Copy link
Copy Markdown
Contributor Author

cc @ArthurZucker @vasqu who reviewed the last issue! #45424

Hi @Rocketknight1 @ArthurZucker @vasqu, thanks for the attention!

I'd like to clarify the background of IndexCache and the context of this PR.

IndexCache is a mechanism explicitly described in the GLM-5 papers and official code. It allows Shared (S) layers to reuse top-k indices from Full (F) layers to accelerate inference. You can find the details in arXiv:2602.15763 and arXiv:2603.12201, and the official implementation is at THUDM/IndexCache.

The GLM-5 implementation in transformers did not include IndexCache until #45424. Without it, Shared layers were unable to reuse indices, which deviated from the official behavior described in the papers. In #45424, I added IndexCache support following the official implementation, so the functional behavior now aligns with the official repo.

A side note on self._cached_keys: The change from a plain Python attribute (self._cached_keys = None) to register_buffer(..., persistent=False) in #45424 was an incidental fix, not part of the IndexCache feature itself. The original plain attribute had clear infrastructure issues: it doesn't sync with the model device (.to(device) / .cuda() won't move it), it can lead to inconsistent states across ranks in distributed training (DDP/FSDP), and it isn't tracked by state_dict. The register_buffer change simply fixed those issues.

This PR (#45595) does not change the functional logic of IndexCache. That logic remains the same as in the THUDM official implementation (i.e., the behavior before my #45424 submission). What this PR does is migrate the cache management from the internal register_buffer in GlmMoeDsaIndexer into the standard past_key_values (DynamicCache/StaticCache) lifecycle. This lets indexer keys automatically participate in Cache.reset, reorder_cache, crop, batch_repeat_interleave, and offloading, rather than requiring downstream models to handle these operations manually.

If anyone has questions about the IndexCache behavior or the migration approach, I'm happy to explain further. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants