-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathmodal_train.py
More file actions
180 lines (159 loc) · 5.29 KB
/
modal_train.py
File metadata and controls
180 lines (159 loc) · 5.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import modal
import modal.experimental
import os
import subprocess
from enum import Enum
from common import (
DATASET_ID,
DATASET_VOLUME_NAME,
DATASET_MOUNT_PATH,
MODEL_MOUNT_PATH,
MODEL_CACHE_DIR,
train_image,
REMOTE_TRAIN_SCRIPT_PATH,
MODEL_VOLUME_NAME,
)
data_vol = modal.Volume.from_name(
DATASET_VOLUME_NAME,
create_if_missing=True,
)
model_vol = modal.Volume.from_name(
MODEL_VOLUME_NAME,
create_if_missing=True,
)
hf_secret = modal.Secret.from_name("huggingface-secret")
wandb_secret = modal.Secret.from_name("wandb-secret")
app = modal.App(
f"{DATASET_ID}-train",
)
# The number of containers (i.e. nodes) in the cluster. This can be between 1 and 4.
n_nodes = 2
# Typically this matches the number of GPUs per container.
n_proc_per_node = 8
# Port used for inter-container control communication.
main_port = 29500
class LaunchType(Enum):
TORCHRUN = "torchrun"
ACCELERATE = "accelerate"
@app.function(
image=train_image,
volumes={
DATASET_MOUNT_PATH: data_vol,
MODEL_MOUNT_PATH: model_vol,
},
secrets=[
wandb_secret,
hf_secret,
],
gpu="H100:8",
experimental_options={
"efa_enabled": "True",
},
timeout=60 * 60 * 24,
)
@modal.experimental.clustered(n_nodes, rdma=True)
def train_multi_node(launch_type: str = "torchrun", profile: bool = False):
"""
Performs multi-node training using either torchrun or Hugging Face Accelerate.
Launch type can be either 'torchrun' or 'accelerate'.
"""
# Parse the launch_type string into the LaunchType enum
parsed_launch_type: LaunchType
if launch_type.lower() == "torchrun":
parsed_launch_type = LaunchType.TORCHRUN
elif launch_type.lower() == "accelerate":
parsed_launch_type = LaunchType.ACCELERATE
else:
raise ValueError(
f"Invalid launch_type: '{launch_type}'. Must be 'torchrun' or 'accelerate'."
)
# Get Modal cluster info for inter-container communication
cluster_info = modal.experimental.get_cluster_info()
container_rank: int = cluster_info.rank
main_ip_addr: str = cluster_info.container_ips[0]
container_id = os.environ["MODAL_TASK_ID"]
# Configuration for batch sizes and gradient accumulation. Target a constant
# global batch size so we can do apples-to-apples comparisons between runs.
global_batch_size_config = 1024
per_device_batch_size_config = 16
grad_accum_config = global_batch_size_config // (
n_proc_per_node * n_nodes * per_device_batch_size_config
)
# Unified run name for output directories and W&B
current_run_name = (
f"starcoder-nodes_{n_nodes}-gpus_{n_proc_per_node}"
f"-batch_{global_batch_size_config}-per_device_{per_device_batch_size_config}"
f"-grad_accum_{grad_accum_config}"
)
print(
f"Hello from {container_id}, rank {container_rank} of {n_nodes} "
f"using {parsed_launch_type.value}. Run ID: {current_run_name}"
)
if container_rank == 0:
print(f"Main container's IP address: {main_ip_addr}")
# Setup W&B environment variables on the main node
wandb_project_name = f"{DATASET_ID.replace('/', '-')}-training"
os.environ["WANDB_PROJECT"] = wandb_project_name
os.environ["WANDB_RUN_NAME"] = current_run_name
print(
f"Weights & Biases: Project='{wandb_project_name}', Run='{current_run_name}'"
)
script_args = [
"--data_dir",
DATASET_MOUNT_PATH,
"--output_dir",
f"{MODEL_MOUNT_PATH}/{current_run_name}",
"--epochs",
"2",
"--batch_per_device",
str(per_device_batch_size_config),
"--grad_accum",
str(grad_accum_config),
"--model_cache_dir",
MODEL_CACHE_DIR,
]
if profile:
script_args.append("--profile")
def _train_torchrun() -> None:
from torch.distributed.run import parse_args, run
args = [
f"--nnodes={n_nodes}",
f"--nproc-per-node={n_proc_per_node}",
f"--node-rank={container_rank}",
f"--master-addr={main_ip_addr}",
f"--master-port={main_port}",
REMOTE_TRAIN_SCRIPT_PATH,
*script_args,
]
print(f"Executing torchrun with args: {' '.join(args)}")
run(parse_args(args))
def _train_accelerate() -> None:
cmd = [
"accelerate",
"launch",
"--num_processes",
str(n_proc_per_node),
"--num_machines",
str(n_nodes),
"--machine_rank",
str(container_rank),
"--main_process_ip",
main_ip_addr,
"--main_process_port",
str(main_port),
"--mixed_precision",
"bf16",
REMOTE_TRAIN_SCRIPT_PATH,
*script_args,
]
print(f"Executing accelerate launch with: {' '.join(cmd)}")
subprocess.run(cmd, check=True)
# Dispatch to the correct training function based on launch_type
if parsed_launch_type == LaunchType.TORCHRUN:
_train_torchrun()
elif parsed_launch_type == LaunchType.ACCELERATE:
_train_accelerate()
else:
raise ValueError(
f"Invalid launch_type: '{launch_type}'. Must be 'torchrun' or 'accelerate'."
)