Skip to content

Fix EP + DeepSpeed ZeRO-3 loading via accelerate launch#45548

Open
AmineDiro wants to merge 3 commits intohuggingface:mainfrom
AmineDiro:fix-deepspeed-ep-init
Open

Fix EP + DeepSpeed ZeRO-3 loading via accelerate launch#45548
AmineDiro wants to merge 3 commits intohuggingface:mainfrom
AmineDiro:fix-deepspeed-ep-init

Conversation

@AmineDiro
Copy link
Copy Markdown
Member

@AmineDiro AmineDiro commented Apr 21, 2026

Issue

Expert Parallelism (DistributedConfig(enable_expert_parallel=True)) hangs during model loading when launched through accelerate launch with a DeepSpeed ZeRO-3 config. EP works on its own (via torchrun) and ZeRO-3 works on its own — but the two conflict inside from_pretrained because every ZeRO-3 code path is gated on a single env-driven flag (is_deepspeed_zero3_enabled()), and EP needs the non-ZeRO-3 path at every one of those gates.

When you run EP through accelerate launch with DeepSpeed ZeRO-3, the environment variable makes is_deepspeed_zero3_enabled() return True everywhere. Every gate takes the ZeRO-3 path. But EP is fundamentally incompatible with ZeRO-3's initialization, as they shard weights in completely different ways:

  • ZeRO-3: creates lazy partitioned params via deepspeed.zero.Init(), then loads weights through GatheredParameters (all-gather before writing, re-partition after).
  • EP: creates a model on the meta device, registers sharding hooks via distribute_model(), then loads weights through the standard path where shard_and_distribute_module slices each expert tensor by EP rank.

These two can't coexist in the same loading flow. ZeRO-3's lazy params break EP's sharding hooks. EP's meta tensors break ZeRO-3's GatheredParameters (which expects ds_id, ds_shape attributes).

Fix

This PR routes EP through the standard (non-zero3) path inside from_pretrained, lets distribute_model() shard experts as usual, and then lets deepspeed.initialize() wrap the already-loaded, already-sharded model afterward.

  1. get_init_context accepts distributed_config; when EP+DS, use meta device (not zero.Init, not real tensors). Meta allocation is free, and init_weights() is skipped; checkpoint weights overwrite everything anyway.
  2. from_pretrained clears device_map set by initialize_tensor_parallelism when EP+DS. EP needs all ranks to read all shard files for the hooks, so we skip the accelerate dispatch split.
  3. _load_pretrained_model when EP+DS, skips the zero3 loading branch and uses the standard convert_and_load_state_dict_in_model path, passing model.tp_plan (the property, which returns the EP plan when EP is on) instead of model._tp_plan.
  4. _move_missing_keys_from_meta_to_device when EP+DS, does not early-return; runs the standard path to move meta buffers (inv_freq, etc.) to CPU.
  5. _initialize_missing_keys when EP+DS, uses standard initialize_weights() (no GatheredParameters, since params are real/empty, not ZeRO-3-partitioned).

Test

Minimal EP + DeepSpeed ZeRO-3 verification. Smulates accelerate launch by setting HfDeepSpeedConfig so
is_deepspeed_zero3_enabled() returns True and the signal that made from_pretrained hang before the fix.

Run on 4xH100:

import os
import torch
import torch.distributed as dist
import deepspeed
from deepspeed import comm as ds_comm

from transformers import AutoModelForCausalLM
from transformers.integrations.deepspeed import HfDeepSpeedConfig
from transformers.distributed.configuration_utils import DistributedConfig

local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])

dist.init_process_group("nccl")
torch.cuda.set_device(local_rank)

ds_config = {
    "train_batch_size": world_size,
    "train_micro_batch_size_per_gpu": 1,
    "bf16": {"enabled": True},
    "zero_optimization": {"stage": 3, "overlap_comm": True, "contiguous_gradients": True},
}
_dschf = HfDeepSpeedConfig(ds_config)  # strong ref; the global weakref dies if GC'd

mesh = dist.init_device_mesh("cuda", (world_size,))
model = AutoModelForCausalLM.from_pretrained(
    "openai/gpt-oss-20b",
    dtype=torch.bfloat16,
    distributed_config=DistributedConfig(enable_expert_parallel=True),
    device_mesh=mesh,
    attn_implementation="eager",
)
model = model.to(f"cuda:{local_rank}")

if dist.get_rank() == 0:
    w = model.model.layers[0].mlp.experts.gate_up_proj
    print(f"expert shape per rank (post-EP, pre-DS): {tuple(w.shape)}  (32 experts / EP=4 = 8)")

ds_comm.init_distributed("nccl")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
engine, optimizer, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config)

x = torch.randint(0, 1000, (1, 8), device=f"cuda:{local_rank}")
out = engine(input_ids=x, labels=x.clone(), use_cache=False)
engine.backward(out.loss)
engine.step()

if dist.get_rank() == 0:
    print(f"loss={out.loss.item():.4f}")

dist.destroy_process_group()

Before submitting

  • I confirm that this is not a pure code agent PR.
  • Did you read the contributor
    guideline
    , Pull Request section?

Who can review?

@3outeille @ArthurZucker (distributed / TP / EP implementation)

Route EP through the standard (non-zero3) loading path when both EP
and is_deepspeed_zero3_enabled() are active, then let deepspeed.initialize()
wrap the EP-sharded model afterwards.

- Add PreTrainedModel.has_ep property; use it in tp_plan
- get_init_context: meta device for EP+DS (not zero.Init)
- from_pretrained: clear device_map for EP+DS
- _load_pretrained_model: skip zero3 path for EP+DS, pass model.tp_plan
- _move_missing_keys_from_meta_to_device: do not early-return for EP+DS
- _initialize_missing_keys: standard init (no GatheredParameters) for EP+DS
- configuration_utils: strip distributed_config from serialized config
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do want and did not know we could bypass zero's "sharding" of the weights. This would be quite ideal if we do the sharding / device placement.

I am not seeing where

lets distribute_model() shard experts as usual, and then lets deepspeed.initialize() wrap the already-loaded, already-sharded model afterward.

happens?
What I mean by this is before did you already have to call deepspeed.initialize() or not and why don't we return deepspeed.initialize(self)?
( I could be lacking context so cc @SunMarc as well)

Comment on lines +3620 to +3628
_has_ep = distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False)
if _has_ep and is_deepspeed_zero3_enabled():
# EP + DeepSpeed: use meta device (same as the normal non-DS path).
# zero.Init is skipped because EP needs to shard experts via distribute_model()
# hooks, which are incompatible with ZeRO-3 lazy parameters.
# The standard weight loading path (not zero3) handles EP sharding via
# shard_and_distribute_module. deepspeed.initialize() wraps the result later.
init_contexts.extend([torch.device("meta"), init.meta_device_safe_creation_ops()])
elif is_deepspeed_zero3_enabled():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw the fix can probably be applied to all DS3 no? not just this?
DS3 does not work with moe, nor with any "conversion"no ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, but I'd keep it narrow to EP+DS3 for now. Two reasons:

  1. I think dense + DS3 still benefits from zero.Init. From what I understand the whole purpose of zero.Init is to avoid a peak-memory spike by partitioning params across ranks during creation. If we switched dense DS3 to meta + standard loader, every rank would temporarily hold a full-precision copy of the weights before deepspeed.initialize partitions them 🤔

nor with any "conversion"no ?

  1. Yes you're right, I just realized that. The _load_state_dict_into_zero3_model doesn't run WeightConverter, so any model that depends on a conversion mapping will silently load the wrong keys under DS3. I think, though, this might need a separate issue. Fixing it probably means routing the DS3 through convert_and_load_state_dict_in_model with a gather around the per-param ?? I'm happy to open a follow-up issue.

wdyt ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

partitioning params across ranks during creation.

we do shard on read, so in the same way I think our peak memory is equivalent (tho I might be wrong)

for 2. yeah that can be nice

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok I missed that part , I'll read more carefully! So I can open up an issue with suggested fix 👍🏼

@AmineDiro
Copy link
Copy Markdown
Member Author

I am not seeing where

lets distribute_model() shard experts as usual, and then lets deepspeed.initialize() wrap the already-loaded, already-sharded model afterward.

happens? What I mean by this is before did you already have to call deepspeed.initialize() or not and why don't we return deepspeed.initialize(self)? ( I could be lacking context so cc @SunMarc as well)

The distribute_model() and deepspeed.initialize() sequence happens across two files and it's not all in this PR. Inside from_pretrained we have :

# modeling_utils.py ~L4164 
if _torch_distributed_available and device_mesh is not None:
   model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)

distribute_model() runs on the meta-device model before weight loading. It only registers hooks on the EP-relevant modules and nothing else. Then _load_pretrained_model runs, and when it hits an expert key that matches the EP plan, shard_and_distribute_module calls GroupedGemmParallel.shard_tensor() slices the expert dim by EP rank during the write. At the end, we get back a regular PreTrainedModel whose expert params are already the per-rank shards ( for example [8, 2880, 5760] instead of [32, 2880, 5760] for gpt-oss-20b on EP=4).

Outside from_pretrained: accelerator.prepare(model) calls deepspeed.initialize(model, optimizer, config) afterwards, which ZeRO-3-wraps the already-sharded model. This part is entirely the caller's responsibility from what I can understand 🤔

why not return deepspeed.initialize(self)?
Do you mean to return it from from_pretrained directly? If that's the case, I think the API would need to change. The deepspeed.initialize needs an optimizer and a full DS config. Also, in accelerate we call the from_pretrained, construct an optimizerthen calls accelerator.prepare(model, optimizer) which internally does
deepspeed initialize. If from_pretrained returned a DeepSpeedEngine, accelerator.prepare would re-wrap it and probably fail ??

@AmineDiro
Copy link
Copy Markdown
Member Author

I do want and did not know we could bypass zero's "sharding" of the weights. This would be quite ideal if we do the sharding / device placement.

Yeah, I was surprised it worked too 😅 . But deepspeed.initialize itself will ZeRO-3-partition whatever it's handed. but zero.Init is probabl required to avoid holding the full model on each rank during creation, that's why think it's a clearner pattern. Skipping zero.Init would work in models where we have TP/EP but for densemodels where we don't have TP sharding, it would probably OOM :/

@ArthurZucker
Copy link
Copy Markdown
Collaborator

For Dense models the whole purpose of our weight loader is to avoid holding all tensor, but instead of reading on 1 rank / sharding (super slow) we leverage shard on read. But yeah otherwise fine to keep it scope but weight converter is important!

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45548&sha=d98bad

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants