Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
309 changes: 281 additions & 28 deletions crates/wind-tuic/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

use std::{
collections::HashMap,
future::Future,
net::SocketAddr,
sync::{
Arc,
Expand Down Expand Up @@ -64,6 +65,21 @@ async fn spawn_logged(label: &str, fut: impl std::future::Future<Output = eyre::
}
}

fn spawn_connection_child<F>(cancel: CancellationToken, fut: F) -> tokio::task::JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
tokio::spawn(
async move {
tokio::select! {
_ = cancel.cancelled() => {}
_ = fut => {}
}
}
.in_current_span(),
)
}

/// Wait for the connection to be authenticated. Returns `true` once an
/// [`AuthState`] is set; returns `false` if the auth timeout elapses first.
/// Callers that get `false` must drop the request.
Expand Down Expand Up @@ -412,6 +428,7 @@ pub async fn serve_connection<C, CB>(
let conn = connection.clone();
let cb = callback.clone();
let dg_cancel = acceptor_cancel.clone();
let dg_task_cancel = dg_cancel.clone();
tokio::spawn(
async move {
acceptor_loop(
Expand All @@ -421,12 +438,16 @@ pub async fn serve_connection<C, CB>(
|datagram| {
let conn = conn.clone();
let cb = cb.clone();
let task_cancel = dg_task_cancel.clone();
async move {
if datagram.len() < 2 || !is_tuic_prefix([datagram[0], datagram[1]]) {
return;
}
if conn.auth.load().is_some() {
tokio::spawn(spawn_logged("Datagram", handle_datagram(conn, datagram, cb)).in_current_span());
spawn_connection_child(
task_cancel,
spawn_logged("Datagram", handle_datagram(conn, datagram, cb)),
);
} else if let Err(e) = handle_datagram(conn, datagram, cb).await {
error!("Datagram error: {e:?}");
}
Expand All @@ -447,6 +468,7 @@ pub async fn serve_connection<C, CB>(
let conn = connection.clone();
let cb = callback.clone();
let uni_cancel = acceptor_cancel.clone();
let uni_task_cancel = uni_cancel.clone();
let h3 = h3.clone();
let active = h3_active.clone();
tokio::spawn(
Expand All @@ -460,22 +482,20 @@ pub async fn serve_connection<C, CB>(
let cb = cb.clone();
let h3 = h3.clone();
let active = active.clone();
let task_cancel = uni_task_cancel.clone();
async move {
tokio::spawn(
async move {
let mut recv = recv;
let Some(prefix) = read_prefix(&mut recv).await else { return };
let recv = wind_quic::PrefixedRecv::new(bytes::Bytes::copy_from_slice(&prefix), recv);
if is_tuic_prefix(prefix) {
if let Err(e) = handle_uni_stream(conn, recv, cb).await {
error!("Uni stream error: {e:?}");
}
} else {
route_non_tuic(&conn, h3.as_ref(), &active, H3Stream::Uni(recv));
spawn_connection_child(task_cancel, async move {
let mut recv = recv;
let Some(prefix) = read_prefix(&mut recv).await else { return };
let recv = wind_quic::PrefixedRecv::new(bytes::Bytes::copy_from_slice(&prefix), recv);
if is_tuic_prefix(prefix) {
if let Err(e) = handle_uni_stream(conn, recv, cb).await {
error!("Uni stream error: {e:?}");
}
} else {
route_non_tuic(&conn, h3.as_ref(), &active, H3Stream::Uni(recv));
}
.in_current_span(),
);
});
}
},
)
Expand All @@ -490,6 +510,7 @@ pub async fn serve_connection<C, CB>(
let conn = connection.clone();
let cb = callback.clone();
let bi_cancel = acceptor_cancel.clone();
let bi_task_cancel = bi_cancel.clone();
let h3 = h3.clone();
let active = h3_active.clone();
tokio::spawn(
Expand All @@ -503,22 +524,20 @@ pub async fn serve_connection<C, CB>(
let cb = cb.clone();
let h3 = h3.clone();
let active = active.clone();
let task_cancel = bi_task_cancel.clone();
async move {
tokio::spawn(
async move {
let mut recv = recv;
let Some(prefix) = read_prefix(&mut recv).await else { return };
let recv = wind_quic::PrefixedRecv::new(bytes::Bytes::copy_from_slice(&prefix), recv);
if is_tuic_prefix(prefix) {
if let Err(e) = handle_bi_stream(conn, send, recv, cb).await {
error!("Bi stream error: {e:?}");
}
} else {
route_non_tuic(&conn, h3.as_ref(), &active, H3Stream::Bi(send, recv));
spawn_connection_child(task_cancel, async move {
let mut recv = recv;
let Some(prefix) = read_prefix(&mut recv).await else { return };
let recv = wind_quic::PrefixedRecv::new(bytes::Bytes::copy_from_slice(&prefix), recv);
if is_tuic_prefix(prefix) {
if let Err(e) = handle_bi_stream(conn, send, recv, cb).await {
error!("Bi stream error: {e:?}");
}
} else {
route_non_tuic(&conn, h3.as_ref(), &active, H3Stream::Bi(send, recv));
}
.in_current_span(),
);
});
}
},
)
Expand All @@ -539,6 +558,9 @@ pub async fn serve_connection<C, CB>(
}
}
acceptor_cancel.cancel();
connection.udp_root_cancel.cancel();
connection.udp_sessions.invalidate_all();
connection.udp_sessions.run_pending_tasks().await;

// Drop this connection from the live-connection registry (no-op if it never
// authenticated and so was never registered).
Expand Down Expand Up @@ -1125,13 +1147,244 @@ async fn handle_dissociate<C: QuicConnection>(connection: &InboundCtx<C>, assoc_

#[cfg(test)]
mod tests {
use std::sync::atomic::AtomicUsize;
use std::{
future, io,
pin::Pin,
sync::{
Arc,
atomic::{AtomicBool, AtomicUsize},
},
task::{Context, Poll},
};

use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use wind_core::{InboundCallback, types::TargetAddr, udp::UdpStream as CoreUdpStream};
use wind_quic::{QuicRecvStream, QuicSendStream};

// Brings in Arc, Duration, Ordering, CancellationToken, QuicError, CmdType,
// and the private helpers under test (`acceptor_loop`, `is_tuic_prefix`,
// `read_prefix`).
use super::*;

struct DummyQuicStream(tokio::io::DuplexStream);

impl DummyQuicStream {
fn pair() -> (Self, Self) {
let (a, b) = tokio::io::duplex(64);
(Self(a), Self(b))
}
}

impl AsyncRead for DummyQuicStream {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

impl AsyncWrite for DummyQuicStream {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}

impl QuicSendStream for DummyQuicStream {
fn finish(&mut self) -> Result<(), QuicError> {
Ok(())
}

fn reset(&mut self, _code: u64) {}

fn id(&self) -> u64 {
0
}
}

impl QuicRecvStream for DummyQuicStream {
fn stop(&mut self, _code: u64) {}

fn id(&self) -> u64 {
0
}
}

#[derive(Clone)]
struct DummyConn;

impl QuicConnection for DummyConn {
type RecvStream = DummyQuicStream;
type SendStream = DummyQuicStream;

async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), QuicError> {
Ok(DummyQuicStream::pair())
}

async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), QuicError> {
future::pending().await
}

async fn open_uni(&self) -> Result<Self::SendStream, QuicError> {
Ok(DummyQuicStream::pair().0)
}

async fn accept_uni(&self) -> Result<Self::RecvStream, QuicError> {
future::pending().await
}

fn send_datagram(&self, _data: bytes::Bytes) -> Result<(), QuicError> {
Ok(())
}

async fn read_datagram(&self) -> Result<bytes::Bytes, QuicError> {
future::pending().await
}

fn max_datagram_size(&self) -> Option<usize> {
Some(1200)
}

async fn export_keying_material(&self, _out: &mut [u8], _label: &[u8], _context: &[u8]) -> Result<(), QuicError> {
Ok(())
}

fn close(&self, _code: u32, _reason: &[u8]) {}

async fn closed(&self) {
future::pending::<()>().await;
}
}

#[derive(Clone)]
struct HangingUdpCallback {
started: Arc<Notify>,
dropped: Arc<Notify>,
was_dropped: Arc<AtomicBool>,
}

struct NotifyOnDrop {
dropped: Arc<Notify>,
was_dropped: Arc<AtomicBool>,
}

impl Drop for NotifyOnDrop {
fn drop(&mut self) {
self.was_dropped.store(true, Ordering::SeqCst);
self.dropped.notify_waiters();
}
}

#[tokio::test]
async fn connection_child_task_is_dropped_on_cancel() {
let cancel = CancellationToken::new();
let dropped = Arc::new(Notify::new());
let was_dropped = Arc::new(AtomicBool::new(false));
let task = {
let dropped = dropped.clone();
let was_dropped = was_dropped.clone();
spawn_connection_child(cancel.clone(), async move {
let _guard = NotifyOnDrop { dropped, was_dropped };
future::pending::<()>().await;
})
};

tokio::task::yield_now().await;
assert!(!was_dropped.load(Ordering::SeqCst));

let dropped_notified = dropped.notified();
cancel.cancel();
tokio::time::timeout(Duration::from_secs(1), task)
.await
.expect("connection child task did not exit after cancellation")
.expect("connection child task panicked");
tokio::time::timeout(Duration::from_secs(1), dropped_notified)
.await
.expect("connection child future was not dropped");
assert!(was_dropped.load(Ordering::SeqCst));
}

impl InboundCallback for HangingUdpCallback {
async fn handle_tcpstream(
&self,
_target_addr: TargetAddr,
_stream: impl wind_core::tcp::AbstractTcpStream + 'static,
) -> eyre::Result<()> {
Ok(())
}

async fn handle_udpstream(&self, _udp_stream: CoreUdpStream) -> eyre::Result<()> {
let _guard = NotifyOnDrop {
dropped: self.dropped.clone(),
was_dropped: self.was_dropped.clone(),
};
self.started.notify_waiters();
future::pending().await
}
}

#[tokio::test]
async fn udp_sessions_are_cancelled_when_connection_tears_down() {
let udp_root_cancel = CancellationToken::new();
let started = Arc::new(Notify::new());
let dropped = Arc::new(Notify::new());
let was_dropped = Arc::new(AtomicBool::new(false));
let cb = HangingUdpCallback {
started: started.clone(),
dropped: dropped.clone(),
was_dropped: was_dropped.clone(),
};

let eviction_cancel = move |_k: Arc<u16>, v: UdpSession<DummyConn>, _cause| -> moka::notification::ListenerFuture {
Box::pin(async move {
v.cancel.cancel();
})
};
let ctx = Arc::new(InboundCtx {
conn: DummyConn,
conn_span: tracing::Span::none(),
auth: ArcSwapOption::from(Some(Arc::new(AuthState {
user: UserId::from("test-user"),
}))),
auth_notify: Arc::new(Notify::new()),
users: Arc::new(HashMap::new()),
auth_timeout: Duration::from_secs(1),
udp_sessions: Cache::builder()
.max_capacity(MAX_UDP_SESSIONS_PER_CONN)
.async_eviction_listener(eviction_cancel)
.build(),
udp_root_cancel: udp_root_cancel.clone(),
hooks: InboundHooks::default(),
conn_info: ConnInfo {
remote_addr: "127.0.0.1:12345".parse().unwrap(),
protocol: Protocol::Tuic,
conn_id: 1,
},
active: None,
conn_cancel: CancellationToken::new(),
});

let _stream = get_or_create_session(&ctx, 7, &cb).await.unwrap();
tokio::time::timeout(Duration::from_secs(1), started.notified())
.await
.expect("UDP callback did not start");
assert!(!was_dropped.load(Ordering::SeqCst));

udp_root_cancel.cancel();
ctx.udp_sessions.invalidate_all();
ctx.udp_sessions.run_pending_tasks().await;

tokio::time::timeout(Duration::from_secs(1), dropped.notified())
.await
.expect("UDP callback future was not dropped after teardown");
assert!(was_dropped.load(Ordering::SeqCst));
}

/// Cancellation must interrupt an accept that is parked forever. This is
/// the core of the graceful-shutdown chain: every per-connection acceptor
/// loop is blocked in `accept()` when shutdown fires, and must unstick
Expand Down
Loading