Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,17 @@ def is_trainable(self):
@property
def is_compileable(self) -> bool:
return True

def _dequantize(self, model):
from torchao.utils import TorchAOBaseTensor

for name, module in model.named_modules():
if isinstance(module, nn.Linear) and isinstance(module.weight, TorchAOBaseTensor):
device = module.weight.device
dequantized_weight = module.weight.dequantize().to(device)
module.weight = nn.Parameter(dequantized_weight)
# Reset extra_repr if it was overridden
if hasattr(module.extra_repr, "__func__") and module.extra_repr.__func__ is not nn.Linear.extra_repr:
module.extra_repr = types.MethodType(nn.Linear.extra_repr, module)

return model
72 changes: 47 additions & 25 deletions tests/models/testing_utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ def _is_module_quantized(self, module):
except (AssertionError, AttributeError):
return False

def _get_dummy_inputs_for_model(self, model):
inputs = self.get_dummy_inputs()
model_dtype = next(model.parameters()).dtype
return {
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs.items()
}

def _load_unquantized_model(self):
kwargs = getattr(self, "pretrained_model_kwargs", {})
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
Expand Down Expand Up @@ -174,7 +182,7 @@ def _test_quantization_inference(self, config_kwargs):
model_quantized = self._create_quantized_model(config_kwargs)
model_quantized.to(torch_device)

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model_quantized)
output = model_quantized(**inputs, return_dict=False)[0]

assert output is not None, "Model output is None"
Expand Down Expand Up @@ -222,7 +230,8 @@ def _test_quantization_lora_inference(self, config_kwargs):
# Move LoRA adapter weights to device (they default to CPU)
model.to(torch_device)

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

output = model(**inputs, return_dict=False)[0]

assert output is not None, "Model output is None with LoRA"
Expand All @@ -236,7 +245,8 @@ def _test_quantization_serialization(self, config_kwargs, tmp_path):

model_loaded = self.model_class.from_pretrained(str(tmp_path))

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model_loaded)

output = model_loaded(**inputs, return_dict=False)[0]
assert not torch.isnan(output).any(), "Loaded model output contains NaN"

Expand Down Expand Up @@ -334,7 +344,8 @@ def _test_quantization_device_map(self, config_kwargs):
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute"
assert model.hf_device_map is not None, "hf_device_map should not be None"

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
Expand All @@ -360,14 +371,7 @@ def _test_dequantize(self, config_kwargs):
assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()"

# Get model dtype from first parameter
model_dtype = next(model.parameters()).dtype

inputs = self.get_dummy_inputs()
# Cast inputs to model dtype
inputs = {
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs.items()
}
inputs = self._get_dummy_inputs_for_model(model)
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None after dequantization"
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
Expand Down Expand Up @@ -413,9 +417,10 @@ def _test_quantization_training(self, config_kwargs):
pytest.skip("No attention layers found in model for adapter training test")

# Step 3: run forward and backward pass
inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

with torch.amp.autocast(torch_device, dtype=torch.float16):
# Use bfloat16 instead of float16 to avoid gradient underflow with quantized layers
with torch.amp.autocast(torch_device, dtype=torch.bfloat16):
out = model(**inputs, return_dict=False)[0]
out.norm().backward()

Expand Down Expand Up @@ -597,7 +602,8 @@ def test_bnb_keep_modules_in_fp32(self):
f"Module {name} should be uint8 but is {module.weight.dtype}"
)

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

_ = model(**inputs)
finally:
if original_fp32_modules is not None:
Expand Down Expand Up @@ -818,6 +824,10 @@ class TorchAoConfigMixin:
@staticmethod
def _get_quant_config(config_name):
config_cls = getattr(_torchao_quantization, config_name)
# TorchAO int4 quantization requires plain_int32 packing format on Intel XPU
if config_name == "Int4WeightOnlyConfig" and torch_device == "xpu":
return TorchAoConfig(config_cls(int4_packing_format="plain_int32"))

return TorchAoConfig(config_cls())

def _create_quantized_model(self, config_name, **extra_kwargs):
Expand All @@ -829,11 +839,12 @@ def _create_quantized_model(self, config_name, **extra_kwargs):
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)

def _verify_if_layer_quantized(self, name, module, config_kwargs):
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"

from torchao.utils import TorchAOBaseTensor

# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack)
_int4wo_skip = pytest.mark.skipif(torch_device != "cuda", reason="int4wo quantization requires CUDA")
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
assert isinstance(module.weight, TorchAOBaseTensor), (
f"Layer {name} weight is {type(module.weight)}, expected TorchAOBaseTensor"
)


@is_torchao
Expand Down Expand Up @@ -861,7 +872,7 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
@pytest.mark.parametrize(
"quant_type",
[
pytest.param("int4wo", marks=_int4wo_skip),
"int4wo",
"int8wo",
"int8dq",
],
Expand All @@ -873,7 +884,7 @@ def test_torchao_quantization_num_parameters(self, quant_type):
@pytest.mark.parametrize(
"quant_type",
[
pytest.param("int4wo", marks=_int4wo_skip),
"int4wo",
"int8wo",
"int8dq",
],
Expand All @@ -888,7 +899,7 @@ def test_torchao_quantization_memory_footprint(self, quant_type):
@pytest.mark.parametrize(
"quant_type",
[
pytest.param("int4wo", marks=_int4wo_skip),
"int4wo",
"int8wo",
"int8dq",
],
Expand All @@ -915,7 +926,8 @@ def test_torchao_quantization_serialization(self, quant_type, tmp_path):

model_loaded = self.model_class.from_pretrained(str(tmp_path), device_map=str(torch_device))

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model_loaded)

output = model_loaded(**inputs, return_dict=False)[0]
assert not torch.isnan(output).any(), "Loaded model output contains NaN"

Expand Down Expand Up @@ -1172,6 +1184,14 @@ class QuantizationCompileTesterMixin:
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""

def _get_dummy_inputs_for_model(self, model):
inputs = self.get_dummy_inputs()
model_dtype = next(model.parameters()).dtype
return {
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs.items()
}

def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
Expand All @@ -1197,7 +1217,8 @@ def _test_torch_compile(self, config_kwargs):
model = torch.compile(model, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True):
inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
Expand Down Expand Up @@ -1228,7 +1249,8 @@ def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False
model.enable_group_offload(**group_offload_kwargs)
model = torch.compile(model)

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan Animate model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
(1, 36, 5, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
Expand All @@ -228,10 +228,10 @@ def get_dummy_inputs(self):
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
(1, 16, 4, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
(1, 3, 13, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
Expand Down
Loading