Skip to content

Commit 8e5697b

Browse files
vbaddiasmigosw
authored andcommitted
nit: rebase to mainline
Signed-off-by: vbaddi <vbaddi@qti.qualcomm.com>
1 parent aee7e88 commit 8e5697b

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

tests/unit_test/models/test_model_quickcheck.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,9 @@ def test_whisper_export_smoke(tmp_path):
471471
@pytest.mark.llm_model
472472
def test_causal_subfunction_export_smoke(tmp_path):
473473
model_id = CAUSAL_RUNTIME_MODEL_IDS["gpt2"]
474-
model_hf = AutoModelForCausalLM.from_pretrained(model_id, **MODEL_KWARGS, low_cpu_mem_usage=False)
474+
model_hf = AutoModelForCausalLM.from_pretrained(
475+
model_id, **MODEL_KWARGS, low_cpu_mem_usage=False, torch_dtype=torch.float32
476+
)
475477
model_hf.eval()
476478
qeff_model = QEFFAutoModelForCausalLM(model_hf)
477479

@@ -499,7 +501,9 @@ def test_causal_subfunction_export_smoke(tmp_path):
499501
def test_causal_compile_with_subfunctions_all_models(model_type, model_id, tmp_path):
500502
del model_type
501503
try:
502-
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
504+
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(
505+
model_id, trust_remote_code=True, torch_dtype=torch.float32
506+
)
503507
except Exception as exc:
504508
_skip_on_model_fetch_error(exc, model_id)
505509

@@ -607,6 +611,7 @@ def test_causal_subfunction_and_proxy_export_smoke_gpt2(tmp_path):
607611
model_id,
608612
trust_remote_code=True,
609613
enable_proxy=True,
614+
torch_dtype=torch.float32,
610615
)
611616
except Exception as exc:
612617
_skip_on_model_fetch_error(exc, model_id)
@@ -621,7 +626,9 @@ def test_causal_subfunction_and_proxy_export_smoke_gpt2(tmp_path):
621626

622627
@pytest.mark.llm_model
623628
def test_prefix_caching_continuous_batching_export_and_ort_smoke(tmp_path):
624-
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(PREFIX_CACHING_MODEL_ID, continuous_batching=True)
629+
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(
630+
PREFIX_CACHING_MODEL_ID, continuous_batching=True, torch_dtype=torch.float32
631+
)
625632
onnx_path = _exported_onnx_path(qeff_model.export(tmp_path / "prefix-caching"))
626633
onnx_model = onnx.load(onnx_path, load_external_data=False)
627634

@@ -638,7 +645,9 @@ def test_prefix_caching_continuous_batching_export_and_ort_smoke(tmp_path):
638645
def test_awq_export_smoke(tmp_path):
639646
replace_transformers_quantizers()
640647
try:
641-
model_hf = AutoModelForCausalLM.from_pretrained(TINY_AWQ_MODEL_ID, low_cpu_mem_usage=False)
648+
model_hf = AutoModelForCausalLM.from_pretrained(
649+
TINY_AWQ_MODEL_ID, low_cpu_mem_usage=False, torch_dtype=torch.float32
650+
)
642651
except Exception as exc:
643652
_skip_on_model_fetch_error(exc, TINY_AWQ_MODEL_ID)
644653
model_hf.eval()
@@ -655,8 +664,12 @@ def test_awq_export_smoke(tmp_path):
655664
def test_proxy_toggle_onnx_transform_policy_for_causal_lm():
656665
model_id = CAUSAL_RUNTIME_MODEL_IDS["gpt2"]
657666
try:
658-
qeff_default = QEFFAutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
659-
qeff_proxy = QEFFAutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, enable_proxy=True)
667+
qeff_default = QEFFAutoModelForCausalLM.from_pretrained(
668+
model_id, trust_remote_code=True, torch_dtype=torch.float32
669+
)
670+
qeff_proxy = QEFFAutoModelForCausalLM.from_pretrained(
671+
model_id, trust_remote_code=True, enable_proxy=True, torch_dtype=torch.float32
672+
)
660673
except Exception as exc:
661674
_skip_on_model_fetch_error(exc, model_id)
662675

0 commit comments

Comments
 (0)