@@ -10,7 +10,8 @@ use modelexpress_common::{
1010 health:: { HealthRequest , HealthResponse , health_service_server:: HealthService } ,
1111 model:: {
1212 FileChunk , ModelDownloadRequest , ModelFileInfo , ModelFileList , ModelFilesRequest ,
13- ModelStatusUpdate , model_service_server:: ModelService ,
13+ ModelProvider as GrpcModelProvider , ModelStatusUpdate ,
14+ model_service_server:: ModelService ,
1415 } ,
1516 } ,
1617 models:: { ModelProvider , ModelStatus } ,
@@ -41,23 +42,6 @@ fn get_server_cache_dir() -> Option<std::path::PathBuf> {
4142 }
4243}
4344
44- /// Convert gRPC provider to internal ModelProvider enum
45- ///
46- /// Falls back to HuggingFace provider if the conversion fails or an invalid
47- /// provider value is provided. A warning is logged when fallback occurs.
48- fn convert_provider ( grpc_provider : i32 ) -> ModelProvider {
49- match modelexpress_common:: grpc:: model:: ModelProvider :: try_from ( grpc_provider) {
50- Ok ( provider) => provider. into ( ) ,
51- Err ( _) => {
52- tracing:: warn!(
53- "Invalid provider value {}, falling back to HuggingFace" ,
54- grpc_provider
55- ) ;
56- ModelProvider :: HuggingFace
57- }
58- }
59- }
60-
6145/// Health service implementation
6246#[ derive( Debug , Default ) ]
6347pub struct HealthServiceImpl ;
@@ -183,7 +167,13 @@ impl ModelService for ModelServiceImpl {
183167 let ( tx, rx) = tokio:: sync:: mpsc:: channel ( 4 ) ;
184168
185169 // Convert gRPC provider to our enum
186- let provider = convert_provider ( model_request. provider ) ;
170+ let grpc_provider = GrpcModelProvider :: try_from ( model_request. provider ) . map_err ( |_| {
171+ Status :: invalid_argument ( format ! (
172+ "Invalid provider value: {}" ,
173+ model_request. provider
174+ ) )
175+ } ) ?;
176+ let provider = ModelProvider :: from ( grpc_provider) ;
187177 let model_name = download:: canonical_model_name ( & model_request. model_name , provider)
188178 . map_err ( |e| Status :: invalid_argument ( e. to_string ( ) ) ) ?;
189179 let ignore_weights = model_request. ignore_weights ;
@@ -202,8 +192,7 @@ impl ModelService for ModelServiceImpl {
202192 Some ( "Previous download failed - retrying" . to_string ( ) )
203193 }
204194 } ,
205- provider : modelexpress_common:: grpc:: model:: ModelProvider :: from ( provider)
206- as i32 ,
195+ provider : grpc_provider as i32 ,
207196 } ;
208197
209198 if tx. send ( Ok ( update) ) . await . is_err ( ) {
@@ -232,7 +221,7 @@ impl ModelService for ModelServiceImpl {
232221 ModelStatus :: ERROR => Some ( "Model download failed" . to_string ( ) ) ,
233222 ModelStatus :: DOWNLOADING => Some ( "Download still in progress" . to_string ( ) ) ,
234223 } ,
235- provider : modelexpress_common :: grpc :: model :: ModelProvider :: from ( provider ) as i32 ,
224+ provider : grpc_provider as i32 ,
236225 } ;
237226
238227 let _ = tx. send ( Ok ( final_update) ) . await ;
@@ -253,7 +242,13 @@ impl ModelService for ModelServiceImpl {
253242 } ;
254243
255244 // Convert gRPC provider to our enum
256- let provider = convert_provider ( files_request. provider ) ;
245+ let grpc_provider = GrpcModelProvider :: try_from ( files_request. provider ) . map_err ( |_| {
246+ Status :: invalid_argument ( format ! (
247+ "Invalid provider value: {}" ,
248+ files_request. provider
249+ ) )
250+ } ) ?;
251+ let provider = ModelProvider :: from ( grpc_provider) ;
257252 let model_name = download:: canonical_model_name ( & files_request. model_name , provider)
258253 . map_err ( |e| Status :: invalid_argument ( e. to_string ( ) ) ) ?;
259254 let provider_impl = download:: get_provider ( provider) ;
@@ -426,7 +421,13 @@ impl ModelService for ModelServiceImpl {
426421 let files_request = request. into_inner ( ) ;
427422
428423 // Convert gRPC provider to our enum
429- let provider = convert_provider ( files_request. provider ) ;
424+ let grpc_provider = GrpcModelProvider :: try_from ( files_request. provider ) . map_err ( |_| {
425+ Status :: invalid_argument ( format ! (
426+ "Invalid provider value: {}" ,
427+ files_request. provider
428+ ) )
429+ } ) ?;
430+ let provider = ModelProvider :: from ( grpc_provider) ;
430431 let model_name = download:: canonical_model_name ( & files_request. model_name , provider)
431432 . map_err ( |e| Status :: invalid_argument ( e. to_string ( ) ) ) ?;
432433 let provider_impl = download:: get_provider ( provider) ;
@@ -549,7 +550,7 @@ impl ModelDownloadTracker {
549550 model_name : model_name. clone ( ) ,
550551 status : modelexpress_common:: grpc:: model:: ModelStatus :: from ( status) as i32 ,
551552 message,
552- provider : modelexpress_common :: grpc :: model :: ModelProvider :: from ( provider) as i32 ,
553+ provider : GrpcModelProvider :: from ( provider) as i32 ,
553554 } ;
554555
555556 for channel in channels {
@@ -619,8 +620,7 @@ impl ModelDownloadTracker {
619620 status : modelexpress_common:: grpc:: model:: ModelStatus :: from ( ModelStatus :: ERROR )
620621 as i32 ,
621622 message : Some ( "Database error occurred" . to_string ( ) ) ,
622- provider : modelexpress_common:: grpc:: model:: ModelProvider :: from ( provider)
623- as i32 ,
623+ provider : GrpcModelProvider :: from ( provider) as i32 ,
624624 } ;
625625 let _ = tx. send ( Ok ( error_update) ) . await ;
626626 return ModelStatus :: ERROR ;
@@ -636,7 +636,7 @@ impl ModelDownloadTracker {
636636 ModelStatus :: DOWNLOADING => Some ( "Model download in progress" . to_string ( ) ) ,
637637 ModelStatus :: ERROR => Some ( "Previous download failed - retrying" . to_string ( ) ) ,
638638 } ,
639- provider : modelexpress_common :: grpc :: model :: ModelProvider :: from ( provider) as i32 ,
639+ provider : GrpcModelProvider :: from ( provider) as i32 ,
640640 } ;
641641
642642 let _ = tx. send ( Ok ( update) ) . await ;
@@ -1221,6 +1221,57 @@ mod tests {
12211221 assert_eq ! ( status. code( ) , tonic:: Code :: NotFound ) ;
12221222 }
12231223
1224+ #[ tokio:: test]
1225+ async fn test_ensure_model_downloaded_rejects_invalid_provider ( ) {
1226+ let service = ModelServiceImpl ;
1227+
1228+ let request = Request :: new ( ModelDownloadRequest {
1229+ model_name : "test/model" . to_string ( ) ,
1230+ provider : 99 ,
1231+ ignore_weights : false ,
1232+ } ) ;
1233+
1234+ let result = service. ensure_model_downloaded ( request) . await ;
1235+ assert ! ( result. is_err( ) ) ;
1236+ let status = result. expect_err ( "Should return error" ) ;
1237+ assert_eq ! ( status. code( ) , tonic:: Code :: InvalidArgument ) ;
1238+ assert ! ( status. message( ) . contains( "Invalid provider value" ) ) ;
1239+ }
1240+
1241+ #[ tokio:: test]
1242+ async fn test_list_model_files_rejects_invalid_provider ( ) {
1243+ let service = ModelServiceImpl ;
1244+
1245+ let request = Request :: new ( ModelFilesRequest {
1246+ model_name : "test/model" . to_string ( ) ,
1247+ provider : 99 ,
1248+ chunk_size : 0 ,
1249+ } ) ;
1250+
1251+ let result = service. list_model_files ( request) . await ;
1252+ assert ! ( result. is_err( ) ) ;
1253+ let status = result. expect_err ( "Should return error" ) ;
1254+ assert_eq ! ( status. code( ) , tonic:: Code :: InvalidArgument ) ;
1255+ assert ! ( status. message( ) . contains( "Invalid provider value" ) ) ;
1256+ }
1257+
1258+ #[ tokio:: test]
1259+ async fn test_stream_model_files_rejects_invalid_provider ( ) {
1260+ let service = ModelServiceImpl ;
1261+
1262+ let request = Request :: new ( ModelFilesRequest {
1263+ model_name : "test/model" . to_string ( ) ,
1264+ provider : 99 ,
1265+ chunk_size : 1024 ,
1266+ } ) ;
1267+
1268+ let result = service. stream_model_files ( request) . await ;
1269+ assert ! ( result. is_err( ) ) ;
1270+ let status = result. expect_err ( "Should return error" ) ;
1271+ assert_eq ! ( status. code( ) , tonic:: Code :: InvalidArgument ) ;
1272+ assert ! ( status. message( ) . contains( "Invalid provider value" ) ) ;
1273+ }
1274+
12241275 #[ tokio:: test]
12251276 #[ allow( clippy:: await_holding_lock) ]
12261277 async fn test_stream_model_files_hf_first_chunk_includes_commit_hash ( ) {
0 commit comments