Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion specforge/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,32 @@
from sglang.srt.server_args import ATTENTION_BACKEND_CHOICES


def adapt_sglang_server_args_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Adapt piecewise cuda graph kwargs for sglang version compatibility.

New sglang (post-0.5.9) uses 'enforce_piecewise_cuda_graph' (piecewise cuda graph
is enabled by default, this flag forces it on even when auto-disabled).
Old sglang (<=0.5.9) uses 'enable_piecewise_cuda_graph' (disabled by default).

This function translates between the two based on the installed sglang version.
"""
from sglang.srt.server_args import ServerArgs

has_enforce = hasattr(ServerArgs, "enforce_piecewise_cuda_graph")
has_enable = hasattr(ServerArgs, "enable_piecewise_cuda_graph")

if "enforce_piecewise_cuda_graph" in kwargs and not has_enforce and has_enable:
kwargs["enable_piecewise_cuda_graph"] = kwargs.pop(
"enforce_piecewise_cuda_graph"
)
elif "enable_piecewise_cuda_graph" in kwargs and not has_enable and has_enforce:
kwargs["enforce_piecewise_cuda_graph"] = kwargs.pop(
"enable_piecewise_cuda_graph"
)
Comment on lines +22 to +29
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adapt_sglang_server_args_kwargs can silently overwrite values when both enable_piecewise_cuda_graph and enforce_piecewise_cuda_graph are present in kwargs (e.g., caller passes one manually while another is injected elsewhere). Please handle this explicitly (e.g., define precedence, raise a ValueError, or only convert when the destination key is absent) to avoid surprising behavior.

Suggested change
if "enforce_piecewise_cuda_graph" in kwargs and not has_enforce and has_enable:
kwargs["enable_piecewise_cuda_graph"] = kwargs.pop(
"enforce_piecewise_cuda_graph"
)
elif "enable_piecewise_cuda_graph" in kwargs and not has_enable and has_enforce:
kwargs["enforce_piecewise_cuda_graph"] = kwargs.pop(
"enable_piecewise_cuda_graph"
)
# When only one of these attributes exists on ServerArgs, translate the
# corresponding kwarg name. If both the source and destination kwargs are
# present, avoid silently overwriting and instead either validate or raise.
if "enforce_piecewise_cuda_graph" in kwargs and not has_enforce and has_enable:
# Destination key already present: resolve potential conflict.
if "enable_piecewise_cuda_graph" in kwargs:
src_val = kwargs["enforce_piecewise_cuda_graph"]
dst_val = kwargs["enable_piecewise_cuda_graph"]
if src_val != dst_val:
raise ValueError(
"Both 'enforce_piecewise_cuda_graph' and "
"'enable_piecewise_cuda_graph' were provided with "
f"different values ({src_val!r} vs {dst_val!r}) while "
"only 'enable_piecewise_cuda_graph' is supported by the "
"installed sglang version. Please specify only one or "
"ensure they have the same value."
)
# Values are equal: drop the unsupported alias to avoid confusion.
kwargs.pop("enforce_piecewise_cuda_graph")
else:
# Only the unsupported name is present: translate it.
kwargs["enable_piecewise_cuda_graph"] = kwargs.pop(
"enforce_piecewise_cuda_graph"
)
elif "enable_piecewise_cuda_graph" in kwargs and not has_enable and has_enforce:
if "enforce_piecewise_cuda_graph" in kwargs:
src_val = kwargs["enable_piecewise_cuda_graph"]
dst_val = kwargs["enforce_piecewise_cuda_graph"]
if src_val != dst_val:
raise ValueError(
"Both 'enable_piecewise_cuda_graph' and "
"'enforce_piecewise_cuda_graph' were provided with "
f"different values ({src_val!r} vs {dst_val!r}) while "
"only 'enforce_piecewise_cuda_graph' is supported by the "
"installed sglang version. Please specify only one or "
"ensure they have the same value."
)
kwargs.pop("enable_piecewise_cuda_graph")
else:
kwargs["enforce_piecewise_cuda_graph"] = kwargs.pop(
"enable_piecewise_cuda_graph"
)

Copilot uses AI. Check for mistakes.

return kwargs
Comment on lines +8 to +31
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function modifies the kwargs dictionary in-place, which can lead to unexpected side effects for the caller, especially when **kwargs is passed to a function. It's a better practice to work on a copy of the dictionary to avoid side effects.

I suggest creating a copy of kwargs at the beginning of the function. This makes the function pure with respect to its inputs.

def adapt_sglang_server_args_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
    """Adapt piecewise cuda graph kwargs for sglang version compatibility.

    New sglang (post-0.5.9) uses 'enforce_piecewise_cuda_graph' (piecewise cuda graph
    is enabled by default, this flag forces it on even when auto-disabled).
    Old sglang (<=0.5.9) uses 'enable_piecewise_cuda_graph' (disabled by default).

    This function translates between the two based on the installed sglang version.
    """
    kwargs = kwargs.copy()
    from sglang.srt.server_args import ServerArgs

    has_enforce = hasattr(ServerArgs, "enforce_piecewise_cuda_graph")
    has_enable = hasattr(ServerArgs, "enable_piecewise_cuda_graph")

    if "enforce_piecewise_cuda_graph" in kwargs and not has_enforce and has_enable:
        kwargs["enable_piecewise_cuda_graph"] = kwargs.pop(
            "enforce_piecewise_cuda_graph"
        )
    elif "enable_piecewise_cuda_graph" in kwargs and not has_enable and has_enforce:
        kwargs["enforce_piecewise_cuda_graph"] = kwargs.pop(
            "enable_piecewise_cuda_graph"
        )

    return kwargs



@dataclass
class TrackerArgs:
report_to: str = "none"
Expand Down Expand Up @@ -188,7 +214,7 @@ def from_args(args: argparse.Namespace) -> "SGLangBackendArgs":
)

def to_kwargs(self) -> Dict[str, Any]:
return dict(
kwargs = dict(
attention_backend=self.sglang_attention_backend,
mem_fraction_static=self.sglang_mem_fraction_static,
context_length=self.sglang_context_length,
Expand All @@ -204,3 +230,4 @@ def to_kwargs(self) -> Dict[str, Any]:
max_running_requests=self.sglang_max_running_requests,
max_total_tokens=self.sglang_max_total_tokens,
)
return adapt_sglang_server_args_kwargs(kwargs)
2 changes: 2 additions & 0 deletions specforge/modeling/target/dflash_target_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather
from transformers import AutoModelForCausalLM

from specforge.args import adapt_sglang_server_args_kwargs
from specforge.distributed import get_tp_group

from .sglang_backend import SGLangRunner
Expand Down Expand Up @@ -80,6 +81,7 @@ def from_pretrained(
**kwargs,
) -> "SGLangDFlashTargetModel":
tp_size = dist.get_world_size(get_tp_group())
kwargs = adapt_sglang_server_args_kwargs(kwargs)
server_args = ServerArgs(
model_path=pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
Expand Down
2 changes: 2 additions & 0 deletions specforge/modeling/target/eagle3_target_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather
from transformers import AutoModelForCausalLM

from specforge.args import adapt_sglang_server_args_kwargs
from specforge.distributed import get_tp_device_mesh, get_tp_group
from specforge.utils import padding

Expand Down Expand Up @@ -307,6 +308,7 @@ def from_pretrained(
# NOTE: sglang 0.5.9 requires dtype to be non-None
# If torch_dtype is None, use "auto" to let sglang decide the dtype
dtype_arg = torch_dtype if torch_dtype is not None else "auto"
kwargs = adapt_sglang_server_args_kwargs(kwargs)
server_args = ServerArgs(
model_path=pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
Expand Down
Loading