@@ -471,7 +471,9 @@ def test_whisper_export_smoke(tmp_path):
471471@pytest .mark .llm_model
472472def test_causal_subfunction_export_smoke (tmp_path ):
473473 model_id = CAUSAL_RUNTIME_MODEL_IDS ["gpt2" ]
474- model_hf = AutoModelForCausalLM .from_pretrained (model_id , ** MODEL_KWARGS , low_cpu_mem_usage = False )
474+ model_hf = AutoModelForCausalLM .from_pretrained (
475+ model_id , ** MODEL_KWARGS , low_cpu_mem_usage = False , torch_dtype = torch .float32
476+ )
475477 model_hf .eval ()
476478 qeff_model = QEFFAutoModelForCausalLM (model_hf )
477479
@@ -499,7 +501,9 @@ def test_causal_subfunction_export_smoke(tmp_path):
499501def test_causal_compile_with_subfunctions_all_models (model_type , model_id , tmp_path ):
500502 del model_type
501503 try :
502- qeff_model = QEFFAutoModelForCausalLM .from_pretrained (model_id , trust_remote_code = True )
504+ qeff_model = QEFFAutoModelForCausalLM .from_pretrained (
505+ model_id , trust_remote_code = True , torch_dtype = torch .float32
506+ )
503507 except Exception as exc :
504508 _skip_on_model_fetch_error (exc , model_id )
505509
@@ -607,6 +611,7 @@ def test_causal_subfunction_and_proxy_export_smoke_gpt2(tmp_path):
607611 model_id ,
608612 trust_remote_code = True ,
609613 enable_proxy = True ,
614+ torch_dtype = torch .float32 ,
610615 )
611616 except Exception as exc :
612617 _skip_on_model_fetch_error (exc , model_id )
@@ -621,7 +626,9 @@ def test_causal_subfunction_and_proxy_export_smoke_gpt2(tmp_path):
621626
622627@pytest .mark .llm_model
623628def test_prefix_caching_continuous_batching_export_and_ort_smoke (tmp_path ):
624- qeff_model = QEFFAutoModelForCausalLM .from_pretrained (PREFIX_CACHING_MODEL_ID , continuous_batching = True )
629+ qeff_model = QEFFAutoModelForCausalLM .from_pretrained (
630+ PREFIX_CACHING_MODEL_ID , continuous_batching = True , torch_dtype = torch .float32
631+ )
625632 onnx_path = _exported_onnx_path (qeff_model .export (tmp_path / "prefix-caching" ))
626633 onnx_model = onnx .load (onnx_path , load_external_data = False )
627634
@@ -638,7 +645,9 @@ def test_prefix_caching_continuous_batching_export_and_ort_smoke(tmp_path):
638645def test_awq_export_smoke (tmp_path ):
639646 replace_transformers_quantizers ()
640647 try :
641- model_hf = AutoModelForCausalLM .from_pretrained (TINY_AWQ_MODEL_ID , low_cpu_mem_usage = False )
648+ model_hf = AutoModelForCausalLM .from_pretrained (
649+ TINY_AWQ_MODEL_ID , low_cpu_mem_usage = False , torch_dtype = torch .float32
650+ )
642651 except Exception as exc :
643652 _skip_on_model_fetch_error (exc , TINY_AWQ_MODEL_ID )
644653 model_hf .eval ()
@@ -655,8 +664,12 @@ def test_awq_export_smoke(tmp_path):
655664def test_proxy_toggle_onnx_transform_policy_for_causal_lm ():
656665 model_id = CAUSAL_RUNTIME_MODEL_IDS ["gpt2" ]
657666 try :
658- qeff_default = QEFFAutoModelForCausalLM .from_pretrained (model_id , trust_remote_code = True )
659- qeff_proxy = QEFFAutoModelForCausalLM .from_pretrained (model_id , trust_remote_code = True , enable_proxy = True )
667+ qeff_default = QEFFAutoModelForCausalLM .from_pretrained (
668+ model_id , trust_remote_code = True , torch_dtype = torch .float32
669+ )
670+ qeff_proxy = QEFFAutoModelForCausalLM .from_pretrained (
671+ model_id , trust_remote_code = True , enable_proxy = True , torch_dtype = torch .float32
672+ )
660673 except Exception as exc :
661674 _skip_on_model_fetch_error (exc , model_id )
662675
0 commit comments