From bef284cd640172fb06d7499200cae00bc5bccdc7 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 20 Apr 2026 18:55:16 -0700 Subject: [PATCH 01/10] enable int4wo tests on XPU Signed-off-by: jiqing-feng --- tests/models/testing_utils/quantization.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 4403cacc6966..63eef8b381bb 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -832,10 +832,6 @@ def _verify_if_layer_quantized(self, name, module, config_kwargs): assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}" -# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack) -_int4wo_skip = pytest.mark.skipif(torch_device != "cuda", reason="int4wo quantization requires CUDA") - - @is_torchao @require_accelerator @require_torchao_version_greater_or_equal("0.7.0") @@ -861,7 +857,7 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin): @pytest.mark.parametrize( "quant_type", [ - pytest.param("int4wo", marks=_int4wo_skip), + "int4wo" "int8wo", "int8dq", ], @@ -873,7 +869,7 @@ def test_torchao_quantization_num_parameters(self, quant_type): @pytest.mark.parametrize( "quant_type", [ - pytest.param("int4wo", marks=_int4wo_skip), + "int4wo" "int8wo", "int8dq", ], @@ -888,7 +884,7 @@ def test_torchao_quantization_memory_footprint(self, quant_type): @pytest.mark.parametrize( "quant_type", [ - pytest.param("int4wo", marks=_int4wo_skip), + "int4wo" "int8wo", "int8dq", ], From ca507a8cf5226029cb75f813a6ec59ce41f1653e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 20 Apr 2026 18:56:58 -0700 Subject: [PATCH 02/10] fix typo Signed-off-by: jiqing-feng --- tests/models/testing_utils/quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 63eef8b381bb..d9b07e3347aa 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -857,7 +857,7 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin): @pytest.mark.parametrize( "quant_type", [ - "int4wo" + "int4wo", "int8wo", "int8dq", ], @@ -869,7 +869,7 @@ def test_torchao_quantization_num_parameters(self, quant_type): @pytest.mark.parametrize( "quant_type", [ - "int4wo" + "int4wo", "int8wo", "int8dq", ], @@ -884,7 +884,7 @@ def test_torchao_quantization_memory_footprint(self, quant_type): @pytest.mark.parametrize( "quant_type", [ - "int4wo" + "int4wo", "int8wo", "int8dq", ], From c51708e3eafaed1ea194d6373388fd2c5c9fabd5 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 21 Apr 2026 12:45:57 +0800 Subject: [PATCH 03/10] fix input dtype Signed-off-by: jiqing-feng --- tests/models/testing_utils/quantization.py | 50 ++++++++++++++-------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index d9b07e3347aa..7ef6c7454313 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -142,6 +142,14 @@ def _is_module_quantized(self, module): except (AssertionError, AttributeError): return False + def _get_dummy_inputs_for_model(self, model): + inputs = self.get_dummy_inputs() + model_dtype = next(model.parameters()).dtype + return { + k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v + for k, v in inputs.items() + } + def _load_unquantized_model(self): kwargs = getattr(self, "pretrained_model_kwargs", {}) return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) @@ -174,7 +182,7 @@ def _test_quantization_inference(self, config_kwargs): model_quantized = self._create_quantized_model(config_kwargs) model_quantized.to(torch_device) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model_quantized) output = model_quantized(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" @@ -222,7 +230,8 @@ def _test_quantization_lora_inference(self, config_kwargs): # Move LoRA adapter weights to device (they default to CPU) model.to(torch_device) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) + output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None with LoRA" @@ -236,7 +245,8 @@ def _test_quantization_serialization(self, config_kwargs, tmp_path): model_loaded = self.model_class.from_pretrained(str(tmp_path)) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model_loaded) + output = model_loaded(**inputs, return_dict=False)[0] assert not torch.isnan(output).any(), "Loaded model output contains NaN" @@ -334,7 +344,8 @@ def _test_quantization_device_map(self, config_kwargs): assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute" assert model.hf_device_map is not None, "hf_device_map should not be None" - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) + output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" assert not torch.isnan(output).any(), "Model output contains NaN" @@ -360,14 +371,7 @@ def _test_dequantize(self, config_kwargs): assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()" # Get model dtype from first parameter - model_dtype = next(model.parameters()).dtype - - inputs = self.get_dummy_inputs() - # Cast inputs to model dtype - inputs = { - k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v - for k, v in inputs.items() - } + inputs = self._get_dummy_inputs_for_model(model) output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None after dequantization" assert not torch.isnan(output).any(), "Model output contains NaN after dequantization" @@ -413,7 +417,7 @@ def _test_quantization_training(self, config_kwargs): pytest.skip("No attention layers found in model for adapter training test") # Step 3: run forward and backward pass - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) with torch.amp.autocast(torch_device, dtype=torch.float16): out = model(**inputs, return_dict=False)[0] @@ -597,7 +601,8 @@ def test_bnb_keep_modules_in_fp32(self): f"Module {name} should be uint8 but is {module.weight.dtype}" ) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) + _ = model(**inputs) finally: if original_fp32_modules is not None: @@ -911,7 +916,8 @@ def test_torchao_quantization_serialization(self, quant_type, tmp_path): model_loaded = self.model_class.from_pretrained(str(tmp_path), device_map=str(torch_device)) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model_loaded) + output = model_loaded(**inputs, return_dict=False)[0] assert not torch.isnan(output).any(), "Loaded model output contains NaN" @@ -1168,6 +1174,14 @@ class QuantizationCompileTesterMixin: - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass """ + def _get_dummy_inputs_for_model(self, model): + inputs = self.get_dummy_inputs() + model_dtype = next(model.parameters()).dtype + return { + k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v + for k, v in inputs.items() + } + def setup_method(self): gc.collect() backend_empty_cache(torch_device) @@ -1193,7 +1207,8 @@ def _test_torch_compile(self, config_kwargs): model = torch.compile(model, fullgraph=True) with torch._dynamo.config.patch(error_on_recompile=True): - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) + output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" assert not torch.isnan(output).any(), "Model output contains NaN" @@ -1224,7 +1239,8 @@ def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False model.enable_group_offload(**group_offload_kwargs) model = torch.compile(model) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) + output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" assert not torch.isnan(output).any(), "Model output contains NaN" From 6df4b31b950c10f998031cb71eae49a5f8bd95e4 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 21 Apr 2026 13:03:02 +0800 Subject: [PATCH 04/10] fix int4 config for xpu Signed-off-by: jiqing-feng --- tests/models/testing_utils/quantization.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 7ef6c7454313..f86e29a89c39 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -823,6 +823,9 @@ class TorchAoConfigMixin: @staticmethod def _get_quant_config(config_name): config_cls = getattr(_torchao_quantization, config_name) + # TorchAO int4 quantization requires plain_int32 packing format on Intel XPU + if torch_device == "xpu": + config_cls.int4_packing_format = "plain_int32" return TorchAoConfig(config_cls()) def _create_quantized_model(self, config_name, **extra_kwargs): From 8a9013d52c9490db0cff8813a46530a9594cf537 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 21 Apr 2026 13:11:43 +0800 Subject: [PATCH 05/10] fix format Signed-off-by: jiqing-feng --- tests/models/testing_utils/quantization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index f86e29a89c39..50e3fc34e184 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -825,7 +825,8 @@ def _get_quant_config(config_name): config_cls = getattr(_torchao_quantization, config_name) # TorchAO int4 quantization requires plain_int32 packing format on Intel XPU if torch_device == "xpu": - config_cls.int4_packing_format = "plain_int32" + return TorchAoConfig(config_cls(int4_packing_format="plain_int32")) + return TorchAoConfig(config_cls()) def _create_quantized_model(self, config_name, **extra_kwargs): From 81e701516ccd336f425e3ccd3806872b4988463d Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 21 Apr 2026 13:16:25 +0800 Subject: [PATCH 06/10] only int4wo need specific format Signed-off-by: jiqing-feng --- tests/models/testing_utils/quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 50e3fc34e184..4e30bba34bbf 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -824,7 +824,7 @@ class TorchAoConfigMixin: def _get_quant_config(config_name): config_cls = getattr(_torchao_quantization, config_name) # TorchAO int4 quantization requires plain_int32 packing format on Intel XPU - if torch_device == "xpu": + if config_name == "int4wo" and torch_device == "xpu": return TorchAoConfig(config_cls(int4_packing_format="plain_int32")) return TorchAoConfig(config_cls()) From 4e4e759cbed2e97b25a9248a45ae0e70966eca3d Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 21 Apr 2026 13:19:16 +0800 Subject: [PATCH 07/10] fix config name Signed-off-by: jiqing-feng --- tests/models/testing_utils/quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 4e30bba34bbf..9ec154914135 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -824,7 +824,7 @@ class TorchAoConfigMixin: def _get_quant_config(config_name): config_cls = getattr(_torchao_quantization, config_name) # TorchAO int4 quantization requires plain_int32 packing format on Intel XPU - if config_name == "int4wo" and torch_device == "xpu": + if config_name == "Int4WeightOnlyConfig" and torch_device == "xpu": return TorchAoConfig(config_cls(int4_packing_format="plain_int32")) return TorchAoConfig(config_cls()) From 8180979e182bdfe2f83244b526cdbc1f9106f3d3 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 21 Apr 2026 14:49:51 +0800 Subject: [PATCH 08/10] fix dequantize and training Signed-off-by: jiqing-feng --- .../quantizers/torchao/torchao_quantizer.py | 14 ++++++++++++++ tests/models/testing_utils/quantization.py | 8 +++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 3a20dca88ecf..59387e41654e 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -376,3 +376,17 @@ def is_trainable(self): @property def is_compileable(self) -> bool: return True + + def _dequantize(self, model): + from torchao.utils import TorchAOBaseTensor + + for name, module in model.named_modules(): + if isinstance(module, nn.Linear) and isinstance(module.weight, TorchAOBaseTensor): + device = module.weight.device + dequantized_weight = module.weight.dequantize().to(device) + module.weight = nn.Parameter(dequantized_weight) + # Reset extra_repr if it was overridden + if hasattr(module.extra_repr, "__func__") and module.extra_repr.__func__ is not nn.Linear.extra_repr: + module.extra_repr = types.MethodType(nn.Linear.extra_repr, module) + + return model diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 9ec154914135..6172a22ad02a 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -419,7 +419,8 @@ def _test_quantization_training(self, config_kwargs): # Step 3: run forward and backward pass inputs = self._get_dummy_inputs_for_model(model) - with torch.amp.autocast(torch_device, dtype=torch.float16): + # Use bfloat16 instead of float16 to avoid gradient underflow with quantized layers + with torch.amp.autocast(torch_device, dtype=torch.bfloat16): out = model(**inputs, return_dict=False)[0] out.norm().backward() @@ -838,7 +839,12 @@ def _create_quantized_model(self, config_name, **extra_kwargs): return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) def _verify_if_layer_quantized(self, name, module, config_kwargs): + from torchao.utils import TorchAOBaseTensor + assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}" + assert isinstance(module.weight, TorchAOBaseTensor), ( + f"Layer {name} weight is {type(module.weight)}, expected TorchAOBaseTensor" + ) @is_torchao From d210d4a0953092416060a74e07e010543e3c7871 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 21 Apr 2026 16:45:18 +0800 Subject: [PATCH 09/10] fix test size Signed-off-by: jiqing-feng --- .../transformers/test_models_transformer_wan_animate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_wan_animate.py b/tests/models/transformers/test_models_transformer_wan_animate.py index ac0ef0698c63..67cb28d1728f 100644 --- a/tests/models/transformers/test_models_transformer_wan_animate.py +++ b/tests/models/transformers/test_models_transformer_wan_animate.py @@ -219,7 +219,7 @@ def get_dummy_inputs(self): """Override to provide inputs matching the tiny Wan Animate model dimensions.""" return { "hidden_states": randn_tensor( - (1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype + (1, 36, 5, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "encoder_hidden_states": randn_tensor( (1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype @@ -228,10 +228,10 @@ def get_dummy_inputs(self): (1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "pose_hidden_states": randn_tensor( - (1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype + (1, 16, 4, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "face_pixel_values": randn_tensor( - (1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype + (1, 3, 5, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype), } From 0ba8682067803af0cf6394ae85f41bf1f3a7bf5d Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 21 Apr 2026 16:50:38 +0800 Subject: [PATCH 10/10] fix size Signed-off-by: jiqing-feng --- .../models/transformers/test_models_transformer_wan_animate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/transformers/test_models_transformer_wan_animate.py b/tests/models/transformers/test_models_transformer_wan_animate.py index 67cb28d1728f..569e3507825e 100644 --- a/tests/models/transformers/test_models_transformer_wan_animate.py +++ b/tests/models/transformers/test_models_transformer_wan_animate.py @@ -231,7 +231,7 @@ def get_dummy_inputs(self): (1, 16, 4, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "face_pixel_values": randn_tensor( - (1, 3, 5, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype + (1, 3, 13, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype), }