diff --git a/crates/wind-tuic/src/server/mod.rs b/crates/wind-tuic/src/server/mod.rs index 517f8f7..f4666aa 100644 --- a/crates/wind-tuic/src/server/mod.rs +++ b/crates/wind-tuic/src/server/mod.rs @@ -13,6 +13,7 @@ use std::{ collections::HashMap, + future::Future, net::SocketAddr, sync::{ Arc, @@ -64,6 +65,21 @@ async fn spawn_logged(label: &str, fut: impl std::future::Future(cancel: CancellationToken, fut: F) -> tokio::task::JoinHandle<()> +where + F: Future + Send + 'static, +{ + tokio::spawn( + async move { + tokio::select! { + _ = cancel.cancelled() => {} + _ = fut => {} + } + } + .in_current_span(), + ) +} + /// Wait for the connection to be authenticated. Returns `true` once an /// [`AuthState`] is set; returns `false` if the auth timeout elapses first. /// Callers that get `false` must drop the request. @@ -412,6 +428,7 @@ pub async fn serve_connection( let conn = connection.clone(); let cb = callback.clone(); let dg_cancel = acceptor_cancel.clone(); + let dg_task_cancel = dg_cancel.clone(); tokio::spawn( async move { acceptor_loop( @@ -421,12 +438,16 @@ pub async fn serve_connection( |datagram| { let conn = conn.clone(); let cb = cb.clone(); + let task_cancel = dg_task_cancel.clone(); async move { if datagram.len() < 2 || !is_tuic_prefix([datagram[0], datagram[1]]) { return; } if conn.auth.load().is_some() { - tokio::spawn(spawn_logged("Datagram", handle_datagram(conn, datagram, cb)).in_current_span()); + spawn_connection_child( + task_cancel, + spawn_logged("Datagram", handle_datagram(conn, datagram, cb)), + ); } else if let Err(e) = handle_datagram(conn, datagram, cb).await { error!("Datagram error: {e:?}"); } @@ -447,6 +468,7 @@ pub async fn serve_connection( let conn = connection.clone(); let cb = callback.clone(); let uni_cancel = acceptor_cancel.clone(); + let uni_task_cancel = uni_cancel.clone(); let h3 = h3.clone(); let active = h3_active.clone(); tokio::spawn( @@ -460,22 +482,20 @@ pub async fn serve_connection( let cb = cb.clone(); let h3 = h3.clone(); let active = active.clone(); + let task_cancel = uni_task_cancel.clone(); async move { - tokio::spawn( - async move { - let mut recv = recv; - let Some(prefix) = read_prefix(&mut recv).await else { return }; - let recv = wind_quic::PrefixedRecv::new(bytes::Bytes::copy_from_slice(&prefix), recv); - if is_tuic_prefix(prefix) { - if let Err(e) = handle_uni_stream(conn, recv, cb).await { - error!("Uni stream error: {e:?}"); - } - } else { - route_non_tuic(&conn, h3.as_ref(), &active, H3Stream::Uni(recv)); + spawn_connection_child(task_cancel, async move { + let mut recv = recv; + let Some(prefix) = read_prefix(&mut recv).await else { return }; + let recv = wind_quic::PrefixedRecv::new(bytes::Bytes::copy_from_slice(&prefix), recv); + if is_tuic_prefix(prefix) { + if let Err(e) = handle_uni_stream(conn, recv, cb).await { + error!("Uni stream error: {e:?}"); } + } else { + route_non_tuic(&conn, h3.as_ref(), &active, H3Stream::Uni(recv)); } - .in_current_span(), - ); + }); } }, ) @@ -490,6 +510,7 @@ pub async fn serve_connection( let conn = connection.clone(); let cb = callback.clone(); let bi_cancel = acceptor_cancel.clone(); + let bi_task_cancel = bi_cancel.clone(); let h3 = h3.clone(); let active = h3_active.clone(); tokio::spawn( @@ -503,22 +524,20 @@ pub async fn serve_connection( let cb = cb.clone(); let h3 = h3.clone(); let active = active.clone(); + let task_cancel = bi_task_cancel.clone(); async move { - tokio::spawn( - async move { - let mut recv = recv; - let Some(prefix) = read_prefix(&mut recv).await else { return }; - let recv = wind_quic::PrefixedRecv::new(bytes::Bytes::copy_from_slice(&prefix), recv); - if is_tuic_prefix(prefix) { - if let Err(e) = handle_bi_stream(conn, send, recv, cb).await { - error!("Bi stream error: {e:?}"); - } - } else { - route_non_tuic(&conn, h3.as_ref(), &active, H3Stream::Bi(send, recv)); + spawn_connection_child(task_cancel, async move { + let mut recv = recv; + let Some(prefix) = read_prefix(&mut recv).await else { return }; + let recv = wind_quic::PrefixedRecv::new(bytes::Bytes::copy_from_slice(&prefix), recv); + if is_tuic_prefix(prefix) { + if let Err(e) = handle_bi_stream(conn, send, recv, cb).await { + error!("Bi stream error: {e:?}"); } + } else { + route_non_tuic(&conn, h3.as_ref(), &active, H3Stream::Bi(send, recv)); } - .in_current_span(), - ); + }); } }, ) @@ -539,6 +558,9 @@ pub async fn serve_connection( } } acceptor_cancel.cancel(); + connection.udp_root_cancel.cancel(); + connection.udp_sessions.invalidate_all(); + connection.udp_sessions.run_pending_tasks().await; // Drop this connection from the live-connection registry (no-op if it never // authenticated and so was never registered). @@ -1125,13 +1147,244 @@ async fn handle_dissociate(connection: &InboundCtx, assoc_ #[cfg(test)] mod tests { - use std::sync::atomic::AtomicUsize; + use std::{ + future, io, + pin::Pin, + sync::{ + Arc, + atomic::{AtomicBool, AtomicUsize}, + }, + task::{Context, Poll}, + }; + + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use wind_core::{InboundCallback, types::TargetAddr, udp::UdpStream as CoreUdpStream}; + use wind_quic::{QuicRecvStream, QuicSendStream}; // Brings in Arc, Duration, Ordering, CancellationToken, QuicError, CmdType, // and the private helpers under test (`acceptor_loop`, `is_tuic_prefix`, // `read_prefix`). use super::*; + struct DummyQuicStream(tokio::io::DuplexStream); + + impl DummyQuicStream { + fn pair() -> (Self, Self) { + let (a, b) = tokio::io::duplex(64); + (Self(a), Self(b)) + } + } + + impl AsyncRead for DummyQuicStream { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } + } + + impl AsyncWrite for DummyQuicStream { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } + } + + impl QuicSendStream for DummyQuicStream { + fn finish(&mut self) -> Result<(), QuicError> { + Ok(()) + } + + fn reset(&mut self, _code: u64) {} + + fn id(&self) -> u64 { + 0 + } + } + + impl QuicRecvStream for DummyQuicStream { + fn stop(&mut self, _code: u64) {} + + fn id(&self) -> u64 { + 0 + } + } + + #[derive(Clone)] + struct DummyConn; + + impl QuicConnection for DummyConn { + type RecvStream = DummyQuicStream; + type SendStream = DummyQuicStream; + + async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), QuicError> { + Ok(DummyQuicStream::pair()) + } + + async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), QuicError> { + future::pending().await + } + + async fn open_uni(&self) -> Result { + Ok(DummyQuicStream::pair().0) + } + + async fn accept_uni(&self) -> Result { + future::pending().await + } + + fn send_datagram(&self, _data: bytes::Bytes) -> Result<(), QuicError> { + Ok(()) + } + + async fn read_datagram(&self) -> Result { + future::pending().await + } + + fn max_datagram_size(&self) -> Option { + Some(1200) + } + + async fn export_keying_material(&self, _out: &mut [u8], _label: &[u8], _context: &[u8]) -> Result<(), QuicError> { + Ok(()) + } + + fn close(&self, _code: u32, _reason: &[u8]) {} + + async fn closed(&self) { + future::pending::<()>().await; + } + } + + #[derive(Clone)] + struct HangingUdpCallback { + started: Arc, + dropped: Arc, + was_dropped: Arc, + } + + struct NotifyOnDrop { + dropped: Arc, + was_dropped: Arc, + } + + impl Drop for NotifyOnDrop { + fn drop(&mut self) { + self.was_dropped.store(true, Ordering::SeqCst); + self.dropped.notify_waiters(); + } + } + + #[tokio::test] + async fn connection_child_task_is_dropped_on_cancel() { + let cancel = CancellationToken::new(); + let dropped = Arc::new(Notify::new()); + let was_dropped = Arc::new(AtomicBool::new(false)); + let task = { + let dropped = dropped.clone(); + let was_dropped = was_dropped.clone(); + spawn_connection_child(cancel.clone(), async move { + let _guard = NotifyOnDrop { dropped, was_dropped }; + future::pending::<()>().await; + }) + }; + + tokio::task::yield_now().await; + assert!(!was_dropped.load(Ordering::SeqCst)); + + let dropped_notified = dropped.notified(); + cancel.cancel(); + tokio::time::timeout(Duration::from_secs(1), task) + .await + .expect("connection child task did not exit after cancellation") + .expect("connection child task panicked"); + tokio::time::timeout(Duration::from_secs(1), dropped_notified) + .await + .expect("connection child future was not dropped"); + assert!(was_dropped.load(Ordering::SeqCst)); + } + + impl InboundCallback for HangingUdpCallback { + async fn handle_tcpstream( + &self, + _target_addr: TargetAddr, + _stream: impl wind_core::tcp::AbstractTcpStream + 'static, + ) -> eyre::Result<()> { + Ok(()) + } + + async fn handle_udpstream(&self, _udp_stream: CoreUdpStream) -> eyre::Result<()> { + let _guard = NotifyOnDrop { + dropped: self.dropped.clone(), + was_dropped: self.was_dropped.clone(), + }; + self.started.notify_waiters(); + future::pending().await + } + } + + #[tokio::test] + async fn udp_sessions_are_cancelled_when_connection_tears_down() { + let udp_root_cancel = CancellationToken::new(); + let started = Arc::new(Notify::new()); + let dropped = Arc::new(Notify::new()); + let was_dropped = Arc::new(AtomicBool::new(false)); + let cb = HangingUdpCallback { + started: started.clone(), + dropped: dropped.clone(), + was_dropped: was_dropped.clone(), + }; + + let eviction_cancel = move |_k: Arc, v: UdpSession, _cause| -> moka::notification::ListenerFuture { + Box::pin(async move { + v.cancel.cancel(); + }) + }; + let ctx = Arc::new(InboundCtx { + conn: DummyConn, + conn_span: tracing::Span::none(), + auth: ArcSwapOption::from(Some(Arc::new(AuthState { + user: UserId::from("test-user"), + }))), + auth_notify: Arc::new(Notify::new()), + users: Arc::new(HashMap::new()), + auth_timeout: Duration::from_secs(1), + udp_sessions: Cache::builder() + .max_capacity(MAX_UDP_SESSIONS_PER_CONN) + .async_eviction_listener(eviction_cancel) + .build(), + udp_root_cancel: udp_root_cancel.clone(), + hooks: InboundHooks::default(), + conn_info: ConnInfo { + remote_addr: "127.0.0.1:12345".parse().unwrap(), + protocol: Protocol::Tuic, + conn_id: 1, + }, + active: None, + conn_cancel: CancellationToken::new(), + }); + + let _stream = get_or_create_session(&ctx, 7, &cb).await.unwrap(); + tokio::time::timeout(Duration::from_secs(1), started.notified()) + .await + .expect("UDP callback did not start"); + assert!(!was_dropped.load(Ordering::SeqCst)); + + udp_root_cancel.cancel(); + ctx.udp_sessions.invalidate_all(); + ctx.udp_sessions.run_pending_tasks().await; + + tokio::time::timeout(Duration::from_secs(1), dropped.notified()) + .await + .expect("UDP callback future was not dropped after teardown"); + assert!(was_dropped.load(Ordering::SeqCst)); + } + /// Cancellation must interrupt an accept that is parked forever. This is /// the core of the graceful-shutdown chain: every per-connection acceptor /// loop is blocked in `accept()` when shutdown fires, and must unstick