//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/rustls/rustls). //! //! # Why do I need to call `poll_flush`? //! //! Most TLS implementations will have an internal buffer to improve throughput, //! and rustls is no exception. //! //! When we write data to `TlsStream`, we always write rustls buffer first, //! then take out rustls encrypted data packet, and write it to data channel (like TcpStream). //! When data channel is pending, some data may remain in rustls buffer. //! //! `tokio-rustls` To keep it simple and correct, [TlsStream] will behave like `BufWriter`. //! For `TlsStream`, this means that data written by `poll_write` is not guaranteed to be written to `TcpStream`. //! You must call `poll_flush` to ensure that it is written to `TcpStream`. //! //! You should call `poll_flush` at the appropriate time, //! such as when a period of `poll_write` write is complete and there is no more data to write. //! //! ## Why don't we write during `poll_read`? //! //! We did this in the early days of `tokio-rustls`, but it caused some bugs. //! We can solve these bugs through some solutions, but this will cause performance degradation (reverse false wakeup). //! //! And reverse write will also prevent us implement full duplex in the future. //! //! see //! //! ## Why can't we handle it like `native-tls`? //! //! When data channel returns to pending, `native-tls` will falsely report the number of bytes it consumes. //! This means that if data written by `poll_write` is not actually written to data channel, it will not return `Ready`. //! Thus avoiding the call of `poll_flush`. //! //! but which does not conform to convention of `AsyncWrite` trait. //! This means that if you give inconsistent data in two `poll_write`, it may cause unexpected behavior. //! //! see #![warn(unreachable_pub, clippy::use_self)] use std::io; #[cfg(unix)] use std::os::unix::io::{AsRawFd, RawFd}; #[cfg(windows)] use std::os::windows::io::{AsRawSocket, RawSocket}; use std::pin::Pin; use std::task::{Context, Poll}; pub use rustls; use rustls::CommonState; use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; macro_rules! ready { ( $e:expr ) => { match $e { std::task::Poll::Ready(t) => t, std::task::Poll::Pending => return std::task::Poll::Pending, } }; } pub mod client; pub use client::{Connect, FallibleConnect, TlsConnector, TlsConnectorWithAlpn}; mod common; pub mod server; pub use server::{Accept, FallibleAccept, LazyConfigAcceptor, StartHandshake, TlsAcceptor}; /// Unified TLS stream type /// /// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use /// a single type to keep both client- and server-initiated TLS-encrypted connections. #[allow(clippy::large_enum_variant)] // https://github.com/rust-lang/rust-clippy/issues/9798 #[derive(Debug)] pub enum TlsStream { Client(client::TlsStream), Server(server::TlsStream), } impl TlsStream { pub fn get_ref(&self) -> (&T, &CommonState) { use TlsStream::*; match self { Client(io) => { let (io, session) = io.get_ref(); (io, session) } Server(io) => { let (io, session) = io.get_ref(); (io, session) } } } pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) { use TlsStream::*; match self { Client(io) => { let (io, session) = io.get_mut(); (io, &mut *session) } Server(io) => { let (io, session) = io.get_mut(); (io, &mut *session) } } } } impl From> for TlsStream { fn from(s: client::TlsStream) -> Self { Self::Client(s) } } impl From> for TlsStream { fn from(s: server::TlsStream) -> Self { Self::Server(s) } } #[cfg(unix)] impl AsRawFd for TlsStream where S: AsRawFd, { fn as_raw_fd(&self) -> RawFd { self.get_ref().0.as_raw_fd() } } #[cfg(windows)] impl AsRawSocket for TlsStream where S: AsRawSocket, { fn as_raw_socket(&self) -> RawSocket { self.get_ref().0.as_raw_socket() } } impl AsyncRead for TlsStream where T: AsyncRead + AsyncWrite + Unpin, { #[inline] fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { match self.get_mut() { Self::Client(x) => Pin::new(x).poll_read(cx, buf), Self::Server(x) => Pin::new(x).poll_read(cx, buf), } } } impl AsyncBufRead for TlsStream where T: AsyncRead + AsyncWrite + Unpin, { #[inline] fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { Self::Client(x) => Pin::new(x).poll_fill_buf(cx), Self::Server(x) => Pin::new(x).poll_fill_buf(cx), } } #[inline] fn consume(self: Pin<&mut Self>, amt: usize) { match self.get_mut() { Self::Client(x) => Pin::new(x).consume(amt), Self::Server(x) => Pin::new(x).consume(amt), } } } impl AsyncWrite for TlsStream where T: AsyncRead + AsyncWrite + Unpin, { #[inline] fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { match self.get_mut() { Self::Client(x) => Pin::new(x).poll_write(cx, buf), Self::Server(x) => Pin::new(x).poll_write(cx, buf), } } #[inline] fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { match self.get_mut() { Self::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs), Self::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs), } } #[inline] fn is_write_vectored(&self) -> bool { match self { Self::Client(x) => x.is_write_vectored(), Self::Server(x) => x.is_write_vectored(), } } #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { Self::Client(x) => Pin::new(x).poll_flush(cx), Self::Server(x) => Pin::new(x).poll_flush(cx), } } #[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { Self::Client(x) => Pin::new(x).poll_shutdown(cx), Self::Server(x) => Pin::new(x).poll_shutdown(cx), } } }