Skip to content

Commit dca6139

Browse files
feat!: add GCS model provider
- Add first-class GCS support across common, client, and server layers, including proto/provider enums, DB mapping, cache stats/listing/clear, and provider-aware request routing. - Introduce `modelexpress_common::providers::gcs` with strict `gs://bucket/path` canonicalization, safe cache layout validation, ancestor/descendant overlap protection, and resumable downloads with CRC32C verification. - Store GCS cache metadata in `.mx/manifest.json` and use that manifest for cache discovery, listing, and download planning instead of marker-file metadata. - Canonicalize model names at request boundaries and keep the server as the source of truth for `Client::request_model`, with direct download used only when the initial client connection cannot be established. - Update dependency wiring for Google Cloud Storage support and align tests with the final provider-aware client and cache behavior. BREAKING CHANGE: - `Client::request_model` now requires an explicit `provider` argument. - Removed public APIs: `preload_model_to_cache`, `request_model_with_provider_and_fallback`, `request_model_server_only`, and `download_model_directly`. - Renamed `request_model_with_provider` to `request_model_on_server`. Signed-off-by: Andrew Paprotsky <apaprotskyi@nvidia.com>
1 parent 75d0dfc commit dca6139

File tree

18 files changed

+4919
-440
lines changed

18 files changed

+4919
-440
lines changed

Cargo.lock

