Skip to content

Commit db6e43b

Browse files
committed
Add Gemma 4 model family support
Implement the text-only portion of Google DeepMind Gemma 4 architecture: - Hybrid attention: alternating sliding window and full attention layers - Dual RoPE: proportional RoPE for full attention, default for sliding - Per-Layer Embeddings (PLE): per-layer token-dependent gating - KV sharing: later layers reuse KV from earlier layers of same type - Q/K/V normalization: RMS normalization on query, key, and value - Per-layer scalar: learned scaling factor per transformer block - Optional MoE: mixture-of-experts FFN blocks (26B-A4B variant) Architectures: :base (Gemma4TextModel), :for_causal_language_modeling (Gemma4ForCausalLM). Multimodal Gemma4ForConditionalGeneration is not yet supported. Uses a custom decoder loop rather than Layers.Transformer.blocks/2 because the model requires features not available in the shared infrastructure: per-layer embeddings threaded through the block loop, cross-block KV sharing state, per-layer head dimension variation, and value normalization. Includes integration test verified against Python transformers reference values (atol < 5e-5).
1 parent 0b397f6 commit db6e43b

File tree

3 files changed

+1413
-0
lines changed

3 files changed

+1413
-0
lines changed

lib/bumblebee.ex

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ defmodule Bumblebee do
136136
"DistilBertForMultipleChoice" => {Bumblebee.Text.Distilbert, :for_multiple_choice},
137137
"GemmaModel" => {Bumblebee.Text.Gemma, :base},
138138
"GemmaForCausalLM" => {Bumblebee.Text.Gemma, :for_causal_language_modeling},
139+
"Gemma4ForCausalLM" => {Bumblebee.Text.Gemma4, :for_causal_language_modeling},
140+
"Gemma4ForConditionalGeneration" => {Bumblebee.Text.Gemma4, :for_causal_language_modeling},
141+
"Gemma4TextModel" => {Bumblebee.Text.Gemma4, :base},
139142
"GemmaForSequenceClassification" => {Bumblebee.Text.Gemma, :for_sequence_classification},
140143
"Gemma3ForCausalLM" => {Bumblebee.Text.Gemma3Text, :for_causal_language_modeling},
141144
"Gemma3TextModel" => {Bumblebee.Text.Gemma3Text, :base},
@@ -273,6 +276,8 @@ defmodule Bumblebee do
273276
"clip" => :clip,
274277
"gemma" => :gemma,
275278
"gemma3_text" => :gemma,
279+
"gemma4" => :gemma,
280+
"gemma4_text" => :gemma,
276281
"gpt_neox" => :gpt_neo_x,
277282
"gpt2" => :gpt2,
278283
"gpt_bigcode" => :gpt2,

0 commit comments

Comments
 (0)