From d6bfa0fb7e4d8089d01975feab16287a0e8840b8 Mon Sep 17 00:00:00 2001 From: iHsin Date: Thu, 25 Jun 2026 19:05:41 +0800 Subject: [PATCH] fix(tuic): clean up connection child resources Cancel spawned per-connection stream/datagram handlers when the parent connection ends, so TCP relay futures cannot outlive the QUIC connection. Also cancel and invalidate UDP association sessions during connection teardown so outbound UDP bridge tasks and sockets are released promptly. Assisted-by: Codex --- crates/wind-tuic/src/server/mod.rs | 309 ++++++++++++++++++++++++++--- 1 file changed, 281 insertions(+), 28 deletions(-) diff --git a/crates/wind-tuic/src/server/mod.rs b/crates/wind-tuic/src/server/mod.rs index 517f8f7..f4666aa 100644 --- a/crates/wind-tuic/src/server/mod.rs +++ b/crates/wind-tuic/src/server/mod.rs @@ -13,6 +13,7 @@ use std::{ collections::HashMap, + future::Future, net::SocketAddr, sync::{ Arc, @@ -64,6 +65,21 @@ async fn spawn_logged(label: &str, fut: impl std::future::Future(cancel: CancellationToken, fut: F) -> tokio::task::JoinHandle<()> +where + F: Future + Send + 'static, +{ + tokio::spawn( + async move { + tokio::select! { + _ = cancel.cancelled() => {} + _ = fut => {} + } + } + .in_current_span(), + ) +} + /// Wait for the connection to be authenticated. Returns `true` once an /// [`AuthState`] is set; returns `false` if the auth timeout elapses first. /// Callers that get `false` must drop the request. @@ -412,6 +428,7 @@ pub async fn serve_connection( let conn = connection.clone(); let cb = callback.clone(); let dg_cancel = acceptor_cancel.clone(); + let dg_task_cancel = dg_cancel.clone(); tokio::spawn( async move { acceptor_loop( @@ -421,12 +438,16 @@ pub async fn serve_connection( |datagram| { let conn = conn.clone(); let cb = cb.clone(); + let task_cancel = dg_task_cancel.clone(); async move { if datagram.len() < 2 || !is_tuic_prefix([datagram[0], datagram[1]]) { return; } if conn.auth.load().is_some() { - tokio::spawn(spawn_logged("Datagram", handle_datagram(conn, datagram, cb)).in_current_span()); + spawn_connection_child( + task_cancel, + spawn_logged("Datagram", handle_datagram(conn, datagram, cb)), + ); } else if let Err(e) = handle_datagram(conn, datagram, cb).await { error!("Datagram error: {e:?}"); } @@ -447,6 +468,7 @@ pub async fn serve_connection( let conn = connection.clone(); let cb = callback.clone(); let uni_cancel = acceptor_cancel.clone(); + let uni_task_cancel = uni_cancel.clone(); let h3 = h3.clone(); let active = h3_active.clone(); tokio::spawn( @@ -460,22 +482,20 @@ pub async fn serve_connection( let cb = cb.clone(); let h3 = h3.clone(); let active = active.clone(); + let task_cancel = uni_task_cancel.clone(); async move { - tokio::spawn( - async move { - let mut recv = recv; - let Some(prefix) = read_prefix(&mut recv).await else { return }; - let recv = wind_quic::PrefixedRecv::new(bytes::Bytes::copy_from_slice(&prefix), recv); - if is_tuic_prefix(prefix) { - if let Err(e) = handle_uni_stream(conn, recv, cb).await { - error!("Uni stream error: {e:?}"); - } - } else { - route_non_tuic(&conn, h3.as_ref(), &active, H3Stream::Uni(recv)); + spawn_connection_child(task_cancel, async move { + let mut recv = recv; + let Some(prefix) = read_prefix(&mut recv).await else { return }; + let recv = wind_quic::PrefixedRecv::new(bytes::Bytes::copy_from_slice(&prefix), recv); + if is_tuic_prefix(prefix) { + if let Err(e) = handle_uni_stream(conn, recv, cb).await { + error!("Uni stream error: {e:?}"); } + } else { + route_non_tuic(&conn, h3.as_ref(), &active, H3Stream::Uni(recv)); } - .in_current_span(), - ); + }); } }, ) @@ -490,6 +510,7 @@ pub async fn serve_connection( let conn = connection.clone(); let cb = callback.clone(); let bi_cancel = acceptor_cancel.clone(); + let bi_task_cancel = bi_cancel.clone(); let h3 = h3.clone(); let active = h3_active.clone(); tokio::spawn( @@ -503,22 +524,20 @@ pub async fn serve_connection( let cb = cb.clone(); let h3 = h3.clone(); let active = active.clone(); + let task_cancel = bi_task_cancel.clone(); async move { - tokio::spawn( - async move { - let mut recv = recv; - let Some(prefix) = read_prefix(&mut recv).await else { return }; - let recv = wind_quic::PrefixedRecv::new(bytes::Bytes::copy_from_slice(&prefix), recv); - if is_tuic_prefix(prefix) { - if let Err(e) = handle_bi_stream(conn, send, recv, cb).await { - error!("Bi stream error: {e:?}"); - } - } else { - route_non_tuic(&conn, h3.as_ref(), &active, H3Stream::Bi(send, recv)); + spawn_connection_child(task_cancel, async move { + let mut recv = recv; + let Some(prefix) = read_prefix(&mut recv).await else { return }; + let recv = wind_quic::PrefixedRecv::new(bytes::Bytes::copy_from_slice(&prefix), recv); + if is_tuic_prefix(prefix) { + if let Err(e) = handle_bi_stream(conn, send, recv, cb).await { + error!("Bi stream error: {e:?}"); } + } else { + route_non_tuic(&conn, h3.as_ref(), &active, H3Stream::Bi(send, recv)); } - .in_current_span(), - ); + }); } }, ) @@ -539,6 +558,9 @@ pub async fn serve_connection( } } acceptor_cancel.cancel(); + connection.udp_root_cancel.cancel(); + connection.udp_sessions.invalidate_all(); + connection.udp_sessions.run_pending_tasks().await; // Drop this connection from the live-connection registry (no-op if it never // authenticated and so was never registered). @@ -1125,13 +1147,244 @@ async fn handle_dissociate(connection: &InboundCtx, assoc_ #[cfg(test)] mod tests { - use std::sync::atomic::AtomicUsize; + use std::{ + future, io, + pin::Pin, + sync::{ + Arc, + atomic::{AtomicBool, AtomicUsize}, + }, + task::{Context, Poll}, + }; + + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use wind_core::{InboundCallback, types::TargetAddr, udp::UdpStream as CoreUdpStream}; + use wind_quic::{QuicRecvStream, QuicSendStream}; // Brings in Arc, Duration, Ordering, CancellationToken, QuicError, CmdType, // and the private helpers under test (`acceptor_loop`, `is_tuic_prefix`, // `read_prefix`). use super::*; + struct DummyQuicStream(tokio::io::DuplexStream); + + impl DummyQuicStream { + fn pair() -> (Self, Self) { + let (a, b) = tokio::io::duplex(64); + (Self(a), Self(b)) + } + } + + impl AsyncRead for DummyQuicStream { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } + } + + impl AsyncWrite for DummyQuicStream { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } + } + + impl QuicSendStream for DummyQuicStream { + fn finish(&mut self) -> Result<(), QuicError> { + Ok(()) + } + + fn reset(&mut self, _code: u64) {} + + fn id(&self) -> u64 { + 0 + } + } + + impl QuicRecvStream for DummyQuicStream { + fn stop(&mut self, _code: u64) {} + + fn id(&self) -> u64 { + 0 + } + } + + #[derive(Clone)] + struct DummyConn; + + impl QuicConnection for DummyConn { + type RecvStream = DummyQuicStream; + type SendStream = DummyQuicStream; + + async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), QuicError> { + Ok(DummyQuicStream::pair()) + } + + async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), QuicError> { + future::pending().await + } + + async fn open_uni(&self) -> Result { + Ok(DummyQuicStream::pair().0) + } + + async fn accept_uni(&self) -> Result { + future::pending().await + } + + fn send_datagram(&self, _data: bytes::Bytes) -> Result<(), QuicError> { + Ok(()) + } + + async fn read_datagram(&self) -> Result { + future::pending().await + } + + fn max_datagram_size(&self) -> Option { + Some(1200) + } + + async fn export_keying_material(&self, _out: &mut [u8], _label: &[u8], _context: &[u8]) -> Result<(), QuicError> { + Ok(()) + } + + fn close(&self, _code: u32, _reason: &[u8]) {} + + async fn closed(&self) { + future::pending::<()>().await; + } + } + + #[derive(Clone)] + struct HangingUdpCallback { + started: Arc, + dropped: Arc, + was_dropped: Arc, + } + + struct NotifyOnDrop { + dropped: Arc, + was_dropped: Arc, + } + + impl Drop for NotifyOnDrop { + fn drop(&mut self) { + self.was_dropped.store(true, Ordering::SeqCst); + self.dropped.notify_waiters(); + } + } + + #[tokio::test] + async fn connection_child_task_is_dropped_on_cancel() { + let cancel = CancellationToken::new(); + let dropped = Arc::new(Notify::new()); + let was_dropped = Arc::new(AtomicBool::new(false)); + let task = { + let dropped = dropped.clone(); + let was_dropped = was_dropped.clone(); + spawn_connection_child(cancel.clone(), async move { + let _guard = NotifyOnDrop { dropped, was_dropped }; + future::pending::<()>().await; + }) + }; + + tokio::task::yield_now().await; + assert!(!was_dropped.load(Ordering::SeqCst)); + + let dropped_notified = dropped.notified(); + cancel.cancel(); + tokio::time::timeout(Duration::from_secs(1), task) + .await + .expect("connection child task did not exit after cancellation") + .expect("connection child task panicked"); + tokio::time::timeout(Duration::from_secs(1), dropped_notified) + .await + .expect("connection child future was not dropped"); + assert!(was_dropped.load(Ordering::SeqCst)); + } + + impl InboundCallback for HangingUdpCallback { + async fn handle_tcpstream( + &self, + _target_addr: TargetAddr, + _stream: impl wind_core::tcp::AbstractTcpStream + 'static, + ) -> eyre::Result<()> { + Ok(()) + } + + async fn handle_udpstream(&self, _udp_stream: CoreUdpStream) -> eyre::Result<()> { + let _guard = NotifyOnDrop { + dropped: self.dropped.clone(), + was_dropped: self.was_dropped.clone(), + }; + self.started.notify_waiters(); + future::pending().await + } + } + + #[tokio::test] + async fn udp_sessions_are_cancelled_when_connection_tears_down() { + let udp_root_cancel = CancellationToken::new(); + let started = Arc::new(Notify::new()); + let dropped = Arc::new(Notify::new()); + let was_dropped = Arc::new(AtomicBool::new(false)); + let cb = HangingUdpCallback { + started: started.clone(), + dropped: dropped.clone(), + was_dropped: was_dropped.clone(), + }; + + let eviction_cancel = move |_k: Arc, v: UdpSession, _cause| -> moka::notification::ListenerFuture { + Box::pin(async move { + v.cancel.cancel(); + }) + }; + let ctx = Arc::new(InboundCtx { + conn: DummyConn, + conn_span: tracing::Span::none(), + auth: ArcSwapOption::from(Some(Arc::new(AuthState { + user: UserId::from("test-user"), + }))), + auth_notify: Arc::new(Notify::new()), + users: Arc::new(HashMap::new()), + auth_timeout: Duration::from_secs(1), + udp_sessions: Cache::builder() + .max_capacity(MAX_UDP_SESSIONS_PER_CONN) + .async_eviction_listener(eviction_cancel) + .build(), + udp_root_cancel: udp_root_cancel.clone(), + hooks: InboundHooks::default(), + conn_info: ConnInfo { + remote_addr: "127.0.0.1:12345".parse().unwrap(), + protocol: Protocol::Tuic, + conn_id: 1, + }, + active: None, + conn_cancel: CancellationToken::new(), + }); + + let _stream = get_or_create_session(&ctx, 7, &cb).await.unwrap(); + tokio::time::timeout(Duration::from_secs(1), started.notified()) + .await + .expect("UDP callback did not start"); + assert!(!was_dropped.load(Ordering::SeqCst)); + + udp_root_cancel.cancel(); + ctx.udp_sessions.invalidate_all(); + ctx.udp_sessions.run_pending_tasks().await; + + tokio::time::timeout(Duration::from_secs(1), dropped.notified()) + .await + .expect("UDP callback future was not dropped after teardown"); + assert!(was_dropped.load(Ordering::SeqCst)); + } + /// Cancellation must interrupt an accept that is parked forever. This is /// the core of the graceful-shutdown chain: every per-connection acceptor /// loop is blocked in `accept()` when shutdown fires, and must unstick