Skip to content

Commit cad701a

Browse files
author
Name
committed
changes
1 parent 980c470 commit cad701a

File tree

3 files changed

+116
-65
lines changed

3 files changed

+116
-65
lines changed

src/rust/lqosd/src/lts2_sys/control_channel.rs

Lines changed: 110 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use anyhow::{Result, bail};
2-
use futures_util::{StreamExt, sink::SinkExt};
2+
use futures_util::{StreamExt, sink::SinkExt, stream::SplitSink};
33
use serde::{Deserialize, Serialize};
44
use std::{collections::HashMap, time::Duration};
55
use tokio::net::TcpStream;
@@ -23,6 +23,11 @@ pub use messages::{
2323
};
2424

2525
const SHAPER_VERSION_STRING: &str = include_str!("../../../../VERSION_STRING");
26+
const CONTROL_CHANNEL_QUEUE_DEPTH: usize = 256;
27+
const CONNECTION_COMMAND_QUEUE_DEPTH: usize = 1024;
28+
const SOCKET_SENDER_QUEUE_DEPTH: usize = 32;
29+
30+
type ControlSocketWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
2631

2732
fn current_shaper_version() -> Option<String> {
2833
let version = SHAPER_VERSION_STRING.trim();
@@ -133,7 +138,7 @@ pub struct ControlChannelBuilder {
133138

134139
pub fn init_control_channel() -> Result<ControlChannelBuilder> {
135140
// Doing this two-step: make the channel here and then spawn the task
136-
let (tx, rx) = tokio::sync::mpsc::channel(256);
141+
let (tx, rx) = tokio::sync::mpsc::channel(CONTROL_CHANNEL_QUEUE_DEPTH);
137142
Ok(ControlChannelBuilder { tx, rx })
138143
}
139144

@@ -148,7 +153,7 @@ pub async fn start_control_channel(builder: ControlChannelBuilder) -> Result<()>
148153

149154
async fn control_channel_loop(mut builder: ControlChannelBuilder) -> Result<()> {
150155
// Handle the persistent channel to Insight here
151-
let (tx, rx) = tokio::sync::mpsc::channel::<ConnectionCommand>(1024);
156+
let (tx, rx) = tokio::sync::mpsc::channel::<ConnectionCommand>(CONNECTION_COMMAND_QUEUE_DEPTH);
152157
tokio::spawn(persistent_connection(rx));
153158

154159
while let Some(cmd) = builder.rx.recv().await {
@@ -273,6 +278,72 @@ const TCP_TIMEOUT: Duration = Duration::from_secs(30);
273278
// Prevent unbounded growth while waiting for Welcome
274279
const MAX_PENDING_CHATBOT_MESSAGES: usize = 256;
275280

281+
fn encode_ws_binary(
282+
message: &messages::WsMessage,
283+
label: &'static str,
284+
) -> std::result::Result<Message, ()> {
285+
let Ok((_, _, bytes)) = message.to_bytes() else {
286+
error!("Failed to serialize {label} message");
287+
return Err(());
288+
};
289+
Ok(Message::Binary(bytes))
290+
}
291+
292+
async fn send_ws_message_now(
293+
write: &mut ControlSocketWriter,
294+
message: Message,
295+
label: &'static str,
296+
) -> std::result::Result<(), ()> {
297+
let Ok(Ok(_)) = timeout(TCP_TIMEOUT, write.send(message)).await else {
298+
error!("Failed to send {label} message");
299+
return Err(());
300+
};
301+
Ok(())
302+
}
303+
304+
async fn send_ingest_batch(
305+
write: &mut ControlSocketWriter,
306+
serial: usize,
307+
chunks: Vec<Vec<u8>>,
308+
) -> std::result::Result<usize, ()> {
309+
let n_chunks = chunks.len();
310+
let byte_count = chunks.iter().map(|chunk| chunk.len()).sum::<usize>();
311+
312+
let begin = messages::WsMessage::BeginIngest {
313+
unique_id: serial as u64,
314+
n_chunks: n_chunks as u64,
315+
};
316+
send_ws_message_now(
317+
write,
318+
encode_ws_binary(&begin, "BeginIngest")?,
319+
"BeginIngest",
320+
)
321+
.await?;
322+
323+
for (chunk_index, chunk) in chunks.into_iter().enumerate() {
324+
let ingest_chunk = messages::WsMessage::IngestChunk {
325+
unique_id: serial as u64,
326+
chunk: chunk_index as u64,
327+
n_chunks: n_chunks as u64,
328+
data: chunk,
329+
};
330+
send_ws_message_now(
331+
write,
332+
encode_ws_binary(&ingest_chunk, "IngestChunk")?,
333+
"IngestChunk",
334+
)
335+
.await?;
336+
}
337+
338+
let end = messages::WsMessage::EndIngest {
339+
unique_id: serial as u64,
340+
n_chunks: n_chunks as u64,
341+
};
342+
send_ws_message_now(write, encode_ws_binary(&end, "EndIngest")?, "EndIngest").await?;
343+
344+
Ok(byte_count)
345+
}
346+
276347
async fn persistent_connection(
277348
mut rx: tokio::sync::mpsc::Receiver<ConnectionCommand>,
278349
) -> std::result::Result<(), String> {
@@ -303,7 +374,7 @@ async fn persistent_connection(
303374
// Split the socket
304375
let (mut write, mut read) = socket.split();
305376
let (socket_sender_tx, mut socket_sender_rx) =
306-
tokio::sync::mpsc::channel::<Message>(32);
377+
tokio::sync::mpsc::channel::<Message>(SOCKET_SENDER_QUEUE_DEPTH);
307378
let mut ping_interval = tokio::time::interval(Duration::from_secs(10));
308379
let mut license_interval = tokio::time::interval(Duration::from_secs(60 * 15)); // 15 minutes
309380
let mut pending_history: HashMap<
@@ -363,61 +434,11 @@ async fn persistent_connection(
363434
info!("Not permitted to send chunks yet");
364435
continue 'message_pump;
365436
}
366-
let n_chunks = chunks.len();
367-
let byte_count = chunks.iter().map(|c| c.len()).sum::<usize>();
368-
369-
// Send BeginIngest
370-
let Ok((_, _, bytes)) = messages::WsMessage::BeginIngest { unique_id: serial as u64, n_chunks: n_chunks as u64 }.to_bytes() else {
371-
error!("Failed to serialize BeginIngest message");
372-
break 'message_pump;
373-
};
374-
if let Err(e) = socket_sender_tx.try_send(Message::Binary(bytes)) {
375-
match e {
376-
TrySendError::Full(_) => {
377-
warn!("Send unavailable: BeginIngest queue full; dropping message");
378-
}
379-
TrySendError::Closed(_) => {
380-
error!("Failed to send BeginIngest message: channel closed");
381-
break 'message_pump;
382-
}
383-
}
384-
}
385-
386-
// Submit Each Chunk
387-
for (i, chunk) in chunks.into_iter().enumerate() {
388-
let Ok((_, _, bytes)) = messages::WsMessage::IngestChunk { unique_id: serial as u64, chunk: i as u64, n_chunks: n_chunks as u64, data: chunk }.to_bytes() else {
389-
error!("Failed to serialize IngestChunk message");
390-
break 'message_pump;
391-
};
392-
if let Err(e) = socket_sender_tx.try_send(Message::Binary(bytes)) {
393-
match e {
394-
TrySendError::Full(_) => {
395-
warn!("Send unavailable: IngestChunk queue full; dropping chunk");
396-
}
397-
TrySendError::Closed(_) => {
398-
error!("Failed to send IngestChunk message: channel closed");
399-
break 'message_pump;
400-
}
401-
}
402-
}
403-
}
404-
405-
// Send EndIngest
406-
let Ok((_, _, bytes)) = messages::WsMessage::EndIngest { unique_id: serial as u64, n_chunks: n_chunks as u64 }.to_bytes() else {
407-
error!("Failed to serialize EndIngest message");
437+
let Ok(byte_count) =
438+
send_ingest_batch(&mut write, serial, chunks).await
439+
else {
408440
break 'message_pump;
409441
};
410-
if let Err(e) = socket_sender_tx.try_send(Message::Binary(bytes)) {
411-
match e {
412-
TrySendError::Full(_) => {
413-
warn!("Send unavailable: EndIngest queue full; dropping message");
414-
}
415-
TrySendError::Closed(_) => {
416-
error!("Failed to send EndIngest message: channel closed");
417-
break 'message_pump;
418-
}
419-
}
420-
}
421442
debug!("Submitted {} bytes for ingestion", byte_count);
422443
}
423444
Some(ConnectionCommand::FetchHistory { request, responder }) => {
@@ -1148,6 +1169,38 @@ async fn persistent_connection(
11481169
}
11491170
}
11501171

1172+
#[cfg(test)]
1173+
mod tests {
1174+
use super::*;
1175+
1176+
#[test]
1177+
fn encode_ws_binary_round_trips_begin_ingest() {
1178+
let encoded = encode_ws_binary(
1179+
&messages::WsMessage::BeginIngest {
1180+
unique_id: 42,
1181+
n_chunks: 7,
1182+
},
1183+
"BeginIngest",
1184+
)
1185+
.expect("BeginIngest should encode");
1186+
1187+
let Message::Binary(bytes) = encoded else {
1188+
panic!("BeginIngest should encode to a binary websocket frame");
1189+
};
1190+
1191+
match messages::WsMessage::from_bytes(&bytes).expect("BeginIngest should decode") {
1192+
messages::WsMessage::BeginIngest {
1193+
unique_id,
1194+
n_chunks,
1195+
} => {
1196+
assert_eq!(unique_id, 42);
1197+
assert_eq!(n_chunks, 7);
1198+
}
1199+
other => panic!("unexpected decoded message: {other:?}"),
1200+
}
1201+
}
1202+
}
1203+
11511204
async fn connect() -> anyhow::Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
11521205
let remote_host = crate::lts2_sys::lts2_client::get_remote_host();
11531206
let target = format!("wss://{}:443/shaper_gateway/ws", &remote_host);

src/rust/lqosd/src/node_manager/local_api/scheduler.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,10 @@ mod tests {
424424
assert_eq!(status.updated_unix, Some(last_seen_unix));
425425
assert!(!status.stale);
426426
assert_eq!(details.updated_unix, Some(last_seen_unix));
427-
assert!(details
428-
.details
429-
.contains(&format!("Last updated Unix: {last_seen_unix}")));
427+
assert!(
428+
details
429+
.details
430+
.contains(&format!("Last updated Unix: {last_seen_unix}"))
431+
);
430432
}
431433
}

src/rust/lqosd/src/tool_status.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,7 @@ pub fn is_scheduler_available() -> bool {
8686
/// Returns the Unix timestamp of the scheduler's last heartbeat, if one has been seen.
8787
pub fn scheduler_last_seen_unix() -> Option<u64> {
8888
let last = SCHEDULER_LAST_SEEN.load(Ordering::Relaxed);
89-
if last == 0 {
90-
None
91-
} else {
92-
Some(last)
93-
}
89+
if last == 0 { None } else { Some(last) }
9490
}
9591

9692
/// Returns the current scheduler error message, if any.

0 commit comments

Comments
 (0)