Skip to content

Commit 14dd40e

Browse files
Add Nomic BERT model (#440)
1 parent 07fd98e commit 14dd40e

File tree

4 files changed

+396
-0
lines changed

4 files changed

+396
-0
lines changed

lib/bumblebee.ex

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ defmodule Bumblebee do
181181
"MPNetForTokenClassification" => {Bumblebee.Text.MpNet, :for_token_classification},
182182
"MPNetForQuestionAnswering" => {Bumblebee.Text.MpNet, :for_question_answering},
183183
"MPNetForMultipleChoice" => {Bumblebee.Text.MpNet, :for_multiple_choice},
184+
"NomicBertModel" => {Bumblebee.Text.NomicBert, :base},
184185
"PhiModel" => {Bumblebee.Text.Phi, :base},
185186
"PhiForCausalLM" => {Bumblebee.Text.Phi, :for_causal_language_modeling},
186187
"PhiForSequenceClassification" => {Bumblebee.Text.Phi, :for_sequence_classification},
@@ -272,6 +273,7 @@ defmodule Bumblebee do
272273
"mistral" => :llama,
273274
"mbart" => :mbart,
274275
"mpnet" => :mpnet,
276+
"nomic_bert" => :bert,
275277
"phi" => :code_gen,
276278
"phi3" => :llama,
277279
"qwen3" => :qwen2,

lib/bumblebee/text/nomic_bert.ex

Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
defmodule Bumblebee.Text.NomicBert do
2+
alias Bumblebee.Shared
3+
4+
options =
5+
[
6+
vocab_size: [
7+
default: 30528,
8+
doc: """
9+
the vocabulary size of the token embedding. This corresponds to the number of distinct
10+
tokens that can be represented in model input and output
11+
"""
12+
],
13+
max_positions: [
14+
default: 8192,
15+
doc: """
16+
the maximum sequence length that this model can process. Typically this is set to a large
17+
value just in case, such as 512, 1024 or 2048
18+
"""
19+
],
20+
max_token_types: [
21+
default: 2,
22+
doc: """
23+
the vocabulary size of the token type embedding (also referred to as segment embedding).
24+
This corresponds to how many different token groups can be distinguished in the input
25+
"""
26+
],
27+
hidden_size: [
28+
default: 768,
29+
doc: "the dimensionality of hidden layers"
30+
],
31+
num_blocks: [
32+
default: 12,
33+
doc: "the number of Transformer blocks in the encoder"
34+
],
35+
num_attention_heads: [
36+
default: 12,
37+
doc: "the number of attention heads for each attention layer in the encoder"
38+
],
39+
intermediate_size: [
40+
default: nil,
41+
doc:
42+
"the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder. Defaults to 4 * hidden_size"
43+
],
44+
activation: [
45+
default: :silu,
46+
doc: "the activation function"
47+
],
48+
rotary_embedding_base: [
49+
default: 1000,
50+
doc: "base for computing rotary embedding frequency"
51+
],
52+
rotary_embedding_percentage: [
53+
default: 1.0,
54+
doc: "percentage of hidden size to use for rotary embeddings"
55+
],
56+
layer_norm_epsilon: [
57+
default: 1.0e-5,
58+
doc: "the epsilon used by the layer normalization layers"
59+
],
60+
initializer_scale: [
61+
default: 0.02,
62+
doc:
63+
"the standard deviation of the normal initializer used for initializing kernel parameters"
64+
],
65+
ffn_gate_bias: [
66+
default: true,
67+
doc: "whether to use bias in the up and gate projections of the FFN"
68+
],
69+
ffn_output_bias: [
70+
default: true,
71+
doc: "whether to use bias in the output projection of the FFN"
72+
]
73+
] ++ Shared.common_options([:num_labels, :id_to_label])
74+
75+
@moduledoc """
76+
Nomic BERT model family.
77+
78+
This is a variant of BERT that uses:
79+
- Rotary position embeddings (RoPE) instead of absolute position embeddings
80+
- SwiGLU activation in the feed-forward network
81+
- Post-normalization (like original BERT)
82+
- No biases in attention and feed-forward layers
83+
84+
## Architectures
85+
86+
* `:base` - plain Nomic BERT without any head on top
87+
88+
## Inputs
89+
90+
* `"input_ids"` - `{batch_size, sequence_length}`
91+
92+
Indices of input sequence tokens in the vocabulary.
93+
94+
* `"attention_mask"` - `{batch_size, sequence_length}`
95+
96+
Mask indicating which tokens to attend to. This is used to ignore
97+
padding tokens, which are added when processing a batch of sequences
98+
with different length.
99+
100+
* `"position_ids"` - `{batch_size, sequence_length}`
101+
102+
Indices of positions of each input sequence token used when applying
103+
rotary position embeddings (RoPE).
104+
105+
## Global layer options
106+
107+
#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}
108+
109+
## Configuration
110+
111+
#{Shared.options_doc(options)}
112+
"""
113+
114+
defstruct [architecture: :base] ++ Shared.option_defaults(options)
115+
116+
@behaviour Bumblebee.ModelSpec
117+
@behaviour Bumblebee.Configurable
118+
119+
import Bumblebee.Utils.Model, only: [join: 2]
120+
121+
alias Bumblebee.Layers
122+
123+
@impl true
124+
def architectures(), do: [:base]
125+
126+
@impl true
127+
def config(spec, opts) do
128+
spec
129+
|> Shared.put_config_attrs(opts)
130+
|> Shared.validate_label_options()
131+
end
132+
133+
@impl true
134+
def input_template(_spec) do
135+
%{"input_ids" => Nx.template({1, 1}, :u32)}
136+
end
137+
138+
@impl true
139+
def model(%__MODULE__{architecture: :base} = spec) do
140+
inputs = inputs(spec)
141+
142+
inputs
143+
|> core(spec)
144+
|> Layers.output()
145+
end
146+
147+
defp inputs(spec) do
148+
shape = {nil, nil}
149+
attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads}
150+
151+
Bumblebee.Utils.Model.inputs_to_map([
152+
Axon.input("input_ids", shape: shape),
153+
Axon.input("attention_mask", optional: true, shape: shape),
154+
Axon.input("token_type_ids", optional: true, shape: shape),
155+
Axon.input("position_ids", optional: true, shape: shape),
156+
Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape)
157+
])
158+
end
159+
160+
defp core(inputs, spec) do
161+
token_type_ids =
162+
Layers.default inputs["token_type_ids"] do
163+
Layers.default_token_type_ids(inputs["input_ids"])
164+
end
165+
166+
embeddings = embedder(inputs["input_ids"], token_type_ids, spec, name: "embedder")
167+
168+
position_ids =
169+
Layers.default inputs["position_ids"] do
170+
Layers.default_position_ids(embeddings)
171+
end
172+
173+
encoder_outputs =
174+
encoder(
175+
embeddings,
176+
position_ids,
177+
inputs["attention_mask"],
178+
inputs["attention_head_mask"],
179+
spec,
180+
name: "encoder"
181+
)
182+
183+
pooled_state = pooler(encoder_outputs.hidden_state, spec, name: "pooler")
184+
185+
%{
186+
hidden_state: encoder_outputs.hidden_state,
187+
pooled_state: pooled_state,
188+
hidden_states: encoder_outputs.hidden_states,
189+
attentions: encoder_outputs.attentions
190+
}
191+
end
192+
193+
defp embedder(input_ids, token_type_ids, spec, opts) do
194+
name = opts[:name]
195+
196+
token_embeddings =
197+
Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size,
198+
kernel_initializer: kernel_initializer(spec),
199+
name: join(name, "token_embedding")
200+
)
201+
202+
token_type_embeddings =
203+
Axon.embedding(token_type_ids, spec.max_token_types, spec.hidden_size,
204+
kernel_initializer: kernel_initializer(spec),
205+
name: join(name, "token_type_embedding")
206+
)
207+
208+
Axon.add([token_embeddings, token_type_embeddings])
209+
|> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "norm"))
210+
end
211+
212+
defp encoder(hidden_state, position_ids, attention_mask, attention_head_mask, spec, opts) do
213+
name = opts[:name]
214+
215+
Layers.Transformer.blocks(hidden_state,
216+
attention_mask: attention_mask,
217+
attention_head_mask: attention_head_mask,
218+
num_blocks: spec.num_blocks,
219+
num_attention_heads: spec.num_attention_heads,
220+
hidden_size: spec.hidden_size,
221+
kernel_initializer: kernel_initializer(spec),
222+
layer_norm: [epsilon: spec.layer_norm_epsilon],
223+
ffn:
224+
&gated_ffn(&1, intermediate_size(spec), spec.hidden_size,
225+
name: &2,
226+
activation: spec.activation,
227+
gate_use_bias: spec.ffn_gate_bias,
228+
output_use_bias: spec.ffn_output_bias
229+
),
230+
block_type: :standard,
231+
causal: false,
232+
rotary_embedding: [
233+
position_ids: position_ids,
234+
max_positions: spec.max_positions,
235+
base: spec.rotary_embedding_base,
236+
percentage: spec.rotary_embedding_percentage
237+
],
238+
query_use_bias: false,
239+
key_use_bias: false,
240+
value_use_bias: false,
241+
output_use_bias: false,
242+
name: join(name, "blocks")
243+
)
244+
end
245+
246+
defp pooler(hidden_state, spec, opts) do
247+
name = opts[:name]
248+
249+
hidden_state
250+
|> Layers.take_token(index: 0, axis: 1)
251+
|> Axon.dense(spec.hidden_size,
252+
kernel_initializer: kernel_initializer(spec),
253+
name: join(name, "output")
254+
)
255+
|> Axon.tanh()
256+
end
257+
258+
defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do
259+
name = opts[:name]
260+
activation = opts[:activation]
261+
gate_use_bias = opts[:gate_use_bias]
262+
output_use_bias = opts[:output_use_bias]
263+
264+
# Nomic MLP: y = fc11(x) * activation(fc12(x)), then fc2
265+
# fc11 is "up", fc12 is "gate", fc2 is "down"
266+
up =
267+
Axon.dense(hidden_state, intermediate_size,
268+
name: join(name, "up"),
269+
use_bias: gate_use_bias
270+
)
271+
272+
gate =
273+
Axon.dense(hidden_state, intermediate_size,
274+
name: join(name, "gate"),
275+
use_bias: gate_use_bias
276+
)
277+
278+
# Nomic applies activation to gate, not up: up * activation(gate)
279+
hidden_state = Axon.multiply(up, Axon.activation(gate, activation))
280+
281+
Axon.dense(hidden_state, output_size, name: join(name, "down"), use_bias: output_use_bias)
282+
end
283+
284+
defp kernel_initializer(spec) do
285+
Axon.Initializers.normal(scale: spec.initializer_scale)
286+
end
287+
288+
# NomicBERT rounds intermediate_size to nearest multiple of 256 for hardware efficiency
289+
defp intermediate_size(spec) do
290+
size = spec.intermediate_size || div(8 * spec.hidden_size, 3)
291+
multiple_of = 256
292+
div(size + multiple_of - 1, multiple_of) * multiple_of
293+
end
294+
295+
defimpl Bumblebee.HuggingFace.Transformers.Config do
296+
def load(spec, data) do
297+
import Shared.Converters
298+
299+
opts =
300+
convert!(data,
301+
vocab_size: {"vocab_size", number()},
302+
max_positions: {"n_positions", number()},
303+
max_token_types: {"type_vocab_size", number()},
304+
hidden_size: {"n_embd", number()},
305+
num_blocks: {"n_layer", number()},
306+
num_attention_heads: {"n_head", number()},
307+
intermediate_size: {"n_inner", optional(number())},
308+
rotary_embedding_base: {"rotary_emb_base", number()},
309+
rotary_embedding_percentage: {"rotary_emb_fraction", optional(number())},
310+
layer_norm_epsilon: {"layer_norm_epsilon", number()},
311+
initializer_scale: {"initializer_range", number()},
312+
ffn_gate_bias: {"mlp_fc1_bias", boolean()},
313+
ffn_output_bias: {"mlp_fc2_bias", boolean()}
314+
) ++ Shared.common_options_from_transformers(data, spec)
315+
316+
@for.config(spec, opts)
317+
end
318+
end
319+
320+
defimpl Bumblebee.HuggingFace.Transformers.Model do
321+
def params_mapping(_spec) do
322+
%{
323+
"embedder.token_embedding" => "embeddings.word_embeddings",
324+
"embedder.token_type_embedding" => "embeddings.token_type_embeddings",
325+
"embedder.norm" => "emb_ln",
326+
"encoder.blocks.{n}.self_attention.query" => qkv_dense("encoder.layers.{n}.attn.Wqkv", 0),
327+
"encoder.blocks.{n}.self_attention.key" => qkv_dense("encoder.layers.{n}.attn.Wqkv", 1),
328+
"encoder.blocks.{n}.self_attention.value" => qkv_dense("encoder.layers.{n}.attn.Wqkv", 2),
329+
"encoder.blocks.{n}.self_attention.output" => "encoder.layers.{n}.attn.out_proj",
330+
"encoder.blocks.{n}.self_attention_norm" => "encoder.layers.{n}.norm1",
331+
"encoder.blocks.{n}.ffn.up" => "encoder.layers.{n}.mlp.fc11",
332+
"encoder.blocks.{n}.ffn.gate" => "encoder.layers.{n}.mlp.fc12",
333+
"encoder.blocks.{n}.ffn.down" => "encoder.layers.{n}.mlp.fc2",
334+
"encoder.blocks.{n}.output_norm" => "encoder.layers.{n}.norm2",
335+
"pooler.output" => "pooler.dense"
336+
}
337+
end
338+
339+
defp qkv_dense(source_layer_name, chunk_idx) do
340+
# Wqkv is [3 * hidden_size, hidden_size] in PyTorch format
341+
# After slicing, transpose to get [hidden_size, hidden_size] for Axon
342+
%{
343+
"kernel" => {
344+
[{source_layer_name, "weight"}],
345+
fn [kernel] ->
346+
size = Nx.axis_size(kernel, 0)
347+
step = div(size, 3)
348+
349+
kernel
350+
|> Nx.slice_along_axis(chunk_idx * step, step, axis: 0)
351+
|> Nx.transpose()
352+
end
353+
}
354+
}
355+
end
356+
end
357+
end

mix.exs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ defmodule Bumblebee.MixProject do
102102
Bumblebee.Text.Mbart,
103103
Bumblebee.Text.Mistral,
104104
Bumblebee.Text.MpNet,
105+
Bumblebee.Text.NomicBert,
105106
Bumblebee.Text.Phi,
106107
Bumblebee.Text.Phi3,
107108
Bumblebee.Text.Roberta,

0 commit comments

Comments
 (0)