[nemotron_h] respect _no_reinit flag on dt_bias and out_proj.weight#45591
Open
vai-minzhou wants to merge 1 commit intohuggingface:mainfrom
Open
[nemotron_h] respect _no_reinit flag on dt_bias and out_proj.weight#45591vai-minzhou wants to merge 1 commit intohuggingface:mainfrom
vai-minzhou wants to merge 1 commit intohuggingface:mainfrom
Conversation
_init_weights() on `NemotronHPreTrainedModel` unconditionally overwrites
`dt_bias` (random `inv_softplus(dt)`) and `out_proj.weight` (kaiming_uniform
scaled by 1/sqrt(n_layer)) every time it is invoked on a mamba block.
It sets `module.dt_bias._no_reinit = True` after the copy, but the flag is
never checked by either code path (only the Linear-bias branch reads it).
On transformers>=5.0, `_init_weights` is triggered a second time after
`from_pretrained()` has loaded the checkpoint (the post-load safety pass
that initializes tensors staying on `meta`). For `NemotronHForCausalLM`
that silently overwrites the checkpoint values for `dt_bias` and
`out_proj.weight` with fresh random draws. The model then outputs
repetitive stop-word streams like ` and and and and ,` for any input.
Minimal repro with any Nemotron-H checkpoint:
from transformers import AutoConfig, AutoModelForCausalLM
from safetensors.torch import load_file
import json, pathlib
path = ".../NVIDIA-Nemotron-Cascade-2-30B-A3B-BF16" # or Nano
cfg = AutoConfig.from_pretrained(path); cfg._attn_implementation='eager'
m = AutoModelForCausalLM.from_pretrained(path, config=cfg, torch_dtype='bfloat16')
idx = json.loads((pathlib.Path(path) / 'model.safetensors.index.json').read_text())['weight_map']
k = 'backbone.layers.0.mixer.dt_bias'
on_disk = load_file(f'{path}/{idx[k]}')[k]
in_mem = m.backbone.layers[0].mixer.dt_bias
print((on_disk.float() - in_mem.float().cpu()).abs().max()) # ~26.8
This patch makes `_init_weights` honour `_no_reinit` on both `dt_bias` and
`out_proj.weight` (the only two params that re-init unconditionally), and
sets `_no_reinit = True` on `out_proj.weight` after the initial kaiming
scale so a second pass is a no-op. Ordinary fresh-init training is
unaffected; only the second invocation becomes idempotent.
Signed-off-by: Min Zhou <minzhou@virtueai.com>
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: nemotron_h |
Member
|
Hey, I'm not sure about this PR! We already have the |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
NemotronHPreTrainedModel._init_weightsunconditionally overwrites two trained parameters every time it is invoked:NemotronHMamba2Mixer.dt_bias— reset to a freshinv_softplus(random dt)draw{…}.out_proj.weight— reset to a kaiming-uniform scaled by1/sqrt(num_hidden_layers)It sets
module.dt_bias._no_reinit = Trueafter the copy, but that flag is only checked for thenn.Linear.biasbranch of the same function — never fordt_biasitself, andout_proj.weightdoesn't set the flag at all.On
transformers>=5.0,_init_weightsruns a second time afterfrom_pretrainedhas finished loading the checkpoint (the post-load pass that initialises tensors still onmeta). ForNemotronHForCausalLMthat silently overwrites the on-disk values fordt_biasandout_proj.weightwith fresh random ones, while all other tensors keep their trained values.The resulting model outputs repetitive filler streams like
and and and , and and ,for any input — sanity is preserved only when loading through vLLM (which bypasses_init_weights) or via an older transformers release.Reproduction
Prompting
"Hello, how are you? I am"on an unpatched load returns' and' ' in' ' the' ' first' ','as top-5 next tokens — a symptom of Mamba2 with randomiseddt_biasand mis-scaledout_proj. After the patch, trained values are preserved and the model generates normally.The fix
Both changes live in
NemotronHPreTrainedModel._init_weights:dt_biasbranch: early-return ifdt_bias._no_reinitis already set (the flag is set at the end of the current branch, so the first pass initialises normally, the second pass becomes a no-op).out_proj.weightbranch: skip whenp._no_reinitis set, and setp._no_reinit = Trueafter the initial kaiming scale so a second invocation is a no-op.Fresh-init training is unaffected — only the second (post-load) invocation is made idempotent. Same edit is mirrored into
modular_nemotron_h.pyandmodeling_nemotron_h.py.Test plan
|on_disk - in_mem|.max()for layer-0 dt_bias ≈ 26.8, next-token logits return stop-word garbage.tests/models/nemotron_h/— no behaviour change for fresh-init, only the idempotence of the re-init pass changes.Please let me know if you'd like the fix to take a different shape (e.g. short-circuit
_init_weightsentirely when the module's parameters are all materialised, or move the guard to a shared utility inmodeling_utils). Happy to adjust.