diff --git a/core/core/src/types/context/write.rs b/core/core/src/types/context/write.rs index f965eaae0fb9..a9c464ec45fd 100644 --- a/core/core/src/types/context/write.rs +++ b/core/core/src/types/context/write.rs @@ -154,10 +154,18 @@ impl WriteGenerator { // - write buffer + bs directly. if !self.exact { let fill_size = bs.len(); + // Save old buffer before pushing bs, so we can restore only the old + // buffer on failure. The caller will retry with the same bs. + let old_buffer = self.buffer.clone(); self.buffer.push(bs); let buf = self.buffer.take().collect(); - self.w.write_dyn(buf).await?; - return Ok(fill_size); + match self.w.write_dyn(buf).await { + Ok(()) => return Ok(fill_size), + Err(err) => { + self.buffer = old_buffer; + return Err(err); + } + } } // Condition: @@ -167,8 +175,12 @@ impl WriteGenerator { // Action: // - write existing buffer in chunk_size to make more rooms for writing data. if self.buffer.len() >= chunk_size { - let buf = self.buffer.take().collect(); - self.w.write_dyn(buf).await?; + let taken = self.buffer.take(); + let buf = taken.clone().collect(); + if let Err(err) = self.w.write_dyn(buf).await { + self.buffer = taken; + return Err(err); + } } // Condition @@ -190,8 +202,12 @@ impl WriteGenerator { break; } - let buf = self.buffer.take().collect(); - self.w.write_dyn(buf).await?; + let taken = self.buffer.take(); + let buf = taken.clone().collect(); + if let Err(err) = self.w.write_dyn(buf).await { + self.buffer = taken; + return Err(err); + } } self.w.close().await @@ -548,4 +564,168 @@ mod tests { Ok(()) } + + /// A mock writer that fails the first N write calls, then succeeds. + /// Used to test that WriteGenerator retains buffered data on write failure. + struct FailOnceMockWriter { + buf: Arc>>, + fail_count: Arc>, + } + + impl Write for FailOnceMockWriter { + async fn write(&mut self, bs: Buffer) -> Result<()> { + let mut fail_count = self.fail_count.lock().await; + if *fail_count > 0 { + *fail_count -= 1; + return Err( + Error::new(ErrorKind::Unexpected, "write failed (simulated)").set_temporary(), + ); + } + drop(fail_count); + + let mut buf = self.buf.lock().await; + buf.put(bs); + Ok(()) + } + + async fn close(&mut self) -> Result { + Ok(Metadata::default()) + } + + async fn abort(&mut self) -> Result<()> { + Ok(()) + } + } + + /// Test that in inexact mode, a write failure retains only the old buffered data + /// (not the current bs), so the caller can safely retry without data duplication. + #[tokio::test] + async fn test_inexact_write_failure_retains_buffer() -> Result<()> { + setup(); + + let buf = Arc::new(Mutex::new(vec![])); + let fail_count = Arc::new(Mutex::new(1usize)); + let mut writer = WriteGenerator::new( + Box::new(FailOnceMockWriter { + buf: buf.clone(), + fail_count: fail_count.clone(), + }), + Some(10), + false, // inexact mode + ); + + // Write 5 bytes (buffered, below chunk_size=10) + let data1 = Bytes::from(vec![1u8, 2, 3, 4, 5]); + let n = writer.write(data1.into()).await?; + assert_eq!(n, 5); + + // Write 10 bytes — buffer (5) + new (10) = 15 >= chunk_size, triggers flush. + // First flush will fail. + let data2 = Bytes::from(vec![6u8, 7, 8, 9, 10, 11, 12, 13, 14, 15]); + let err = writer.write(data2.clone().into()).await; + assert!(err.is_err(), "first flush should fail"); + + // On failure, only old buffer (data1) is retained; data2 is NOT absorbed. + // Caller retries with the same data2 — now the mock writer succeeds. + let n = writer.write(data2.into()).await?; + assert_eq!(n, 10); + + writer.close().await?; + + // Verify no data was lost and no data was duplicated: exactly 15 bytes. + let buf = buf.lock().await; + assert_eq!(buf.len(), 15); + assert_eq!( + &*buf, + &[1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + ); + + Ok(()) + } + + /// Test that in exact mode, a write failure retains the buffered data + /// so that a subsequent write/close succeeds without data loss. + #[tokio::test] + async fn test_exact_write_failure_retains_buffer() -> Result<()> { + setup(); + + let buf = Arc::new(Mutex::new(vec![])); + let fail_count = Arc::new(Mutex::new(1usize)); + let mut writer = WriteGenerator::new( + Box::new(FailOnceMockWriter { + buf: buf.clone(), + fail_count: fail_count.clone(), + }), + Some(10), + true, // exact mode + ); + + // Fill buffer to exactly chunk_size + let data1 = Bytes::from(vec![1u8; 10]); + let mut remaining = data1.clone(); + while !remaining.is_empty() { + let n = writer.write(remaining.clone().into()).await?; + remaining.advance(n); + } + + // Write more data — buffer is full (10 bytes), so it flushes first. + // The first flush will fail. + let data2 = Bytes::from(vec![2u8; 5]); + let err = writer.write(data2.clone().into()).await; + assert!(err.is_err(), "first flush should fail"); + + // Retry — now succeeds. The buffer should still have the original 10 bytes. + // On retry, the flush succeeds and clears the buffer, then data2 is buffered. + let mut remaining = data2; + while !remaining.is_empty() { + let n = writer.write(remaining.clone().into()).await?; + remaining.advance(n); + } + + writer.close().await?; + + // All 15 bytes should be present. + let buf = buf.lock().await; + assert_eq!(buf.len(), 15); + assert_eq!(&buf[..10], &[1u8; 10]); + assert_eq!(&buf[10..], &[2u8; 5]); + + Ok(()) + } + + /// Test that close() retains buffered data when the underlying write fails, + /// and succeeds on retry. + #[tokio::test] + async fn test_close_failure_retains_buffer() -> Result<()> { + setup(); + + let buf = Arc::new(Mutex::new(vec![])); + let fail_count = Arc::new(Mutex::new(1usize)); + let mut writer = WriteGenerator::new( + Box::new(FailOnceMockWriter { + buf: buf.clone(), + fail_count: fail_count.clone(), + }), + Some(10), + false, + ); + + // Write 5 bytes (buffered, below chunk_size) + let data = Bytes::from(vec![42u8; 5]); + let n = writer.write(data.into()).await?; + assert_eq!(n, 5); + + // First close attempt fails during flush + let err = writer.close().await; + assert!(err.is_err(), "first close should fail"); + + // Second close attempt succeeds — buffer was retained + writer.close().await?; + + let buf = buf.lock().await; + assert_eq!(buf.len(), 5); + assert_eq!(&*buf, &[42u8; 5]); + + Ok(()) + } }