diff --git a/specforge/args.py b/specforge/args.py index fd6de14c..09ab44b9 100644 --- a/specforge/args.py +++ b/specforge/args.py @@ -5,6 +5,32 @@ from sglang.srt.server_args import ATTENTION_BACKEND_CHOICES +def adapt_sglang_server_args_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Adapt piecewise cuda graph kwargs for sglang version compatibility. + + New sglang (post-0.5.9) uses 'enforce_piecewise_cuda_graph' (piecewise cuda graph + is enabled by default, this flag forces it on even when auto-disabled). + Old sglang (<=0.5.9) uses 'enable_piecewise_cuda_graph' (disabled by default). + + This function translates between the two based on the installed sglang version. + """ + from sglang.srt.server_args import ServerArgs + + has_enforce = hasattr(ServerArgs, "enforce_piecewise_cuda_graph") + has_enable = hasattr(ServerArgs, "enable_piecewise_cuda_graph") + + if "enforce_piecewise_cuda_graph" in kwargs and not has_enforce and has_enable: + kwargs["enable_piecewise_cuda_graph"] = kwargs.pop( + "enforce_piecewise_cuda_graph" + ) + elif "enable_piecewise_cuda_graph" in kwargs and not has_enable and has_enforce: + kwargs["enforce_piecewise_cuda_graph"] = kwargs.pop( + "enable_piecewise_cuda_graph" + ) + + return kwargs + + @dataclass class TrackerArgs: report_to: str = "none" @@ -188,7 +214,7 @@ def from_args(args: argparse.Namespace) -> "SGLangBackendArgs": ) def to_kwargs(self) -> Dict[str, Any]: - return dict( + kwargs = dict( attention_backend=self.sglang_attention_backend, mem_fraction_static=self.sglang_mem_fraction_static, context_length=self.sglang_context_length, @@ -204,3 +230,4 @@ def to_kwargs(self) -> Dict[str, Any]: max_running_requests=self.sglang_max_running_requests, max_total_tokens=self.sglang_max_total_tokens, ) + return adapt_sglang_server_args_kwargs(kwargs) diff --git a/specforge/modeling/target/dflash_target_model.py b/specforge/modeling/target/dflash_target_model.py index 0df93823..dc133c74 100644 --- a/specforge/modeling/target/dflash_target_model.py +++ b/specforge/modeling/target/dflash_target_model.py @@ -17,6 +17,7 @@ from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather from transformers import AutoModelForCausalLM +from specforge.args import adapt_sglang_server_args_kwargs from specforge.distributed import get_tp_group from .sglang_backend import SGLangRunner @@ -80,6 +81,7 @@ def from_pretrained( **kwargs, ) -> "SGLangDFlashTargetModel": tp_size = dist.get_world_size(get_tp_group()) + kwargs = adapt_sglang_server_args_kwargs(kwargs) server_args = ServerArgs( model_path=pretrained_model_name_or_path, trust_remote_code=trust_remote_code, diff --git a/specforge/modeling/target/eagle3_target_model.py b/specforge/modeling/target/eagle3_target_model.py index 2acf50ba..0571b333 100644 --- a/specforge/modeling/target/eagle3_target_model.py +++ b/specforge/modeling/target/eagle3_target_model.py @@ -32,6 +32,7 @@ from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather from transformers import AutoModelForCausalLM +from specforge.args import adapt_sglang_server_args_kwargs from specforge.distributed import get_tp_device_mesh, get_tp_group from specforge.utils import padding @@ -307,6 +308,7 @@ def from_pretrained( # NOTE: sglang 0.5.9 requires dtype to be non-None # If torch_dtype is None, use "auto" to let sglang decide the dtype dtype_arg = torch_dtype if torch_dtype is not None else "auto" + kwargs = adapt_sglang_server_args_kwargs(kwargs) server_args = ServerArgs( model_path=pretrained_model_name_or_path, trust_remote_code=trust_remote_code,