Skip to content

Commit 9059c75

Browse files
committed
fix: Bump sglang version from 0.5.9 to 0.5.10
1 parent d5fb617 commit 9059c75

9 files changed

Lines changed: 48 additions & 19 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dependencies = [
1515
"torch==2.9.1",
1616
"torchaudio==2.9.1",
1717
"torchvision==0.24.1",
18-
"transformers==4.57.1",
18+
"transformers==5.3.0",
1919
"qwen-vl-utils==0.0.11",
2020
"datasets",
2121
"setuptools",
@@ -25,7 +25,7 @@ dependencies = [
2525
"numpy",
2626
"accelerate",
2727
"pydantic",
28-
"sglang==0.5.9",
28+
"sglang==0.5.10",
2929
"openai-harmony",
3030
"ninja",
3131
"packaging",

requirements-rocm.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ pre-commit
55
torch==2.8.0+rocm6.3
66
torchaudio==2.8.0+rocm6.3
77
torchvision==0.23.0+rocm6.3
8-
transformers==4.57.1
8+
transformers==5.3.0
99
qwen-vl-utils==0.0.11
1010
datasets
1111
setuptools
@@ -15,6 +15,6 @@ psutil
1515
numpy
1616
accelerate
1717
pydantic
18-
sglang[all]==0.5.4
18+
sglang[all]==0.5.10
1919
openai-harmony
2020
tensorboard

specforge/args.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class SGLangBackendArgs:
9696
sglang_enable_torch_compile: bool = True
9797
sglang_enable_dp_attention: bool = False
9898
sglang_enable_dp_lm_head: bool = False
99-
sglang_enable_piecewise_cuda_graph: bool = False
99+
sglang_enforce_piecewise_cuda_graph: bool = False
100100
sglang_piecewise_cuda_graph_max_tokens: int = 4096
101101
sglang_piecewise_cuda_graph_tokens: List[int] = None
102102
sglang_ep_size: int = 1
@@ -151,9 +151,9 @@ def add_args(parser: argparse.ArgumentParser) -> None:
151151
help="Enable piecewise CUDA graph for SGLang backend",
152152
)
153153
parser.add_argument(
154-
"--sglang-enable-piecewise-cuda-graph",
154+
"--sglang-enforce-piecewise-cuda-graph",
155155
action="store_true",
156-
help="Enable piecewise CUDA graph for SGLang backend's prefill",
156+
help="Enforce piecewise CUDA graph for SGLang backend's prefill",
157157
)
158158
parser.add_argument(
159159
"--sglang-piecewise-cuda-graph-max-tokens",
@@ -186,7 +186,7 @@ def from_args(args: argparse.Namespace) -> "SGLangBackendArgs":
186186
sglang_enable_torch_compile=args.sglang_enable_torch_compile,
187187
sglang_enable_dp_attention=args.sglang_enable_dp_attention,
188188
sglang_enable_dp_lm_head=args.sglang_enable_dp_lm_head,
189-
sglang_enable_piecewise_cuda_graph=args.sglang_enable_piecewise_cuda_graph,
189+
sglang_enforce_piecewise_cuda_graph=args.sglang_enforce_piecewise_cuda_graph,
190190
sglang_piecewise_cuda_graph_max_tokens=args.sglang_piecewise_cuda_graph_max_tokens,
191191
sglang_piecewise_cuda_graph_tokens=args.sglang_piecewise_cuda_graph_tokens,
192192
sglang_ep_size=args.sglang_ep_size,
@@ -210,7 +210,7 @@ def to_kwargs(self) -> Dict[str, Any]:
210210
enable_torch_compile=self.sglang_enable_torch_compile,
211211
enable_dp_attention=self.sglang_enable_dp_attention,
212212
enable_dp_lm_head=self.sglang_enable_dp_lm_head,
213-
enable_piecewise_cuda_graph=self.sglang_enable_piecewise_cuda_graph,
213+
enforce_piecewise_cuda_graph=self.sglang_enforce_piecewise_cuda_graph,
214214
piecewise_cuda_graph_max_tokens=self.sglang_piecewise_cuda_graph_max_tokens,
215215
piecewise_cuda_graph_tokens=self.sglang_piecewise_cuda_graph_tokens,
216216
ep_size=self.sglang_ep_size,

specforge/modeling/draft/llama3_eagle.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,17 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
272272
"sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
273273
)
274274

275+
def rebuild_buffers(self, device):
276+
"""Rebuild non-persistent RoPE buffers corrupted by transformers 5.x meta-device init."""
277+
self.inv_freq = 1.0 / (
278+
self.base ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
279+
)
280+
self._set_cos_sin_cache(
281+
seq_len=self.max_position_embeddings + 20,
282+
device=device,
283+
dtype=torch.get_default_dtype(),
284+
)
285+
275286
@torch.compile(dynamic=True)
276287
def forward(self, x, seq_len=None):
277288
# x: [bs, num_attention_heads, seq_len, head_size]
@@ -1314,6 +1325,16 @@ class LlamaForCausalLMEagle3(Eagle3DraftModel):
13141325

13151326
config_class = LlamaConfig
13161327

1328+
def _init_weights(self, module):
1329+
# Override the transformers 5.x default _init_weights which would
1330+
# re-randomize all Linear/Embedding weights with normal_(0, 0.02).
1331+
# Draft model weights come from checkpoint, not random init.
1332+
#
1333+
# For RotaryEmbedding: rebuild non-persistent buffers (inv_freq,
1334+
# cos_cached, sin_cached) corrupted by meta-device materialization.
1335+
if isinstance(module, LlamaRotaryEmbedding):
1336+
module.rebuild_buffers(module.inv_freq.device)
1337+
13171338
def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None:
13181339
super().__init__(config)
13191340
self.config = config
@@ -1346,6 +1367,8 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None:
13461367
self.register_buffer("t2d", t2d)
13471368
self.register_buffer("d2t", d2t)
13481369

1370+
self.post_init()
1371+
13491372
def forward(
13501373
self,
13511374
hidden_states: torch.Tensor,

specforge/modeling/target/custom_backend/gpt_oss.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
3737
from transformers.processing_utils import Unpack
3838
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
39-
from transformers.utils.generic import check_model_inputs
39+
from transformers.utils.generic import merge_with_config_defaults
40+
from transformers.utils.output_capturing import capture_outputs
4041

4142
from specforge.distributed import get_tp_group, shard_tensor
4243
from specforge.layers import (
@@ -585,7 +586,8 @@ def __init__(self, config: GptOssConfig):
585586
# Initialize weights and apply final processing
586587
self.post_init()
587588

588-
@check_model_inputs
589+
@merge_with_config_defaults
590+
@capture_outputs
589591
@auto_docstring
590592
def forward(
591593
self,

specforge/modeling/target/custom_backend/llama.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
)
4242
from transformers.processing_utils import Unpack
4343
from transformers.utils import TransformersKwargs, logging
44-
from transformers.utils.generic import check_model_inputs
44+
from transformers.utils.generic import merge_with_config_defaults
45+
from transformers.utils.output_capturing import capture_outputs
4546

4647
from specforge.distributed import get_tp_group
4748
from specforge.layers import (
@@ -275,7 +276,8 @@ def __init__(self, config: LlamaConfig):
275276
# Initialize weights and apply final processing
276277
self.post_init()
277278

278-
@check_model_inputs
279+
@merge_with_config_defaults
280+
@capture_outputs
279281
def forward(
280282
self,
281283
input_ids: Optional[torch.LongTensor] = None,

specforge/modeling/target/custom_backend/llama4.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@
5252
logging,
5353
)
5454
from transformers.utils.deprecation import deprecate_kwarg
55-
from transformers.utils.generic import check_model_inputs
55+
from transformers.utils.generic import merge_with_config_defaults
56+
from transformers.utils.output_capturing import capture_outputs
5657

5758
# [MODIFIED] Import from transformers library
5859
from specforge.distributed import get_tp_group, shard_tensor
@@ -431,7 +432,8 @@ def __init__(self, config: Llama4TextConfig):
431432
self.post_init()
432433

433434
@can_return_tuple
434-
@check_model_inputs
435+
@merge_with_config_defaults
436+
@capture_outputs
435437
@auto_docstring
436438
def forward(
437439
self,

specforge/modeling/target/custom_backend/phi3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343
from transformers.processing_utils import Unpack
4444
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
4545
from transformers.utils.deprecation import deprecate_kwarg
46-
from transformers.utils.generic import check_model_inputs
46+
from transformers.utils.generic import merge_with_config_defaults
47+
from transformers.utils.output_capturing import capture_outputs
4748

4849
from specforge.distributed import get_tp_group
4950
from specforge.layers import (
@@ -284,7 +285,8 @@ def __init__(self, config: Phi3Config):
284285
# Initialize weights and apply final processing
285286
self.post_init()
286287

287-
@check_model_inputs
288+
@merge_with_config_defaults
289+
@capture_outputs
288290
@auto_docstring
289291
def forward(
290292
self,

specforge/modeling/target/sglang_backend/patch.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def initialize_model_parallel(
140140
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
141141
),
142142
group_name="tp",
143-
pynccl_use_current_stream=duplicate_tp_group,
144143
)
145144

146145
if duplicate_tp_group:
@@ -156,7 +155,6 @@ def initialize_model_parallel(
156155
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
157156
),
158157
group_name="pdmux_prefill_tp",
159-
pynccl_use_current_stream=True,
160158
)
161159
# NOTE: Check pynccl_comm exists before accessing it (may be None in sglang 0.5.9)
162160
if parallel_state._TP.pynccl_comm is not None:

0 commit comments

Comments
 (0)