From ae548bf628493f6342466d56c19f383efd254a4e Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 21 Apr 2026 10:52:26 +0000 Subject: [PATCH] Fix EP + DeepSpeed ZeRO-3 loading via accelerate launch Route EP through the standard (non-zero3) loading path when both EP and is_deepspeed_zero3_enabled() are active, then let deepspeed.initialize() wrap the EP-sharded model afterwards. - Add PreTrainedModel.has_ep property; use it in tp_plan - get_init_context: meta device for EP+DS (not zero.Init) - from_pretrained: clear device_map for EP+DS - _load_pretrained_model: skip zero3 path for EP+DS, pass model.tp_plan - _move_missing_keys_from_meta_to_device: do not early-return for EP+DS - _initialize_missing_keys: standard init (no GatheredParameters) for EP+DS - configuration_utils: strip distributed_config from serialized config --- src/transformers/configuration_utils.py | 1 + src/transformers/modeling_utils.py | 46 +++++++++++++++++++++---- 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 4f58a230e352..4ac0a179c008 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -1154,6 +1154,7 @@ def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None: "ignore_keys_at_rope_validation", "base_model_tp_plan", "base_model_pp_plan", + "distributed_config", ]: d.pop(key_to_remove, None) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index db2ef1b3323a..53295a5927f6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1330,12 +1330,18 @@ def post_init(self): self.init_weights() self._backward_compatibility_gradient_checkpointing() + @property + def has_ep(self) -> bool: + """Whether expert parallelism is enabled for this model.""" + distributed_config = getattr(getattr(self, "config", None), "distributed_config", None) + return distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False) + @property def tp_plan(self) -> dict[str, str]: """ The full tp plan for the model's modules """ - if hasattr(self.config, "distributed_config") and self.config.distributed_config.enable_expert_parallel: + if self.has_ep: return self._ep_plan return self._tp_plan @@ -3599,14 +3605,27 @@ def float(self, *args): @classmethod def get_init_context( - cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool, allow_all_kernels: bool | None + cls, + dtype: torch.dtype, + is_quantized: bool, + _is_ds_init_called: bool, + allow_all_kernels: bool | None, + distributed_config=None, ): # Need to instantiate with correct dtype init_contexts = [local_torch_dtype(dtype, cls.__name__), init.no_tie_weights(), apply_patches()] # Needed as we cannot forward the `allow_all_kernels` arg in the model's __init__ if allow_all_kernels: init_contexts.append(allow_all_hub_kernels()) - if is_deepspeed_zero3_enabled(): + _has_ep = distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False) + if _has_ep and is_deepspeed_zero3_enabled(): + # EP + DeepSpeed: use meta device (same as the normal non-DS path). + # zero.Init is skipped because EP needs to shard experts via distribute_model() + # hooks, which are incompatible with ZeRO-3 lazy parameters. + # The standard weight loading path (not zero3) handles EP sharding via + # shard_and_distribute_module. deepspeed.initialize() wraps the result later. + init_contexts.extend([torch.device("meta"), init.meta_device_safe_creation_ops()]) + elif is_deepspeed_zero3_enabled(): import deepspeed # We cannot initialize the model on meta device with deepspeed when not quantized @@ -4007,6 +4026,12 @@ def from_pretrained( download_kwargs_with_commit, **adapter_kwargs, ) + # EP + DeepSpeed: clear device_map (set by initialize_tensor_parallelism) so the model + # loads on CPU first. distribute_model() handles GPU placement during EP sharding. + # Without this, device_map triggers accelerate's dispatch path which breaks shard loading. + _has_ep = distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False) + if _has_ep and is_deepspeed_zero3_enabled(): + device_map = None device_map = check_and_set_device_map(device_map) # warn, error and fix the device map user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} @@ -4110,7 +4135,9 @@ def from_pretrained( register_fusion_patches(cls, config, fusion_config) - model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called, allow_all_kernels) + model_init_context = cls.get_init_context( + dtype, is_quantized, _is_ds_init_called, allow_all_kernels, distributed_config + ) config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. with ContextManagers(model_init_context): @@ -4241,7 +4268,11 @@ def _load_pretrained_model( error_msgs = [] - if is_deepspeed_zero3_enabled() and not is_quantized: + # EP + DeepSpeed: skip zero3 loading path. The model was created on meta device + # (not via zero.Init), so params are not zero3-partitioned. The standard loading + # path handles EP sharding via shard_and_distribute_module using the EP plan hooks + # registered by distribute_model(). deepspeed.initialize() wraps the result later. + if is_deepspeed_zero3_enabled() and not is_quantized and not model.has_ep: if state_dict is None: merged_state_dict = {} for ckpt_file in checkpoint_files: @@ -4551,7 +4582,8 @@ def _move_missing_keys_from_meta_to_device( """ is_quantized = hf_quantizer is not None # This is the only case where we do not initialize the model on meta device, so we don't have to do anything here - if is_deepspeed_zero3_enabled() and not is_quantized: + # Exception: EP + DeepSpeed uses meta device (not zero.Init), so it needs the standard move path. + if is_deepspeed_zero3_enabled() and not is_quantized and not self.has_ep: return # In this case we need to move everything back @@ -4609,7 +4641,7 @@ def _initialize_missing_keys(self, is_quantized: bool) -> None: self._is_hf_initialized = True # This will only initialize submodules that are not marked as initialized by the line above. - if is_deepspeed_zero3_enabled() and not is_quantized: + if is_deepspeed_zero3_enabled() and not is_quantized and not self.has_ep: import deepspeed # keep_vars=True as we need the original tensors, so that the "_is_hf_initialized" is present on them