Skip to content

Commit dcd404b

Browse files
authored
[Fix]: guard LigerQwen3_5CausalLMOutputWithPast import (#1169)
LigerQwen3_5CausalLMOutputWithPast is only defined in output_classes.py when the installed transformers version includes the Qwen3.5 model. In older transformers versions, the base class import fails silently andthe Liger subclass is never defined, causing an ImportError when qwen3_5.py unconditionally imports it at module level. This breaks CI(Ascend) environments with a transformers version that does not yet ship Qwen3.5, as test_monkey_patch.py imports qwen3_5.py at collection time, causing the entire test run to abort with: ``` ImportError: cannot import name 'LigerQwen3_5CausalLMOutputWithPast' from 'liger_kernel.transformers.model.output_classes' ``` Fix by wrapping the import in a try/except block (falling back to None) and converting the return type annotation to a string literal to avoid evaluation at function definition time. - Hardware Type: Atlas 800I A2 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 3978fe2 commit dcd404b

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/liger_kernel/transformers/model/qwen3_5.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
88
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
99
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
10-
from liger_kernel.transformers.model.output_classes import LigerQwen3_5CausalLMOutputWithPast
10+
11+
try:
12+
from liger_kernel.transformers.model.output_classes import LigerQwen3_5CausalLMOutputWithPast
13+
except ImportError:
14+
LigerQwen3_5CausalLMOutputWithPast = None
1115

1216

1317
def lce_forward(
@@ -138,7 +142,7 @@ def lce_forward_for_multimodal(
138142
logits_to_keep: Union[int, torch.Tensor] = 0,
139143
skip_logits: Optional[bool] = None,
140144
**kwargs,
141-
) -> Union[tuple, LigerQwen3_5CausalLMOutputWithPast]:
145+
) -> Union[tuple, "LigerQwen3_5CausalLMOutputWithPast"]:
142146
r"""
143147
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
144148
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,

0 commit comments

Comments
 (0)