use std::future::Future; #[cfg(unix)] use std::os::unix::io::{AsRawFd, RawFd}; #[cfg(windows)] use std::os::windows::io::{AsRawSocket, RawSocket}; use std::pin::Pin; #[cfg(feature = "early-data")] use std::task::Waker; use std::task::{Context, Poll}; use std::{ io::{self, BufRead as _}, sync::Arc, }; use rustls::{pki_types::ServerName, ClientConfig, ClientConnection}; use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use crate::common::{IoSession, MidHandshake, Stream, TlsState}; /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. #[derive(Clone)] pub struct TlsConnector { inner: Arc, #[cfg(feature = "early-data")] early_data: bool, } impl TlsConnector { /// Enable 0-RTT. /// /// If you want to use 0-RTT, /// You must also set `ClientConfig.enable_early_data` to `true`. #[cfg(feature = "early-data")] pub fn early_data(mut self, flag: bool) -> Self { self.early_data = flag; self } #[inline] pub fn connect(&self, domain: ServerName<'static>, stream: IO) -> Connect where IO: AsyncRead + AsyncWrite + Unpin, { self.connect_impl(domain, stream, None, |_| ()) } #[inline] pub fn connect_with(&self, domain: ServerName<'static>, stream: IO, f: F) -> Connect where IO: AsyncRead + AsyncWrite + Unpin, F: FnOnce(&mut ClientConnection), { self.connect_impl(domain, stream, None, f) } fn connect_impl( &self, domain: ServerName<'static>, stream: IO, alpn_protocols: Option>>, f: F, ) -> Connect where IO: AsyncRead + AsyncWrite + Unpin, F: FnOnce(&mut ClientConnection), { let alpn = alpn_protocols.unwrap_or_else(|| self.inner.alpn_protocols.clone()); let mut session = match ClientConnection::new_with_alpn(self.inner.clone(), domain, alpn) { Ok(session) => session, Err(error) => { return Connect(MidHandshake::Error { io: stream, // TODO(eliza): should this really return an `io::Error`? // Probably not... error: io::Error::new(io::ErrorKind::Other, error), }); } }; f(&mut session); Connect(MidHandshake::Handshaking(TlsStream { io: stream, #[cfg(not(feature = "early-data"))] state: TlsState::Stream, #[cfg(feature = "early-data")] state: if self.early_data && session.early_data().is_some() { TlsState::EarlyData(0, Vec::new()) } else { TlsState::Stream }, need_flush: false, #[cfg(feature = "early-data")] early_waker: None, session, })) } pub fn with_alpn(&self, alpn_protocols: Vec>) -> TlsConnectorWithAlpn<'_> { TlsConnectorWithAlpn { inner: self, alpn_protocols, } } /// Get a read-only reference to underlying config pub fn config(&self) -> &Arc { &self.inner } } impl From> for TlsConnector { fn from(inner: Arc) -> Self { Self { inner, #[cfg(feature = "early-data")] early_data: false, } } } pub struct TlsConnectorWithAlpn<'c> { inner: &'c TlsConnector, alpn_protocols: Vec>, } impl TlsConnectorWithAlpn<'_> { #[inline] pub fn connect(self, domain: ServerName<'static>, stream: IO) -> Connect where IO: AsyncRead + AsyncWrite + Unpin, { self.inner .connect_impl(domain, stream, Some(self.alpn_protocols), |_| ()) } #[inline] pub fn connect_with(self, domain: ServerName<'static>, stream: IO, f: F) -> Connect where IO: AsyncRead + AsyncWrite + Unpin, F: FnOnce(&mut ClientConnection), { self.inner .connect_impl(domain, stream, Some(self.alpn_protocols), f) } } /// Future returned from `TlsConnector::connect` which will resolve /// once the connection handshake has finished. pub struct Connect(MidHandshake>); impl Connect { #[inline] pub fn into_fallible(self) -> FallibleConnect { FallibleConnect(self.0) } pub fn get_ref(&self) -> Option<&IO> { match &self.0 { MidHandshake::Handshaking(sess) => Some(sess.get_ref().0), MidHandshake::SendAlert { io, .. } => Some(io), MidHandshake::Error { io, .. } => Some(io), MidHandshake::End => None, } } pub fn get_mut(&mut self) -> Option<&mut IO> { match &mut self.0 { MidHandshake::Handshaking(sess) => Some(sess.get_mut().0), MidHandshake::SendAlert { io, .. } => Some(io), MidHandshake::Error { io, .. } => Some(io), MidHandshake::End => None, } } } impl Future for Connect { type Output = io::Result>; #[inline] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err) } } impl Future for FallibleConnect { type Output = Result, (io::Error, IO)>; #[inline] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { Pin::new(&mut self.0).poll(cx) } } /// Like [Connect], but returns `IO` on failure. pub struct FallibleConnect(MidHandshake>); /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. #[derive(Debug)] pub struct TlsStream { pub(crate) io: IO, pub(crate) session: ClientConnection, pub(crate) state: TlsState, pub(crate) need_flush: bool, #[cfg(feature = "early-data")] pub(crate) early_waker: Option, } impl TlsStream { #[inline] pub fn get_ref(&self) -> (&IO, &ClientConnection) { (&self.io, &self.session) } #[inline] pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) { (&mut self.io, &mut self.session) } #[inline] pub fn into_inner(self) -> (IO, ClientConnection) { (self.io, self.session) } } #[cfg(unix)] impl AsRawFd for TlsStream where S: AsRawFd, { fn as_raw_fd(&self) -> RawFd { self.get_ref().0.as_raw_fd() } } #[cfg(windows)] impl AsRawSocket for TlsStream where S: AsRawSocket, { fn as_raw_socket(&self) -> RawSocket { self.get_ref().0.as_raw_socket() } } impl IoSession for TlsStream { type Io = IO; type Session = ClientConnection; #[inline] fn skip_handshake(&self) -> bool { self.state.is_early_data() } #[inline] fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session, &mut bool) { ( &mut self.state, &mut self.io, &mut self.session, &mut self.need_flush, ) } #[inline] fn into_io(self) -> Self::Io { self.io } } #[cfg(feature = "early-data")] impl TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { fn poll_early_data(&mut self, cx: &mut Context<'_>) { // In the EarlyData state, we have not really established a Tls connection. // Before writing data through `AsyncWrite` and completing the tls handshake, // we ignore read readiness and return to pending. // // In order to avoid event loss, // we need to register a waker and wake it up after tls is connected. if self .early_waker .as_ref() .filter(|waker| cx.waker().will_wake(waker)) .is_none() { self.early_waker = Some(cx.waker().clone()); } } } impl AsyncRead for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { let data = ready!(self.as_mut().poll_fill_buf(cx))?; let len = data.len().min(buf.remaining()); buf.put_slice(&data[..len]); self.consume(len); Poll::Ready(Ok(())) } } impl AsyncBufRead for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.state { #[cfg(feature = "early-data")] TlsState::EarlyData(..) => { self.get_mut().poll_early_data(cx); Poll::Pending } TlsState::Stream | TlsState::WriteShutdown => { let this = self.get_mut(); let stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); match stream.poll_fill_buf(cx) { Poll::Ready(Ok(buf)) => { if buf.is_empty() { this.state.shutdown_read(); } Poll::Ready(Ok(buf)) } Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => { this.state.shutdown_read(); Poll::Ready(Err(err)) } output => output, } } TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(&[])), } } fn consume(mut self: Pin<&mut Self>, amt: usize) { self.session.reader().consume(amt); } } impl AsyncWrite for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { /// Note: that it does not guarantee the final data to be sent. /// To be cautious, you must manually call `flush`. fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()) .set_need_flush(this.need_flush); #[cfg(feature = "early-data")] { let bufs = [io::IoSlice::new(buf)]; let written = poll_handle_early_data( &mut this.state, &mut stream, &mut this.early_waker, cx, &bufs, )?; match written { Poll::Ready(0) => {} Poll::Ready(written) => return Poll::Ready(Ok(written)), Poll::Pending => { this.need_flush = stream.need_flush; return Poll::Pending; } } } stream.as_mut_pin().poll_write(cx, buf) } /// Note: that it does not guarantee the final data to be sent. /// To be cautious, you must manually call `flush`. fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()) .set_need_flush(this.need_flush); #[cfg(feature = "early-data")] { let written = poll_handle_early_data( &mut this.state, &mut stream, &mut this.early_waker, cx, bufs, )?; match written { Poll::Ready(0) => {} Poll::Ready(written) => return Poll::Ready(Ok(written)), Poll::Pending => { this.need_flush = stream.need_flush; return Poll::Pending; } } } stream.as_mut_pin().poll_write_vectored(cx, bufs) } #[inline] fn is_write_vectored(&self) -> bool { true } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()) .set_need_flush(this.need_flush); #[cfg(feature = "early-data")] { let written = poll_handle_early_data( &mut this.state, &mut stream, &mut this.early_waker, cx, &[], )?; if written.is_pending() { this.need_flush = stream.need_flush; return Poll::Pending; } } stream.as_mut_pin().poll_flush(cx) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { #[cfg(feature = "early-data")] { // complete handshake if matches!(self.state, TlsState::EarlyData(..)) { ready!(self.as_mut().poll_flush(cx))?; } } if self.state.writeable() { self.session.send_close_notify(); self.state.shutdown_write(); } let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); stream.as_mut_pin().poll_shutdown(cx) } } #[cfg(feature = "early-data")] fn poll_handle_early_data( state: &mut TlsState, stream: &mut Stream, early_waker: &mut Option, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> where IO: AsyncRead + AsyncWrite + Unpin, { if let TlsState::EarlyData(pos, data) = state { use std::io::Write; // write early data if let Some(mut early_data) = stream.session.early_data() { let mut written = 0; for buf in bufs { if buf.is_empty() { continue; } let len = match early_data.write(buf) { Ok(0) => break, Ok(n) => n, Err(err) => return Poll::Ready(Err(err)), }; written += len; data.extend_from_slice(&buf[..len]); if len < buf.len() { break; } } if written != 0 { return Poll::Ready(Ok(written)); } } // complete handshake while stream.session.is_handshaking() { ready!(stream.handshake(cx))?; } // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; *pos += len; } } // end *state = TlsState::Stream; if let Some(waker) = early_waker.take() { waker.wake(); } } Poll::Ready(Ok(0)) }