Skip to content

Commit 5832c7e

Browse files
fix: reject invalid grpc model providers
Signed-off-by: Andrew Paprotsky <apaprotskyi@nvidia.com>
1 parent 2a10c23 commit 5832c7e

File tree

1 file changed

+79
-28
lines changed

1 file changed

+79
-28
lines changed

modelexpress_server/src/services.rs

Lines changed: 79 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ use modelexpress_common::{
1010
health::{HealthRequest, HealthResponse, health_service_server::HealthService},
1111
model::{
1212
FileChunk, ModelDownloadRequest, ModelFileInfo, ModelFileList, ModelFilesRequest,
13-
ModelStatusUpdate, model_service_server::ModelService,
13+
ModelProvider as GrpcModelProvider, ModelStatusUpdate,
14+
model_service_server::ModelService,
1415
},
1516
},
1617
models::{ModelProvider, ModelStatus},
@@ -41,23 +42,6 @@ fn get_server_cache_dir() -> Option<std::path::PathBuf> {
4142
}
4243
}
4344

44-
/// Convert gRPC provider to internal ModelProvider enum
45-
///
46-
/// Falls back to HuggingFace provider if the conversion fails or an invalid
47-
/// provider value is provided. A warning is logged when fallback occurs.
48-
fn convert_provider(grpc_provider: i32) -> ModelProvider {
49-
match modelexpress_common::grpc::model::ModelProvider::try_from(grpc_provider) {
50-
Ok(provider) => provider.into(),
51-
Err(_) => {
52-
tracing::warn!(
53-
"Invalid provider value {}, falling back to HuggingFace",
54-
grpc_provider
55-
);
56-
ModelProvider::HuggingFace
57-
}
58-
}
59-
}
60-
6145
/// Health service implementation
6246
#[derive(Debug, Default)]
6347
pub struct HealthServiceImpl;
@@ -183,7 +167,13 @@ impl ModelService for ModelServiceImpl {
183167
let (tx, rx) = tokio::sync::mpsc::channel(4);
184168

185169
// Convert gRPC provider to our enum
186-
let provider = convert_provider(model_request.provider);
170+
let grpc_provider = GrpcModelProvider::try_from(model_request.provider).map_err(|_| {
171+
Status::invalid_argument(format!(
172+
"Invalid provider value: {}",
173+
model_request.provider
174+
))
175+
})?;
176+
let provider = ModelProvider::from(grpc_provider);
187177
let model_name = download::canonical_model_name(&model_request.model_name, provider)
188178
.map_err(|e| Status::invalid_argument(e.to_string()))?;
189179
let ignore_weights = model_request.ignore_weights;
@@ -202,8 +192,7 @@ impl ModelService for ModelServiceImpl {
202192
Some("Previous download failed - retrying".to_string())
203193
}
204194
},
205-
provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
206-
as i32,
195+
provider: grpc_provider as i32,
207196
};
208197

209198
if tx.send(Ok(update)).await.is_err() {
@@ -232,7 +221,7 @@ impl ModelService for ModelServiceImpl {
232221
ModelStatus::ERROR => Some("Model download failed".to_string()),
233222
ModelStatus::DOWNLOADING => Some("Download still in progress".to_string()),
234223
},
235-
provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
224+
provider: grpc_provider as i32,
236225
};
237226

238227
let _ = tx.send(Ok(final_update)).await;
@@ -253,7 +242,13 @@ impl ModelService for ModelServiceImpl {
253242
};
254243

