Skip to content

Commit 4cf82e0

Browse files
committed
Add Gemma 4 model family support
Implement the text-only portion of Google DeepMind Gemma 4 architecture: - Hybrid attention: alternating sliding window and full attention layers - Dual RoPE: proportional RoPE for full attention, default for sliding - Per-Layer Embeddings (PLE): per-layer token-dependent gating - KV sharing: later layers reuse KV from earlier layers of same type - Q/K/V normalization: RMS normalization on query, key, and value - Per-layer scalar: learned scaling factor per transformer block - Optional MoE: mixture-of-experts FFN blocks (26B-A4B variant) Architectures: :base (Gemma4TextModel), :for_causal_language_modeling (Gemma4ForCausalLM). Multimodal Gemma4ForConditionalGeneration is not yet supported. Uses a custom decoder loop rather than Layers.Transformer.blocks/2 because the model requires features not available in the shared infrastructure: per-layer embeddings threaded through the block loop, cross-block KV sharing state, per-layer head dimension variation, and value normalization. Includes integration test verified against Python transformers reference values (atol < 5e-5).
1 parent 0b397f6 commit 4cf82e0

File tree

4 files changed

+1507
-0
lines changed

4 files changed

+1507
-0
lines changed

generate_tiny_gemma4_checkpoint.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Generate tiny-random Gemma4 checkpoints for Bumblebee integration tests.
4+
5+
Creates:
6+
/tmp/bumblebee-testing/tiny-random-Gemma4ForCausalLM/
7+
8+
Then prints Python reference values for the Elixir test.
9+
10+
Usage:
11+
python3 generate_tiny_gemma4_checkpoint.py
12+
"""
13+
14+
import torch
15+
import os
16+
17+
from transformers import Gemma4TextConfig, Gemma4ForCausalLM
18+
19+
# Tiny config matching the test structure
20+
config = Gemma4TextConfig(
21+
vocab_size=1024,
22+
hidden_size=32,
23+
num_hidden_layers=2,
24+
num_attention_heads=4,
25+
num_key_value_heads=2,
26+
head_dim=8,
27+
global_head_dim=16,
28+
intermediate_size=64,
29+
hidden_activation="gelu_pytorch_tanh",
30+
max_position_embeddings=128,
31+
initializer_range=0.02,
32+
rms_norm_eps=1e-6,
33+
pad_token_id=0,
34+
sliding_window=32,
35+
enable_moe_block=False,
36+
hidden_size_per_layer_input=0,
37+
num_kv_shared_layers=0,
38+
attention_k_eq_v=False,
39+
tie_word_embeddings=True,
40+
final_logit_softcapping=30.0,
41+
num_experts=None,
42+
top_k_experts=None,
43+
layer_types=["sliding_attention", "full_attention"],
44+
rope_parameters={
45+
"sliding_attention": {"rope_theta": 10000.0, "rope_type": "default"},
46+
"full_attention": {
47+
"rope_theta": 1000000.0,
48+
"rope_type": "proportional",
49+
"partial_rotary_factor": 0.25,
50+
},
51+
},
52+
)
53+
54+
print(f"Config model_type: {config.model_type}")
55+
56+
# Create and save model
57+
model = Gemma4ForCausalLM(config)
58+
model.eval()
59+
print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")
60+
61+
out_dir = "/tmp/bumblebee-testing/tiny-random-Gemma4ForCausalLM"
62+
os.makedirs(out_dir, exist_ok=True)
63+
model.save_pretrained(out_dir)
64+
print(f"Saved to {out_dir}")
65+
print(f"Files: {os.listdir(out_dir)}")
66+
67+
# Generate reference values
68+
inputs = {
69+
"input_ids": torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
70+
"attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
71+
}
72+
73+
with torch.no_grad():
74+
outputs = model(**inputs)
75+
logits = outputs.logits
76+
77+
print(f"\nlogits shape: {logits.shape}")
78+
print(f"\nReference values for Elixir test:")
79+
print(f"outputs.logits[[.., 1..3, 1..3]]:")
80+
ref = logits[:, 1:4, 1:4]
81+
print(ref)
82+
83+
print(f"\n--- Copy this into gemma4_test.exs ---")
84+
print(f"assert Nx.shape(outputs.logits) == {{1, 10, 1024}}")
85+
print()
86+
print(f"assert_all_close(")
87+
print(f" outputs.logits[[.., 1..3, 1..3]],")
88+
print(f" Nx.tensor([")
89+
rows = []
90+
for i in range(3):
91+
vals = [f"{ref[0, i, j].item():.4f}" for j in range(3)]
92+
rows.append(f" [{', '.join(vals)}]")
93+
print(",\n".join(rows))
94+
print(f" ])")
95+
print(f")")

lib/bumblebee.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ defmodule Bumblebee do
136136
"DistilBertForMultipleChoice" => {Bumblebee.Text.Distilbert, :for_multiple_choice},
137137
"GemmaModel" => {Bumblebee.Text.Gemma, :base},
138138
"GemmaForCausalLM" => {Bumblebee.Text.Gemma, :for_causal_language_modeling},
139+
"Gemma4ForCausalLM" => {Bumblebee.Text.Gemma4, :for_causal_language_modeling},
140+
"Gemma4TextModel" => {Bumblebee.Text.Gemma4, :base},
139141
"GemmaForSequenceClassification" => {Bumblebee.Text.Gemma, :for_sequence_classification},
140142
"Gemma3ForCausalLM" => {Bumblebee.Text.Gemma3Text, :for_causal_language_modeling},
141143
"Gemma3TextModel" => {Bumblebee.Text.Gemma3Text, :base},
@@ -273,6 +275,8 @@ defmodule Bumblebee do
273275
"clip" => :clip,
274276
"gemma" => :gemma,
275277
"gemma3_text" => :gemma,
278+
"gemma4" => :gemma,
279+
"gemma4_text" => :gemma,
276280
"gpt_neox" => :gpt_neo_x,
277281
"gpt2" => :gpt2,
278282
"gpt_bigcode" => :gpt2,

0 commit comments

Comments
 (0)