Lines changed: 1061 additions & 87 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ axum = "0.8"
2828
chrono = { version = "0.4", features = ["serde"] }
2929
clap = { version = "4.5", features = ["derive", "env"] }
3030
config = { version = "0.15", features = ["yaml", "toml", "json"] }
31+
crc32c = "0.6.8"
3132
colored = "3.0.0"
33+
google-cloud-storage = { version = "1.9.0", default-features = false }
3234
hf-hub = { version = "0.4.3", default-features = false, features = [
3335
"tokio",
3436
"rustls-tls",
@@ -39,6 +41,7 @@ modelexpress-client = { path = "modelexpress_client", version = "0.3.0" }
3941
modelexpress-server = { path = "modelexpress_server", version = "0.3.0" }
4042
once_cell = "1.21.3"
4143
prost = "0.13"
44+
rustls = { version = "0.23.37", default-features = false, features = ["ring", "std"] }
4245
rusqlite = { version = "0.37", features = ["bundled", "chrono"] }
4346
serde = { version = "1.0", features = ["derive"] }
4447
serde_json = "1.0"

modelexpress_client/src/bin/cli.rs

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ async fn main() {
115115
#[allow(clippy::expect_used)]
116116
mod tests {
117117
use super::modules::args::{Cli, Commands};
118-
use clap::Parser;
118+
use clap::{Parser, ValueEnum};
119119
use modelexpress_client::ModelProvider;
120120

121121
#[test]
@@ -139,6 +139,16 @@ mod tests {
139139
assert_eq!(provider, ModelProvider::HuggingFace);
140140
}
141141

142+
#[test]
143+
fn test_cli_model_provider_value_enum() {
144+
let parsed = ModelProvider::from_str("hugging-face", false)
145+
.expect("Failed to parse hugging-face provider");
146+
assert_eq!(parsed, ModelProvider::HuggingFace);
147+
148+
let parsed = ModelProvider::from_str("gcs", false).expect("Failed to parse gcs provider");
149+
assert_eq!(parsed, ModelProvider::Gcs);
150+
}
151+
142152
#[test]
143153
fn test_cli_flattened_client_args_parsing() {
144154
// Test that the flattened ClientArgs fields are accessible through Cli
@@ -200,20 +210,40 @@ mod tests {
200210
}
201211

202212
#[test]
203-
fn test_cli_model_clear_defaults_to_hugging_face_provider() {
213+
fn test_cli_model_clear_parses_explicit_provider() {
204214
let parsed = Cli::try_parse_from([
205215
"modelexpress-cli",
206216
"model",
207217
"clear",
208218
"--provider",
209-
"hugging-face",
210-
"dev/bake/qwen/rev123",
211-
]);
212-
assert!(parsed.is_ok());
219+
"gcs",
220+
"gs://bucket/dev/bake/qwen/rev123",
221+
])
222+
.expect("Expected clear command to parse with explicit provider");
213223

214-
let missing_provider =
215-
Cli::try_parse_from(["modelexpress-cli", "model", "clear", "dev/bake/qwen/rev123"])
216-
.expect("Expected clear command to parse without provider");
224+
let Commands::Model { command } = parsed.command else {
225+
panic!("Expected model command");
226+
};
227+
let super::modules::args::ModelCommands::Clear {
228+
provider,
229+
model_name,
230+
} = command
231+
else {
232+
panic!("Expected clear subcommand");
233+
};
234+
assert_eq!(provider, ModelProvider::Gcs);
235+
assert_eq!(model_name, "gs://bucket/dev/bake/qwen/rev123");
236+
}
237+
238+
#[test]
239+
fn test_cli_model_clear_defaults_to_hugging_face_provider() {
240+
let missing_provider = Cli::try_parse_from([
241+
"modelexpress-cli",
242+
"model",
243+
"clear",
244+
"gs://bucket/dev/bake/qwen/rev123",
245+
])
246+
.expect("Expected clear command to parse without provider");
217247

218248
let Commands::Model { command } = missing_provider.command else {
219249
panic!("Expected model command");
@@ -226,6 +256,6 @@ mod tests {
226256
panic!("Expected clear subcommand");
227257
};
228258
assert_eq!(provider, ModelProvider::HuggingFace);
229-
assert_eq!(model_name, "dev/bake/qwen/rev123");
259+
assert_eq!(model_name, "gs://bucket/dev/bake/qwen/rev123");
230260
}
231261
}

modelexpress_client/src/bin/fallback_test.rs

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,21 @@
44
#![allow(clippy::expect_used)]
55

66
use modelexpress_client::{Client, ClientConfig, ModelProvider};
7-
use tracing::{error, info};
7+
use tracing::info;
88

99
#[tokio::main]
1010
async fn main() -> Result<(), Box<dyn std::error::Error>> {
11-
// Initialize logging
1211
tracing_subscriber::fmt::init();
1312

14-
info!("Testing model download with server fallback...");
13+
info!("Testing smart fallback with unavailable server...");
1514

16-
let model_name = "google-t5/t5-small";
17-
18-
// Test smart fallback - this should work whether server is running or not
19-
info!("Attempting to download model with smart fallback...");
20-
21-
match Client::request_model_with_smart_fallback(
22-
model_name,
15+
Client::request_model_with_smart_fallback(
16+
"google-t5/t5-small",
2317
ModelProvider::HuggingFace,
24-
ClientConfig::default(),
18+
ClientConfig::for_testing("http://127.0.0.1:54321"),
2519
false,
2620
)
27-
.await
28-
{
29-
Ok(()) => {
30-
info!("✅ SUCCESS: Model '{model_name}' downloaded successfully!");
31-
info!(
32-
"The download worked either via server (if running) or direct download (if server unavailable)"
33-
);
34-
}
35-
Err(e) => {
36-
error!("❌ FAILED: Could not download model '{model_name}': {e}");
37-
return Err(e.into());
38-
}
39-
}
21+
.await?;
4022

4123
Ok(())
4224
}

modelexpress_client/src/bin/modules/handlers.rs

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ use super::output::{print_human_readable, print_output};
66
use super::payload::read_payload;
77
use colored::*;
88
use modelexpress_client::{Client, ClientConfig, ModelProvider};
9-
use modelexpress_common::cache::{CacheConfig, CacheStats, ModelInfo};
9+
use modelexpress_common::{
10+
cache::{CacheConfig, CacheStats, ModelInfo},
11+
download,
12+
};
1013
use serde_json::Value;
1114
use std::io::Write;
1215
use std::path::PathBuf;
@@ -174,20 +177,13 @@ async fn download_model(
174177
let result = match strategy {
175178
DownloadStrategy::SmartFallback => {
176179
debug!("Using smart fallback strategy");
180+
let mut config = config.clone();
177181
if let Some(cache_config) = cache_config {
178-
let mut client = Client::new_with_cache(config.clone(), cache_config).await?;
179-
client
180-
.preload_model_to_cache(&model_name, provider, false)
181-
.await
182-
} else {
183-
Client::request_model_with_smart_fallback(
184-
model_name.clone(),
185-
provider,
186-
config,
187-
false,
188-
)
189-
.await
182+
config.cache = cache_config;
190183
}
184+
Client::request_model_with_smart_fallback(model_name.clone(), provider, config, false)
185+
.await
186+
.map(|_| ())
191187
}
192188
DownloadStrategy::ServerOnly => {
193189
debug!("Using server-only strategy");
@@ -197,12 +193,23 @@ async fn download_model(
197193
Client::new(config.clone()).await?
198194
};
199195
client
200-
.request_model_with_provider(&model_name, provider, false)
196+
.request_model(&model_name, provider, false)
201197
.await
198+
.map(|_| ())
202199
}
203200
DownloadStrategy::Direct => {
204201
debug!("Using direct download strategy");
205-
Client::download_model_directly(model_name.clone(), provider, false).await
202+
download::download_model(
203+
&model_name,
204+
provider,
205+
cache_config.map(|config| config.local_path),
206+
false,
207+
)
208+
.await
209+
.map(|_| ())
210+
.map_err(|e| {
211+
modelexpress_common::Error::Server(format!("Direct download failed: {e}")).into()
212+
})
206213
}
207214
};
208215

@@ -246,7 +253,7 @@ async fn download_model(
246253
print_output(&output, format);
247254
}
248255
}
249-
return Err(Box::new(e));
256+
return Err(e);
250257
}
251258
}
252259

modelexpress_client/src/bin/test_client.rs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#![allow(clippy::expect_used)]
1010

1111
use modelexpress_client::{Client, ClientConfig};
12+
use modelexpress_common::download;
1213
use modelexpress_common::models::ModelProvider;
1314
use std::env;
1415
use std::time::{Duration, Instant};
@@ -100,7 +101,7 @@ async fn run_concurrent_model_test(model_name: &str) -> Result<(), Box<dyn std::
100101
info!("Client 1: Requesting model {model_name1}");
101102
let start = Instant::now();
102103
client1
103-
.request_model(model_name1, false)
104+
.request_model(model_name1, ModelProvider::default(), false)
104105
.await
105106
.expect("Client 1 failed to download model");
106107
info!("Client 1: Model downloaded in {:?}", start.elapsed());
@@ -116,7 +117,7 @@ async fn run_concurrent_model_test(model_name: &str) -> Result<(), Box<dyn std::
116117
info!("Client 2: Requesting model {model_name2}");
117118
let start = Instant::now();
118119
client2
119-
.request_model(model_name2, false)
120+
.request_model(model_name2, ModelProvider::default(), false)
120121
.await
121122
.expect("Client 2 failed to download model");
122123
info!("Client 2: Model downloaded in {:?}", start.elapsed());
@@ -140,7 +141,10 @@ async fn run_single_model_test(model_name: &str) -> Result<(), Box<dyn std::erro
140141
info!("Client: Requesting model {model_name}");
141142
let start = Instant::now();
142143

143-
match client.request_model(model_name.to_string(), false).await {
144+
match client
145+
.request_model(model_name.to_string(), ModelProvider::default(), false)
146+
.await
147+
{
144148
Ok(()) => {
145149
info!("Client: Model downloaded in {:?}", start.elapsed());
146150
info!("Client completed in {:?}", start_time.elapsed());
@@ -154,7 +158,7 @@ async fn run_single_model_test(model_name: &str) -> Result<(), Box<dyn std::erro
154158
}
155159
}
156160

157-
/// Test fallback functionality including server fallback, direct download, and smart fallback
161+
/// Test download functionality including server fallback, direct download, and smart fallback
158162
async fn run_fallback_test(model_name: &str) -> Result<(), Box<dyn std::error::Error>> {
159163
info!("Testing fallback functionality (assuming server is running)...");
160164
let mut client = Client::new(ClientConfig::default()).await?;
@@ -163,7 +167,7 @@ async fn run_fallback_test(model_name: &str) -> Result<(), Box<dyn std::error::E
163167

164168
// This should work via server since it's running
165169
match client
166-
.request_model_with_provider_and_fallback(model_name, ModelProvider::HuggingFace, false)
170+
.request_model(model_name, ModelProvider::HuggingFace, false)
167171
.await
168172
{
169173
Ok(()) => {
@@ -181,21 +185,27 @@ async fn run_fallback_test(model_name: &str) -> Result<(), Box<dyn std::error::E
181185
info!("Testing direct download (bypassing server)...");
182186
let start_direct = Instant::now();
183187

184-
match Client::download_model_directly(model_name, ModelProvider::HuggingFace, false).await {
185-
Ok(()) => {
188+
match download::download_model(
189+
model_name,
190+
ModelProvider::HuggingFace,
191+
Some(ClientConfig::default().cache.local_path.clone()),
192+
false,
193+
)
194+
.await
195+
{
196+
Ok(_) => {
186197
info!("Model downloaded directly in {:?}", start_direct.elapsed());
187198
}
188199
Err(e) => {
189200
return Err(format!("Failed to download model directly: {e}").into());
190201
}
191202
}
192203

193-
// Test smart fallback (will use server if available, direct download if not)
194204
info!("Testing smart fallback...");
195205
let start_smart = Instant::now();
196206

197207
match Client::request_model_with_smart_fallback(
198-
model_name,
208+
model_name.to_string(),
199209
ModelProvider::HuggingFace,
200210
ClientConfig::default(),
201211
false,
@@ -207,11 +217,12 @@ async fn run_fallback_test(model_name: &str) -> Result<(), Box<dyn std::error::E
207217
"Model downloaded with smart fallback in {:?}",
208218
start_smart.elapsed()
209219
);
210-
info!(
211-
"FALLBACK TEST PASSED: Server-with-fallback, direct download, and smart fallback all work"
212-
);
213-
Ok(())
214220
}
215-
Err(e) => Err(format!("Failed to download model with smart fallback: {e}").into()),
221+
Err(e) => {
222+
return Err(format!("Failed to download model with smart fallback: {e}").into());
223+
}
216224
}
225+
226+
info!("FALLBACK TEST PASSED: Server, direct download, and smart fallback paths all work");
227+
Ok(())
217228
}

0 commit comments

Comments
 (0)