Skip to content

Commit

Permalink
net: add TcpStream::into_std (#3189)
Browse files Browse the repository at this point in the history
  • Loading branch information
liufuyang committed Dec 6, 2020
1 parent 0dbba13 commit 0707f4c
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 2 deletions.
8 changes: 8 additions & 0 deletions tokio/src/io/poll_evented.rs
Expand Up @@ -124,6 +124,14 @@ impl<E: Source> PollEvented<E> {
pub(crate) fn registration(&self) -> &Registration {
&self.registration
}

/// Deregister the inner io from the registration and returns a Result containing the inner io
#[cfg(feature = "net")]
pub(crate) fn into_inner(mut self) -> io::Result<E> {
let mut inner = self.io.take().unwrap(); // As io shouldn't ever be None, just unwrap here.
self.registration.deregister(&mut inner)?;
Ok(inner)
}
}

feature! {
Expand Down
56 changes: 54 additions & 2 deletions tokio/src/net/tcp/stream.rs
Expand Up @@ -9,10 +9,10 @@ use std::fmt;
use std::io;
use std::net::{Shutdown, SocketAddr};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, FromRawSocket};
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket};

#[cfg(unix)]
use std::os::unix::io::{AsRawFd, FromRawFd};
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
Expand Down Expand Up @@ -184,6 +184,58 @@ impl TcpStream {
Ok(TcpStream { io })
}

/// Turn a [`tokio::net::TcpStream`] into a [`std::net::TcpStream`].
///
/// The returned [`std::net::TcpStream`] will have `nonblocking mode` set as `true`.
/// Use [`set_nonblocking`] to change the blocking mode if needed.
///
/// # Examples
///
/// ```
/// use std::error::Error;
/// use std::io::Read;
/// use tokio::net::TcpListener;
/// # use tokio::net::TcpStream;
/// # use tokio::io::AsyncWriteExt;
///
/// #[tokio::main]
/// async fn main() -> Result<(), Box<dyn Error>> {
/// let mut data = [0u8; 12];
/// let listener = TcpListener::bind("127.0.0.1:34254").await?;
/// # let handle = tokio::spawn(async {
/// # let mut stream: TcpStream = TcpStream::connect("127.0.0.1:34254").await.unwrap();
/// # stream.write(b"Hello world!").await.unwrap();
/// # });
/// let (tokio_tcp_stream, _) = listener.accept().await?;
/// let mut std_tcp_stream = tokio_tcp_stream.into_std()?;
/// # handle.await.expect("The task being joined has panicked");
/// std_tcp_stream.set_nonblocking(false)?;
/// std_tcp_stream.read_exact(&mut data)?;
/// # assert_eq!(b"Hello world!", &data);
/// Ok(())
/// }
/// ```
/// [`tokio::net::TcpStream`]: TcpStream
/// [`std::net::TcpStream`]: std::net::TcpStream
/// [`set_nonblocking`]: fn@std::net::TcpStream::set_nonblocking
pub fn into_std(self) -> io::Result<std::net::TcpStream> {
#[cfg(unix)]
{
self.io
.into_inner()
.map(|io| io.into_raw_fd())
.map(|raw_fd| unsafe { std::net::TcpStream::from_raw_fd(raw_fd) })
}

#[cfg(windows)]
{
self.io
.into_inner()
.map(|io| io.into_raw_socket())
.map(|raw_socket| unsafe { std::net::TcpStream::from_raw_socket(raw_socket) })
}
}

/// Returns the local address that this stream is bound to.
///
/// # Examples
Expand Down
44 changes: 44 additions & 0 deletions tokio/tests/tcp_into_std.rs
@@ -0,0 +1,44 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use std::io::Read;
use std::io::Result;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::net::TcpStream;

#[tokio::test]
async fn tcp_into_std() -> Result<()> {
let mut data = [0u8; 12];
let listener = TcpListener::bind("127.0.0.1:34254").await?;

let handle = tokio::spawn(async {
let stream: TcpStream = TcpStream::connect("127.0.0.1:34254").await.unwrap();
stream
});

let (tokio_tcp_stream, _) = listener.accept().await?;
let mut std_tcp_stream = tokio_tcp_stream.into_std()?;
std_tcp_stream
.set_nonblocking(false)
.expect("set_nonblocking call failed");

let mut client = handle.await.expect("The task being joined has panicked");
client.write_all(b"Hello world!").await?;

std_tcp_stream
.read_exact(&mut data)
.expect("std TcpStream read failed!");
assert_eq!(b"Hello world!", &data);

// test back to tokio stream
std_tcp_stream
.set_nonblocking(true)
.expect("set_nonblocking call failed");
let mut tokio_tcp_stream = TcpStream::from_std(std_tcp_stream)?;
client.write_all(b"Hello tokio!").await?;
let _size = tokio_tcp_stream.read_exact(&mut data).await?;
assert_eq!(b"Hello tokio!", &data);

Ok(())
}

0 comments on commit 0707f4c

Please sign in to comment.