255244
// Convert gRPC provider to our enum
256-
let provider = convert_provider(files_request.provider);
245+
let grpc_provider = GrpcModelProvider::try_from(files_request.provider).map_err(|_| {
246+
Status::invalid_argument(format!(
247+
"Invalid provider value: {}",
248+
files_request.provider
249+
))
250+
})?;
251+
let provider = ModelProvider::from(grpc_provider);
257252
let model_name = download::canonical_model_name(&files_request.model_name, provider)
258253
.map_err(|e| Status::invalid_argument(e.to_string()))?;
259254
let provider_impl = download::get_provider(provider);
@@ -426,7 +421,13 @@ impl ModelService for ModelServiceImpl {
426421
let files_request = request.into_inner();
427422

428423
// Convert gRPC provider to our enum
429-
let provider = convert_provider(files_request.provider);
424+
let grpc_provider = GrpcModelProvider::try_from(files_request.provider).map_err(|_| {
425+
Status::invalid_argument(format!(
426+
"Invalid provider value: {}",
427+
files_request.provider
428+
))
429+
})?;
430+
let provider = ModelProvider::from(grpc_provider);
430431
let model_name = download::canonical_model_name(&files_request.model_name, provider)
431432
.map_err(|e| Status::invalid_argument(e.to_string()))?;
432433
let provider_impl = download::get_provider(provider);
@@ -549,7 +550,7 @@ impl ModelDownloadTracker {
549550
model_name: model_name.clone(),
550551
status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
551552
message,
552-
provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
553+
provider: GrpcModelProvider::from(provider) as i32,
553554
};
554555

555556
for channel in channels {
@@ -619,8 +620,7 @@ impl ModelDownloadTracker {
619620
status: modelexpress_common::grpc::model::ModelStatus::from(ModelStatus::ERROR)
620621
as i32,
621622
message: Some("Database error occurred".to_string()),
622-
provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
623-
as i32,
623+
provider: GrpcModelProvider::from(provider) as i32,
624624
};
625625
let _ = tx.send(Ok(error_update)).await;
626626
return ModelStatus::ERROR;
@@ -636,7 +636,7 @@ impl ModelDownloadTracker {
636636
ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
637637
ModelStatus::ERROR => Some("Previous download failed - retrying".to_string()),
638638
},
639-
provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
639+
provider: GrpcModelProvider::from(provider) as i32,
640640
};
641641

642642
let _ = tx.send(Ok(update)).await;
@@ -1221,6 +1221,57 @@ mod tests {
12211221
assert_eq!(status.code(), tonic::Code::NotFound);
12221222
}
12231223

1224+
#[tokio::test]
1225+
async fn test_ensure_model_downloaded_rejects_invalid_provider() {
1226+
let service = ModelServiceImpl;
1227+
1228+
let request = Request::new(ModelDownloadRequest {
1229+
model_name: "test/model".to_string(),
1230+
provider: 99,
1231+
ignore_weights: false,
1232+
});
1233+
1234+
let result = service.ensure_model_downloaded(request).await;
1235+
assert!(result.is_err());
1236+
let status = result.expect_err("Should return error");
1237+
assert_eq!(status.code(), tonic::Code::InvalidArgument);
1238+
assert!(status.message().contains("Invalid provider value"));
1239+
}
1240+
1241+
#[tokio::test]
1242+
async fn test_list_model_files_rejects_invalid_provider() {
1243+
let service = ModelServiceImpl;
1244+
1245+
let request = Request::new(ModelFilesRequest {
1246+
model_name: "test/model".to_string(),
1247+
provider: 99,
1248+
chunk_size: 0,
1249+
});
1250+
1251+
let result = service.list_model_files(request).await;
1252+
assert!(result.is_err());
1253+
let status = result.expect_err("Should return error");
1254+
assert_eq!(status.code(), tonic::Code::InvalidArgument);
1255+
assert!(status.message().contains("Invalid provider value"));
1256+
}
1257+
1258+
#[tokio::test]
1259+
async fn test_stream_model_files_rejects_invalid_provider() {
1260+
let service = ModelServiceImpl;
1261+
1262+
let request = Request::new(ModelFilesRequest {
1263+
model_name: "test/model".to_string(),
1264+
provider: 99,
1265+
chunk_size: 1024,
1266+
});
1267+
1268+
let result = service.stream_model_files(request).await;
1269+
assert!(result.is_err());
1270+
let status = result.expect_err("Should return error");
1271+
assert_eq!(status.code(), tonic::Code::InvalidArgument);
1272+
assert!(status.message().contains("Invalid provider value"));
1273+
}
1274+
12241275
#[tokio::test]
12251276
#[allow(clippy::await_holding_lock)]
12261277
async fn test_stream_model_files_hf_first_chunk_includes_commit_hash() {

0 commit comments

Comments
 (0)