use std::future::Future; use std::io::{self, BufRead as _}; #[cfg(unix)] use std::os::unix::io::{AsRawFd, RawFd}; #[cfg(windows)] use std::os::windows::io::{AsRawSocket, RawSocket}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use rustls::server::AcceptedAlert; use rustls::{ServerConfig, ServerConnection}; use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use crate::common::{IoSession, MidHandshake, Stream, SyncReadAdapter, SyncWriteAdapter, TlsState}; /// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. #[derive(Clone)] pub struct TlsAcceptor { inner: Arc, } impl From> for TlsAcceptor { fn from(inner: Arc) -> Self { Self { inner } } } impl TlsAcceptor { #[inline] pub fn accept(&self, stream: IO) -> Accept where IO: AsyncRead + AsyncWrite + Unpin, { self.accept_with(stream, |_| ()) } pub fn accept_with(&self, stream: IO, f: F) -> Accept where IO: AsyncRead + AsyncWrite + Unpin, F: FnOnce(&mut ServerConnection), { let mut session = match ServerConnection::new(self.inner.clone()) { Ok(session) => session, Err(error) => { return Accept(MidHandshake::Error { io: stream, // TODO(eliza): should this really return an `io::Error`? // Probably not... error: io::Error::new(io::ErrorKind::Other, error), }); } }; f(&mut session); Accept(MidHandshake::Handshaking(TlsStream { session, io: stream, state: TlsState::Stream, need_flush: false, })) } /// Get a read-only reference to underlying config pub fn config(&self) -> &Arc { &self.inner } } pub struct LazyConfigAcceptor { acceptor: rustls::server::Acceptor, io: Option, alert: Option<(rustls::Error, AcceptedAlert)>, } impl LazyConfigAcceptor where IO: AsyncRead + AsyncWrite + Unpin, { #[inline] pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self { Self { acceptor, io: Some(io), alert: None, } } /// Takes back the client connection. Will return `None` if called more than once or if the /// connection has been accepted. /// /// # Example /// /// ```no_run /// # fn choose_server_config( /// # _: rustls::server::ClientHello, /// # ) -> std::sync::Arc { /// # unimplemented!(); /// # } /// # #[allow(unused_variables)] /// # async fn listen() { /// use tokio::io::AsyncWriteExt; /// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap(); /// let (stream, _) = listener.accept().await.unwrap(); /// /// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream); /// tokio::pin!(acceptor); /// /// match acceptor.as_mut().await { /// Ok(start) => { /// let clientHello = start.client_hello(); /// let config = choose_server_config(clientHello); /// let stream = start.into_stream(config).await.unwrap(); /// // Proceed with handling the ServerConnection... /// } /// Err(err) => { /// if let Some(mut stream) = acceptor.take_io() { /// stream /// .write_all( /// format!("HTTP/1.1 400 Invalid Input\r\n\r\n\r\n{:?}\n", err) /// .as_bytes() /// ) /// .await /// .unwrap(); /// } /// } /// } /// # } /// ``` pub fn take_io(&mut self) -> Option { self.io.take() } } impl Future for LazyConfigAcceptor where IO: AsyncRead + AsyncWrite + Unpin, { type Output = Result, io::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); loop { let io = match this.io.as_mut() { Some(io) => io, None => { return Poll::Ready(Err(io::Error::new( io::ErrorKind::Other, "acceptor cannot be polled after acceptance", ))) } }; if let Some((err, mut alert)) = this.alert.take() { match alert.write(&mut SyncWriteAdapter { io, cx }) { Err(e) if e.kind() == io::ErrorKind::WouldBlock => { this.alert = Some((err, alert)); return Poll::Pending; } Ok(0) | Err(_) => { return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err))) } Ok(_) => { this.alert = Some((err, alert)); continue; } }; } let mut reader = SyncReadAdapter { io, cx }; match this.acceptor.read_tls(&mut reader) { Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(), Ok(_) => {} Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, Err(e) => return Err(e).into(), } match this.acceptor.accept() { Ok(Some(accepted)) => { let io = this.io.take().unwrap(); return Poll::Ready(Ok(StartHandshake { accepted, io })); } Ok(None) => {} Err((err, alert)) => { this.alert = Some((err, alert)); } } } } } /// An incoming connection received through [`LazyConfigAcceptor`]. /// /// This contains the generic `IO` asynchronous transport, /// [`ClientHello`](rustls::server::ClientHello) data, /// and all the state required to continue the TLS handshake (e.g. via /// [`StartHandshake::into_stream`]). #[non_exhaustive] #[derive(Debug)] pub struct StartHandshake { pub accepted: rustls::server::Accepted, pub io: IO, } impl StartHandshake where IO: AsyncRead + AsyncWrite + Unpin, { /// Create a new object from an `IO` transport and prior TLS metadata. pub fn from_parts(accepted: rustls::server::Accepted, transport: IO) -> Self { Self { accepted, io: transport, } } pub fn client_hello(&self) -> rustls::server::ClientHello<'_> { self.accepted.client_hello() } pub fn into_stream(self, config: Arc) -> Accept { self.into_stream_with(config, |_| ()) } pub fn into_stream_with(self, config: Arc, f: F) -> Accept where F: FnOnce(&mut ServerConnection), { let mut conn = match self.accepted.into_connection(config) { Ok(conn) => conn, Err((error, alert)) => { return Accept(MidHandshake::SendAlert { io: self.io, alert, // TODO(eliza): should this really return an `io::Error`? // Probably not... error: io::Error::new(io::ErrorKind::InvalidData, error), }); } }; f(&mut conn); Accept(MidHandshake::Handshaking(TlsStream { session: conn, io: self.io, state: TlsState::Stream, need_flush: false, })) } } /// Future returned from `TlsAcceptor::accept` which will resolve /// once the accept handshake has finished. pub struct Accept(MidHandshake>); impl Accept { #[inline] pub fn into_fallible(self) -> FallibleAccept { FallibleAccept(self.0) } pub fn get_ref(&self) -> Option<&IO> { match &self.0 { MidHandshake::Handshaking(sess) => Some(sess.get_ref().0), MidHandshake::SendAlert { io, .. } => Some(io), MidHandshake::Error { io, .. } => Some(io), MidHandshake::End => None, } } pub fn get_mut(&mut self) -> Option<&mut IO> { match &mut self.0 { MidHandshake::Handshaking(sess) => Some(sess.get_mut().0), MidHandshake::SendAlert { io, .. } => Some(io), MidHandshake::Error { io, .. } => Some(io), MidHandshake::End => None, } } } impl Future for Accept { type Output = io::Result>; #[inline] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err) } } /// Like [Accept], but returns `IO` on failure. pub struct FallibleAccept(MidHandshake>); impl Future for FallibleAccept { type Output = Result, (io::Error, IO)>; #[inline] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { Pin::new(&mut self.0).poll(cx) } } /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. #[derive(Debug)] pub struct TlsStream { pub(crate) io: IO, pub(crate) session: ServerConnection, pub(crate) state: TlsState, pub(crate) need_flush: bool, } impl TlsStream { #[inline] pub fn get_ref(&self) -> (&IO, &ServerConnection) { (&self.io, &self.session) } #[inline] pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) { (&mut self.io, &mut self.session) } #[inline] pub fn into_inner(self) -> (IO, ServerConnection) { (self.io, self.session) } } impl IoSession for TlsStream { type Io = IO; type Session = ServerConnection; #[inline] fn skip_handshake(&self) -> bool { false } #[inline] fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session, &mut bool) { ( &mut self.state, &mut self.io, &mut self.session, &mut self.need_flush, ) } #[inline] fn into_io(self) -> Self::Io { self.io } } impl AsyncRead for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { let data = ready!(self.as_mut().poll_fill_buf(cx))?; let len = data.len().min(buf.remaining()); buf.put_slice(&data[..len]); self.consume(len); Poll::Ready(Ok(())) } } impl AsyncBufRead for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.state { TlsState::Stream | TlsState::WriteShutdown => { let this = self.get_mut(); let stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); match stream.poll_fill_buf(cx) { Poll::Ready(Ok(buf)) => { if buf.is_empty() { this.state.shutdown_read(); } Poll::Ready(Ok(buf)) } Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => { this.state.shutdown_read(); Poll::Ready(Err(err)) } output => output, } } TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(&[])), #[cfg(feature = "early-data")] ref s => unreachable!("server TLS can not hit this state: {:?}", s), } } fn consume(mut self: Pin<&mut Self>, amt: usize) { self.session.reader().consume(amt); } } impl AsyncWrite for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { /// Note: that it does not guarantee the final data to be sent. /// To be cautious, you must manually call `flush`. fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); stream.as_mut_pin().poll_write(cx, buf) } /// Note: that it does not guarantee the final data to be sent. /// To be cautious, you must manually call `flush`. fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); stream.as_mut_pin().poll_write_vectored(cx, bufs) } #[inline] fn is_write_vectored(&self) -> bool { true } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); stream.as_mut_pin().poll_flush(cx) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.state.writeable() { self.session.send_close_notify(); self.state.shutdown_write(); } let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); stream.as_mut_pin().poll_shutdown(cx) } } #[cfg(unix)] impl AsRawFd for TlsStream where IO: AsRawFd, { fn as_raw_fd(&self) -> RawFd { self.get_ref().0.as_raw_fd() } } #[cfg(windows)] impl AsRawSocket for TlsStream where IO: AsRawSocket, { fn as_raw_socket(&self) -> RawSocket { self.get_ref().0.as_raw_socket() } }