Skip to content

[nemotron_h] respect _no_reinit flag on dt_bias and out_proj.weight#45591

Open
vai-minzhou wants to merge 1 commit intohuggingface:mainfrom
vai-minzhou:fix-nemotronh-init-overwrite
Open

[nemotron_h] respect _no_reinit flag on dt_bias and out_proj.weight#45591
vai-minzhou wants to merge 1 commit intohuggingface:mainfrom
vai-minzhou:fix-nemotronh-init-overwrite

Conversation

@vai-minzhou
Copy link
Copy Markdown

Summary

NemotronHPreTrainedModel._init_weights unconditionally overwrites two trained parameters every time it is invoked:

  • NemotronHMamba2Mixer.dt_bias — reset to a fresh inv_softplus(random dt) draw
  • {…}.out_proj.weight — reset to a kaiming-uniform scaled by 1/sqrt(num_hidden_layers)

It sets module.dt_bias._no_reinit = True after the copy, but that flag is only checked for the nn.Linear.bias branch of the same function — never for dt_bias itself, and out_proj.weight doesn't set the flag at all.

On transformers>=5.0, _init_weights runs a second time after from_pretrained has finished loading the checkpoint (the post-load pass that initialises tensors still on meta). For NemotronHForCausalLM that silently overwrites the on-disk values for dt_bias and out_proj.weight with fresh random ones, while all other tensors keep their trained values.

The resulting model outputs repetitive filler streams like and and and , and and , for any input — sanity is preserved only when loading through vLLM (which bypasses _init_weights) or via an older transformers release.

Reproduction

import json, pathlib, torch
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForCausalLM

path = "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16"   # any Nemotron-H ckpt
cfg = AutoConfig.from_pretrained(path, trust_remote_code=True)
cfg._attn_implementation = "eager"
m = AutoModelForCausalLM.from_pretrained(path, config=cfg, torch_dtype=torch.bfloat16)

idx = json.load(open(pathlib.Path(path) / "model.safetensors.index.json"))["weight_map"]
k = "backbone.layers.0.mixer.dt_bias"
on_disk = load_file(f"{path}/{idx[k]}")[k]
in_mem  = m.backbone.layers[0].mixer.dt_bias
print((on_disk.float() - in_mem.float().cpu()).abs().max().item())
# → ~26.8 before this patch, 0 after

Prompting "Hello, how are you? I am" on an unpatched load returns ' and' ' in' ' the' ' first' ',' as top-5 next tokens — a symptom of Mamba2 with randomised dt_bias and mis-scaled out_proj. After the patch, trained values are preserved and the model generates normally.

The fix

Both changes live in NemotronHPreTrainedModel._init_weights:

  1. dt_bias branch: early-return if dt_bias._no_reinit is already set (the flag is set at the end of the current branch, so the first pass initialises normally, the second pass becomes a no-op).
  2. out_proj.weight branch: skip when p._no_reinit is set, and set p._no_reinit = True after the initial kaiming scale so a second invocation is a no-op.

Fresh-init training is unaffected — only the second (post-load) invocation is made idempotent. Same edit is mirrored into modular_nemotron_h.py and modeling_nemotron_h.py.

Test plan

  • Unpatched load: |on_disk - in_mem|.max() for layer-0 dt_bias ≈ 26.8, next-token logits return stop-word garbage.
  • Patched load: diff is 0, next-token logits look sane, eval on our NemotronH-based classifier no longer collapses to 1000/1000 parse failures.
  • CI: run tests/models/nemotron_h/ — no behaviour change for fresh-init, only the idempotence of the re-init pass changes.

Please let me know if you'd like the fix to take a different shape (e.g. short-circuit _init_weights entirely when the module's parameters are all materialised, or move the guard to a shared utility in modeling_utils). Happy to adjust.

_init_weights() on `NemotronHPreTrainedModel` unconditionally overwrites
`dt_bias` (random `inv_softplus(dt)`) and `out_proj.weight` (kaiming_uniform
scaled by 1/sqrt(n_layer)) every time it is invoked on a mamba block.
It sets `module.dt_bias._no_reinit = True` after the copy, but the flag is
never checked by either code path (only the Linear-bias branch reads it).

On transformers>=5.0, `_init_weights` is triggered a second time after
`from_pretrained()` has loaded the checkpoint (the post-load safety pass
that initializes tensors staying on `meta`). For `NemotronHForCausalLM`
that silently overwrites the checkpoint values for `dt_bias` and
`out_proj.weight` with fresh random draws. The model then outputs
repetitive stop-word streams like ` and and and and ,` for any input.

Minimal repro with any Nemotron-H checkpoint:

    from transformers import AutoConfig, AutoModelForCausalLM
    from safetensors.torch import load_file
    import json, pathlib

    path = ".../NVIDIA-Nemotron-Cascade-2-30B-A3B-BF16"  # or Nano
    cfg = AutoConfig.from_pretrained(path); cfg._attn_implementation='eager'
    m = AutoModelForCausalLM.from_pretrained(path, config=cfg, torch_dtype='bfloat16')
    idx = json.loads((pathlib.Path(path) / 'model.safetensors.index.json').read_text())['weight_map']
    k = 'backbone.layers.0.mixer.dt_bias'
    on_disk = load_file(f'{path}/{idx[k]}')[k]
    in_mem  = m.backbone.layers[0].mixer.dt_bias
    print((on_disk.float() - in_mem.float().cpu()).abs().max())   # ~26.8

This patch makes `_init_weights` honour `_no_reinit` on both `dt_bias` and
`out_proj.weight` (the only two params that re-init unconditionally), and
sets `_no_reinit = True` on `out_proj.weight` after the initial kaiming
scale so a second pass is a no-op. Ordinary fresh-init training is
unaffected; only the second invocation becomes idempotent.

Signed-off-by: Min Zhou <minzhou@virtueai.com>
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: nemotron_h

@Rocketknight1
Copy link
Copy Markdown
Member

Hey, I'm not sure about this PR! We already have the _is_hf_initialized attribute, so I'm worried about the no_reinit flag that does the same thing, even though it seems like it's already in the codebase. Can you dig a little deeper and figure out why we don't just use _is_hf_initialized here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants