Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,7 @@ def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
"ignore_keys_at_rope_validation",
"base_model_tp_plan",
"base_model_pp_plan",
"distributed_config",
]:
d.pop(key_to_remove, None)

Expand Down
46 changes: 39 additions & 7 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,12 +1330,18 @@ def post_init(self):
self.init_weights()
self._backward_compatibility_gradient_checkpointing()

@property
def has_ep(self) -> bool:
"""Whether expert parallelism is enabled for this model."""
distributed_config = getattr(getattr(self, "config", None), "distributed_config", None)
return distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False)

@property
def tp_plan(self) -> dict[str, str]:
"""
The full tp plan for the model's modules
"""
if hasattr(self.config, "distributed_config") and self.config.distributed_config.enable_expert_parallel:
if self.has_ep:
return self._ep_plan
return self._tp_plan

Expand Down Expand Up @@ -3599,14 +3605,27 @@ def float(self, *args):

@classmethod
def get_init_context(
cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool, allow_all_kernels: bool | None
cls,
dtype: torch.dtype,
is_quantized: bool,
_is_ds_init_called: bool,
allow_all_kernels: bool | None,
distributed_config=None,
):
# Need to instantiate with correct dtype
init_contexts = [local_torch_dtype(dtype, cls.__name__), init.no_tie_weights(), apply_patches()]
# Needed as we cannot forward the `allow_all_kernels` arg in the model's __init__
if allow_all_kernels:
init_contexts.append(allow_all_hub_kernels())
if is_deepspeed_zero3_enabled():
_has_ep = distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False)
if _has_ep and is_deepspeed_zero3_enabled():
# EP + DeepSpeed: use meta device (same as the normal non-DS path).
# zero.Init is skipped because EP needs to shard experts via distribute_model()
# hooks, which are incompatible with ZeRO-3 lazy parameters.
# The standard weight loading path (not zero3) handles EP sharding via
# shard_and_distribute_module. deepspeed.initialize() wraps the result later.
init_contexts.extend([torch.device("meta"), init.meta_device_safe_creation_ops()])
elif is_deepspeed_zero3_enabled():
Comment on lines +3667 to +3675
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw the fix can probably be applied to all DS3 no? not just this?
DS3 does not work with moe, nor with any "conversion"no ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, but I'd keep it narrow to EP+DS3 for now. Two reasons:

  1. I think dense + DS3 still benefits from zero.Init. From what I understand the whole purpose of zero.Init is to avoid a peak-memory spike by partitioning params across ranks during creation. If we switched dense DS3 to meta + standard loader, every rank would temporarily hold a full-precision copy of the weights before deepspeed.initialize partitions them 🤔

nor with any "conversion"no ?

  1. Yes you're right, I just realized that. The _load_state_dict_into_zero3_model doesn't run WeightConverter, so any model that depends on a conversion mapping will silently load the wrong keys under DS3. I think, though, this might need a separate issue. Fixing it probably means routing the DS3 through convert_and_load_state_dict_in_model with a gather around the per-param ?? I'm happy to open a follow-up issue.

wdyt ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

partitioning params across ranks during creation.

we do shard on read, so in the same way I think our peak memory is equivalent (tho I might be wrong)

for 2. yeah that can be nice

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok I missed that part , I'll read more carefully! So I can open up an issue with suggested fix 👍🏼

import deepspeed

# We cannot initialize the model on meta device with deepspeed when not quantized
Expand Down Expand Up @@ -4007,6 +4026,12 @@ def from_pretrained(
download_kwargs_with_commit,
**adapter_kwargs,
)
# EP + DeepSpeed: clear device_map (set by initialize_tensor_parallelism) so the model
# loads on CPU first. distribute_model() handles GPU placement during EP sharding.
# Without this, device_map triggers accelerate's dispatch path which breaks shard loading.
_has_ep = distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False)
if _has_ep and is_deepspeed_zero3_enabled():
device_map = None
device_map = check_and_set_device_map(device_map) # warn, error and fix the device map

user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
Expand Down Expand Up @@ -4110,7 +4135,9 @@ def from_pretrained(

register_fusion_patches(cls, config, fusion_config)

model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called, allow_all_kernels)
model_init_context = cls.get_init_context(
dtype, is_quantized, _is_ds_init_called, allow_all_kernels, distributed_config
)

config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
with ContextManagers(model_init_context):
Expand Down Expand Up @@ -4241,7 +4268,11 @@ def _load_pretrained_model(

error_msgs = []

if is_deepspeed_zero3_enabled() and not is_quantized:
# EP + DeepSpeed: skip zero3 loading path. The model was created on meta device
# (not via zero.Init), so params are not zero3-partitioned. The standard loading
# path handles EP sharding via shard_and_distribute_module using the EP plan hooks
# registered by distribute_model(). deepspeed.initialize() wraps the result later.
if is_deepspeed_zero3_enabled() and not is_quantized and not model.has_ep:
if state_dict is None:
merged_state_dict = {}
for ckpt_file in checkpoint_files:
Expand Down Expand Up @@ -4551,7 +4582,8 @@ def _move_missing_keys_from_meta_to_device(
"""
is_quantized = hf_quantizer is not None
# This is the only case where we do not initialize the model on meta device, so we don't have to do anything here
if is_deepspeed_zero3_enabled() and not is_quantized:
# Exception: EP + DeepSpeed uses meta device (not zero.Init), so it needs the standard move path.
if is_deepspeed_zero3_enabled() and not is_quantized and not self.has_ep:
return

# In this case we need to move everything back
Expand Down Expand Up @@ -4609,7 +4641,7 @@ def _initialize_missing_keys(self, is_quantized: bool) -> None:
self._is_hf_initialized = True

# This will only initialize submodules that are not marked as initialized by the line above.
if is_deepspeed_zero3_enabled() and not is_quantized:
if is_deepspeed_zero3_enabled() and not is_quantized and not self.has_ep:
import deepspeed

# keep_vars=True as we need the original tensors, so that the "_is_hf_initialized" is present on them
Expand Down
Loading