Skip to content

Add Gemma 4 model family support#452

Closed
chgeuer wants to merge 1 commit intoelixir-nx:mainfrom
chgeuer:gemma4-support
Closed

Add Gemma 4 model family support#452
chgeuer wants to merge 1 commit intoelixir-nx:mainfrom
chgeuer:gemma4-support

Conversation

@chgeuer
Copy link
Copy Markdown

@chgeuer chgeuer commented Apr 6, 2026

Adds Bumblebee.Text.Gemma4 implementing the text-only portion of the Gemma 4 architecture from Google DeepMind, supporting both the E4B (4.5B dense) and the 26B-A4B (Mixture-of-Experts) variants.

Architecture features

  • Hybrid attention: alternating sliding window and full attention layers (5:1 ratio)
  • Dual RoPE: per-layer-type rotary embeddings (theta=10K default for sliding, theta=1M proportional for full)
  • Grouped Query Attention with per-layer head dimensions (attention_head_size / global_attention_head_size)
  • Q/K/V RMS normalization
  • Gated FFN (SwiGLU with gelu_approx_tanh)
  • Per-Layer Embeddings (PLE) with gating and projection
  • KV sharing: later layers reuse K/V from earlier layers of the same attention type
  • Mixture-of-Experts: combined router + expert dispatch (26B-A4B variant)
  • Per-layer scalar
  • Logit softcapping

Architectures

  • :baseGemma4TextModel
  • :for_causal_language_modelingGemma4ForCausalLM

The multimodal Gemma4ForConditionalGeneration is not yet supported.

Custom decoder loop

Uses a custom decoder loop rather than Layers.Transformer.blocks/2 because the model requires features not available in the shared infrastructure: per-layer embeddings (PLE) threaded through the block loop, cross-block KV sharing state, per-layer head dimension variation, per-layer scalar, and value normalization.

Registry entries

  • Gemma4ForCausalLM{Bumblebee.Text.Gemma4, :for_causal_language_modeling}
  • Gemma4TextModel{Bumblebee.Text.Gemma4, :base}
  • gemma4 / gemma4_text:gemma tokenizer type

Testing

Integration test verified against Python transformers reference values (atol < 5e-5) using a tiny-random checkpoint.

Note for maintainers: The integration test references {:hf, "bumblebee-testing/tiny-random-Gemma4ForCausalLM"}. The checkpoint can be generated using this script:

from transformers import Gemma4TextConfig, Gemma4ForCausalLM

config = Gemma4TextConfig(
    vocab_size=1024, hidden_size=32, num_hidden_layers=2,
    num_attention_heads=4, num_key_value_heads=2,
    head_dim=8, global_head_dim=16, intermediate_size=64,
    hidden_activation="gelu_pytorch_tanh", max_position_embeddings=128,
    initializer_range=0.02, rms_norm_eps=1e-6, pad_token_id=0,
    sliding_window=32, enable_moe_block=False,
    hidden_size_per_layer_input=0, num_kv_shared_layers=0,
    attention_k_eq_v=False, tie_word_embeddings=True,
    final_logit_softcapping=30.0,
    layer_types=["sliding_attention", "full_attention"],
    rope_parameters={
        "sliding_attention": {"rope_theta": 10000.0, "rope_type": "default"},
        "full_attention": {"rope_theta": 1000000.0, "rope_type": "proportional",
                          "partial_rotary_factor": 0.25},
    },
)

model = Gemma4ForCausalLM(config)
model.save_pretrained("bumblebee-testing/tiny-random-Gemma4ForCausalLM")

Unit tests cover config loading (E4B + 26B MoE), forward pass (sliding/full attention, partial rotary, MoE, masking, softcapping).

@chgeuer chgeuer force-pushed the gemma4-support branch 2 times, most recently from 4cf82e0 to 9a4cdd1 Compare April 6, 2026 12:51
Implement the text-only portion of Google DeepMind Gemma 4 architecture:

- Hybrid attention: alternating sliding window and full attention layers
- Dual RoPE: proportional RoPE for full attention, default for sliding
- Per-Layer Embeddings (PLE): per-layer token-dependent gating
- KV sharing: later layers reuse KV from earlier layers of same type
- Q/K/V normalization: RMS normalization on query, key, and value
- Per-layer scalar: learned scaling factor per transformer block
- Optional MoE: mixture-of-experts FFN blocks (26B-A4B variant)

Architectures: :base (Gemma4TextModel), :for_causal_language_modeling
(Gemma4ForCausalLM). Multimodal Gemma4ForConditionalGeneration is not
yet supported.

Uses a custom decoder loop rather than Layers.Transformer.blocks/2
because the model requires features not available in the shared
infrastructure: per-layer embeddings threaded through the block loop,
cross-block KV sharing state, per-layer head dimension variation,
and value normalization.

Includes integration test verified against Python transformers reference
values (atol < 5e-5).
@chgeuer chgeuer closed this Apr 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant