Fix EP + DeepSpeed ZeRO-3 loading via accelerate launch#45548
Fix EP + DeepSpeed ZeRO-3 loading via accelerate launch#45548AmineDiro wants to merge 3 commits intohuggingface:mainfrom
Conversation
Route EP through the standard (non-zero3) loading path when both EP and is_deepspeed_zero3_enabled() are active, then let deepspeed.initialize() wrap the EP-sharded model afterwards. - Add PreTrainedModel.has_ep property; use it in tp_plan - get_init_context: meta device for EP+DS (not zero.Init) - from_pretrained: clear device_map for EP+DS - _load_pretrained_model: skip zero3 path for EP+DS, pass model.tp_plan - _move_missing_keys_from_meta_to_device: do not early-return for EP+DS - _initialize_missing_keys: standard init (no GatheredParameters) for EP+DS - configuration_utils: strip distributed_config from serialized config
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
I do want and did not know we could bypass zero's "sharding" of the weights. This would be quite ideal if we do the sharding / device placement.
I am not seeing where
lets distribute_model() shard experts as usual, and then lets deepspeed.initialize() wrap the already-loaded, already-sharded model afterward.
happens?
What I mean by this is before did you already have to call deepspeed.initialize() or not and why don't we return deepspeed.initialize(self)?
( I could be lacking context so cc @SunMarc as well)
| _has_ep = distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False) | ||
| if _has_ep and is_deepspeed_zero3_enabled(): | ||
| # EP + DeepSpeed: use meta device (same as the normal non-DS path). | ||
| # zero.Init is skipped because EP needs to shard experts via distribute_model() | ||
| # hooks, which are incompatible with ZeRO-3 lazy parameters. | ||
| # The standard weight loading path (not zero3) handles EP sharding via | ||
| # shard_and_distribute_module. deepspeed.initialize() wraps the result later. | ||
| init_contexts.extend([torch.device("meta"), init.meta_device_safe_creation_ops()]) | ||
| elif is_deepspeed_zero3_enabled(): |
There was a problem hiding this comment.
btw the fix can probably be applied to all DS3 no? not just this?
DS3 does not work with moe, nor with any "conversion"no ?
There was a problem hiding this comment.
Good point, but I'd keep it narrow to EP+DS3 for now. Two reasons:
- I think dense + DS3 still benefits from zero.Init. From what I understand the whole purpose of zero.Init is to avoid a peak-memory spike by partitioning params across ranks during creation. If we switched dense DS3 to meta + standard loader, every rank would temporarily hold a full-precision copy of the weights before deepspeed.initialize partitions them 🤔
nor with any "conversion"no ?
- Yes you're right, I just realized that. The
_load_state_dict_into_zero3_modeldoesn't runWeightConverter, so any model that depends on a conversion mapping will silently load the wrong keys under DS3. I think, though, this might need a separate issue. Fixing it probably means routing the DS3 throughconvert_and_load_state_dict_in_modelwith a gather around the per-param ?? I'm happy to open a follow-up issue.
wdyt ?
There was a problem hiding this comment.
partitioning params across ranks during creation.
we do shard on read, so in the same way I think our peak memory is equivalent (tho I might be wrong)
for 2. yeah that can be nice
There was a problem hiding this comment.
Oh ok I missed that part , I'll read more carefully! So I can open up an issue with suggested fix 👍🏼
The # modeling_utils.py ~L4164
if _torch_distributed_available and device_mesh is not None:
model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)
Outside
|
Yeah, I was surprised it worked too 😅 . But |
|
For Dense models the whole purpose of our weight loader is to avoid holding all tensor, but instead of reading on 1 rank / sharding (super slow) we leverage shard on read. But yeah otherwise fine to keep it scope but weight converter is important! |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45548&sha=d98bad |
Issue
Expert Parallelism (
DistributedConfig(enable_expert_parallel=True)) hangs during model loading when launched throughaccelerate launchwith a DeepSpeed ZeRO-3 config. EP works on its own (viatorchrun) and ZeRO-3 works on its own — but the two conflict insidefrom_pretrainedbecause every ZeRO-3 code path is gated on a single env-driven flag (is_deepspeed_zero3_enabled()), and EP needs the non-ZeRO-3 path at every one of those gates.When you run EP through
accelerate launchwith DeepSpeed ZeRO-3, the environment variable makesis_deepspeed_zero3_enabled()returnTrueeverywhere. Every gate takes the ZeRO-3 path. But EP is fundamentally incompatible with ZeRO-3's initialization, as they shard weights in completely different ways:deepspeed.zero.Init(), then loads weights throughGatheredParameters(all-gather before writing, re-partition after).distribute_model(), then loads weights through the standard path whereshard_and_distribute_moduleslices each expert tensor by EP rank.These two can't coexist in the same loading flow. ZeRO-3's lazy params break EP's sharding hooks. EP's meta tensors break ZeRO-3's
GatheredParameters(which expectsds_id,ds_shapeattributes).Fix
This PR routes EP through the standard (non-zero3) path inside
from_pretrained, letsdistribute_model()shard experts as usual, and then letsdeepspeed.initialize()wrap the already-loaded, already-sharded model afterward.get_init_contextacceptsdistributed_config; when EP+DS, use meta device (notzero.Init, not real tensors). Meta allocation is free, andinit_weights()is skipped; checkpoint weights overwrite everything anyway.from_pretrainedclearsdevice_mapset byinitialize_tensor_parallelismwhen EP+DS. EP needs all ranks to read all shard files for the hooks, so we skip theacceleratedispatch split._load_pretrained_modelwhen EP+DS, skips the zero3 loading branch and uses the standardconvert_and_load_state_dict_in_modelpath, passingmodel.tp_plan(the property, which returns the EP plan when EP is on) instead ofmodel._tp_plan._move_missing_keys_from_meta_to_devicewhen EP+DS, does not early-return; runs the standard path to move meta buffers (inv_freq, etc.) to CPU._initialize_missing_keyswhen EP+DS, uses standardinitialize_weights()(noGatheredParameters, since params are real/empty, not ZeRO-3-partitioned).Test
Minimal EP + DeepSpeed ZeRO-3 verification. Smulates
accelerate launchby setting HfDeepSpeedConfig sois_deepspeed_zero3_enabled()returns True and the signal that madefrom_pretrainedhang before the fix.Run on 4xH100:
Before submitting
guideline, Pull Request section?
Who can review?
@3outeille @ArthurZucker (distributed / TP / EP implementation)