11use anyhow:: { Result , bail} ;
2- use futures_util:: { StreamExt , sink:: SinkExt } ;
2+ use futures_util:: { StreamExt , sink:: SinkExt , stream :: SplitSink } ;
33use serde:: { Deserialize , Serialize } ;
44use std:: { collections:: HashMap , time:: Duration } ;
55use tokio:: net:: TcpStream ;
@@ -23,6 +23,11 @@ pub use messages::{
2323} ;
2424
2525const SHAPER_VERSION_STRING : & str = include_str ! ( "../../../../VERSION_STRING" ) ;
26+ const CONTROL_CHANNEL_QUEUE_DEPTH : usize = 256 ;
27+ const CONNECTION_COMMAND_QUEUE_DEPTH : usize = 1024 ;
28+ const SOCKET_SENDER_QUEUE_DEPTH : usize = 32 ;
29+
30+ type ControlSocketWriter = SplitSink < WebSocketStream < MaybeTlsStream < TcpStream > > , Message > ;
2631
2732fn current_shaper_version ( ) -> Option < String > {
2833 let version = SHAPER_VERSION_STRING . trim ( ) ;
@@ -133,7 +138,7 @@ pub struct ControlChannelBuilder {
133138
134139pub fn init_control_channel ( ) -> Result < ControlChannelBuilder > {
135140 // Doing this two-step: make the channel here and then spawn the task
136- let ( tx, rx) = tokio:: sync:: mpsc:: channel ( 256 ) ;
141+ let ( tx, rx) = tokio:: sync:: mpsc:: channel ( CONTROL_CHANNEL_QUEUE_DEPTH ) ;
137142 Ok ( ControlChannelBuilder { tx, rx } )
138143}
139144
@@ -148,7 +153,7 @@ pub async fn start_control_channel(builder: ControlChannelBuilder) -> Result<()>
148153
149154async fn control_channel_loop ( mut builder : ControlChannelBuilder ) -> Result < ( ) > {
150155 // Handle the persistent channel to Insight here
151- let ( tx, rx) = tokio:: sync:: mpsc:: channel :: < ConnectionCommand > ( 1024 ) ;
156+ let ( tx, rx) = tokio:: sync:: mpsc:: channel :: < ConnectionCommand > ( CONNECTION_COMMAND_QUEUE_DEPTH ) ;
152157 tokio:: spawn ( persistent_connection ( rx) ) ;
153158
154159 while let Some ( cmd) = builder. rx . recv ( ) . await {
@@ -273,6 +278,72 @@ const TCP_TIMEOUT: Duration = Duration::from_secs(30);
273278// Prevent unbounded growth while waiting for Welcome
274279const MAX_PENDING_CHATBOT_MESSAGES : usize = 256 ;
275280
281+ fn encode_ws_binary (
282+ message : & messages:: WsMessage ,
283+ label : & ' static str ,
284+ ) -> std:: result:: Result < Message , ( ) > {
285+ let Ok ( ( _, _, bytes) ) = message. to_bytes ( ) else {
286+ error ! ( "Failed to serialize {label} message" ) ;
287+ return Err ( ( ) ) ;
288+ } ;
289+ Ok ( Message :: Binary ( bytes) )
290+ }
291+
292+ async fn send_ws_message_now (
293+ write : & mut ControlSocketWriter ,
294+ message : Message ,
295+ label : & ' static str ,
296+ ) -> std:: result:: Result < ( ) , ( ) > {
297+ let Ok ( Ok ( _) ) = timeout ( TCP_TIMEOUT , write. send ( message) ) . await else {
298+ error ! ( "Failed to send {label} message" ) ;
299+ return Err ( ( ) ) ;
300+ } ;
301+ Ok ( ( ) )
302+ }
303+
304+ async fn send_ingest_batch (
305+ write : & mut ControlSocketWriter ,
306+ serial : usize ,
307+ chunks : Vec < Vec < u8 > > ,
308+ ) -> std:: result:: Result < usize , ( ) > {
309+ let n_chunks = chunks. len ( ) ;
310+ let byte_count = chunks. iter ( ) . map ( |chunk| chunk. len ( ) ) . sum :: < usize > ( ) ;
311+
312+ let begin = messages:: WsMessage :: BeginIngest {
313+ unique_id : serial as u64 ,
314+ n_chunks : n_chunks as u64 ,
315+ } ;
316+ send_ws_message_now (
317+ write,
318+ encode_ws_binary ( & begin, "BeginIngest" ) ?,
319+ "BeginIngest" ,
320+ )
321+ . await ?;
322+
323+ for ( chunk_index, chunk) in chunks. into_iter ( ) . enumerate ( ) {
324+ let ingest_chunk = messages:: WsMessage :: IngestChunk {
325+ unique_id : serial as u64 ,
326+ chunk : chunk_index as u64 ,
327+ n_chunks : n_chunks as u64 ,
328+ data : chunk,
329+ } ;
330+ send_ws_message_now (
331+ write,
332+ encode_ws_binary ( & ingest_chunk, "IngestChunk" ) ?,
333+ "IngestChunk" ,
334+ )
335+ . await ?;
336+ }
337+
338+ let end = messages:: WsMessage :: EndIngest {
339+ unique_id : serial as u64 ,
340+ n_chunks : n_chunks as u64 ,
341+ } ;
342+ send_ws_message_now ( write, encode_ws_binary ( & end, "EndIngest" ) ?, "EndIngest" ) . await ?;
343+
344+ Ok ( byte_count)
345+ }
346+
276347async fn persistent_connection (
277348 mut rx : tokio:: sync:: mpsc:: Receiver < ConnectionCommand > ,
278349) -> std:: result:: Result < ( ) , String > {
@@ -303,7 +374,7 @@ async fn persistent_connection(
303374 // Split the socket
304375 let ( mut write, mut read) = socket. split ( ) ;
305376 let ( socket_sender_tx, mut socket_sender_rx) =
306- tokio:: sync:: mpsc:: channel :: < Message > ( 32 ) ;
377+ tokio:: sync:: mpsc:: channel :: < Message > ( SOCKET_SENDER_QUEUE_DEPTH ) ;
307378 let mut ping_interval = tokio:: time:: interval ( Duration :: from_secs ( 10 ) ) ;
308379 let mut license_interval = tokio:: time:: interval ( Duration :: from_secs ( 60 * 15 ) ) ; // 15 minutes
309380 let mut pending_history: HashMap <
@@ -363,61 +434,11 @@ async fn persistent_connection(
363434 info!( "Not permitted to send chunks yet" ) ;
364435 continue ' message_pump;
365436 }
366- let n_chunks = chunks. len( ) ;
367- let byte_count = chunks. iter( ) . map( |c| c. len( ) ) . sum:: <usize >( ) ;
368-
369- // Send BeginIngest
370- let Ok ( ( _, _, bytes) ) = messages:: WsMessage :: BeginIngest { unique_id: serial as u64 , n_chunks: n_chunks as u64 } . to_bytes( ) else {
371- error!( "Failed to serialize BeginIngest message" ) ;
372- break ' message_pump;
373- } ;
374- if let Err ( e) = socket_sender_tx. try_send( Message :: Binary ( bytes) ) {
375- match e {
376- TrySendError :: Full ( _) => {
377- warn!( "Send unavailable: BeginIngest queue full; dropping message" ) ;
378- }
379- TrySendError :: Closed ( _) => {
380- error!( "Failed to send BeginIngest message: channel closed" ) ;
381- break ' message_pump;
382- }
383- }
384- }
385-
386- // Submit Each Chunk
387- for ( i, chunk) in chunks. into_iter( ) . enumerate( ) {
388- let Ok ( ( _, _, bytes) ) = messages:: WsMessage :: IngestChunk { unique_id: serial as u64 , chunk: i as u64 , n_chunks: n_chunks as u64 , data: chunk } . to_bytes( ) else {
389- error!( "Failed to serialize IngestChunk message" ) ;
390- break ' message_pump;
391- } ;
392- if let Err ( e) = socket_sender_tx. try_send( Message :: Binary ( bytes) ) {
393- match e {
394- TrySendError :: Full ( _) => {
395- warn!( "Send unavailable: IngestChunk queue full; dropping chunk" ) ;
396- }
397- TrySendError :: Closed ( _) => {
398- error!( "Failed to send IngestChunk message: channel closed" ) ;
399- break ' message_pump;
400- }
401- }
402- }
403- }
404-
405- // Send EndIngest
406- let Ok ( ( _, _, bytes) ) = messages:: WsMessage :: EndIngest { unique_id: serial as u64 , n_chunks: n_chunks as u64 } . to_bytes( ) else {
407- error!( "Failed to serialize EndIngest message" ) ;
437+ let Ok ( byte_count) =
438+ send_ingest_batch( & mut write, serial, chunks) . await
439+ else {
408440 break ' message_pump;
409441 } ;
410- if let Err ( e) = socket_sender_tx. try_send( Message :: Binary ( bytes) ) {
411- match e {
412- TrySendError :: Full ( _) => {
413- warn!( "Send unavailable: EndIngest queue full; dropping message" ) ;
414- }
415- TrySendError :: Closed ( _) => {
416- error!( "Failed to send EndIngest message: channel closed" ) ;
417- break ' message_pump;
418- }
419- }
420- }
421442 debug!( "Submitted {} bytes for ingestion" , byte_count) ;
422443 }
423444 Some ( ConnectionCommand :: FetchHistory { request, responder } ) => {
@@ -1148,6 +1169,38 @@ async fn persistent_connection(
11481169 }
11491170}
11501171
1172+ #[ cfg( test) ]
1173+ mod tests {
1174+ use super :: * ;
1175+
1176+ #[ test]
1177+ fn encode_ws_binary_round_trips_begin_ingest ( ) {
1178+ let encoded = encode_ws_binary (
1179+ & messages:: WsMessage :: BeginIngest {
1180+ unique_id : 42 ,
1181+ n_chunks : 7 ,
1182+ } ,
1183+ "BeginIngest" ,
1184+ )
1185+ . expect ( "BeginIngest should encode" ) ;
1186+
1187+ let Message :: Binary ( bytes) = encoded else {
1188+ panic ! ( "BeginIngest should encode to a binary websocket frame" ) ;
1189+ } ;
1190+
1191+ match messages:: WsMessage :: from_bytes ( & bytes) . expect ( "BeginIngest should decode" ) {
1192+ messages:: WsMessage :: BeginIngest {
1193+ unique_id,
1194+ n_chunks,
1195+ } => {
1196+ assert_eq ! ( unique_id, 42 ) ;
1197+ assert_eq ! ( n_chunks, 7 ) ;
1198+ }
1199+ other => panic ! ( "unexpected decoded message: {other:?}" ) ,
1200+ }
1201+ }
1202+ }
1203+
11511204async fn connect ( ) -> anyhow:: Result < WebSocketStream < MaybeTlsStream < TcpStream > > > {
11521205 let remote_host = crate :: lts2_sys:: lts2_client:: get_remote_host ( ) ;
11531206 let target = format ! ( "wss://{}:443/shaper_gateway/ws" , & remote_host) ;
0 commit comments