diff --git a/.cirrus.yml b/.cirrus.yml index afb761b7109..7a5f9a6d6d1 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -1,5 +1,5 @@ freebsd_instance: - image: freebsd-12-0-release-amd64 + image: freebsd-12-1-release-amd64 # Test FreeBSD in a full VM on cirrus-ci.com. Test the i686 target too, in the # same VM. The binary will be built in 32-bit mode, but will execute on a @@ -39,4 +39,4 @@ task: # i686_test_script: # - . $HOME/.cargo/env # - | - # cargo test --all --exclude tokio-tls --exclude tokio-macros --target i686-unknown-freebsd + # cargo test --all --exclude tokio-macros --target i686-unknown-freebsd diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md deleted file mode 100644 index 269780639f3..00000000000 --- a/.github/ISSUE_TEMPLATE.md +++ /dev/null @@ -1,51 +0,0 @@ - - -## Version - - - -## Platform - - - -## Subcrates - - - -## Description - - diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000000..98cf6116a72 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,36 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: A-tokio, C-bug +assignees: '' + +--- + +**Version** +List the versions of all `tokio` crates you are using. The easiest way to get +this information is using `cargo-tree`. + +`cargo install cargo-tree` +(see install here: https://github.com/sfackler/cargo-tree) + +Then: + +`cargo tree | grep tokio` + +**Platform** +The output of `uname -a` (UNIX), or version and 32 or 64-bit (Windows) + +**Description** +Enter your issue details here. +One way to structure the description: + +[short summary of the bug] + +I tried this code: + +[code sample that causes the bug] + +I expected to see this happen: [explanation] + +Instead, this happened: [explanation] diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000000..e90a4933174 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: A-tokio, C-feature-request +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md new file mode 100644 index 00000000000..03112f55412 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.md @@ -0,0 +1,16 @@ +--- +name: Question +about: Please use the discussions tab for questions +title: '' +labels: '' +assignees: '' + +--- + +Please post your question as a discussion here: +https://github.com/tokio-rs/tokio/discussions + + +You may also be able to find help here: +https://discord.gg/tokio +https://users.rust-lang.org/ diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index bbff0233b35..6b3db9ae609 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -5,6 +5,9 @@ the requirements below. Bug fixes and new features should include tests. Contributors guide: https://github.com/tokio-rs/tokio/blob/master/CONTRIBUTING.md + +The contributors guide includes instructions for running rustfmt and building the +documentation, which requires special commands beyond `cargo fmt` and `cargo doc`. --> ## Motivation diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml new file mode 100644 index 00000000000..a901a0fd014 --- /dev/null +++ b/.github/workflows/audit.yml @@ -0,0 +1,22 @@ +name: Security Audit + +on: + push: + branches: + - master + paths: + - '**/Cargo.toml' + schedule: + - cron: '0 2 * * *' # run at 2 AM UTC + +jobs: + security-audit: + runs-on: ubuntu-latest + if: "!contains(github.event.head_commit.message, 'ci skip')" + steps: + - uses: actions/checkout@v2 + + - name: Audit Check + uses: actions-rs/audit-check@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/pr-audit.yml b/.github/workflows/pr-audit.yml new file mode 100644 index 00000000000..26c0ee2f119 --- /dev/null +++ b/.github/workflows/pr-audit.yml @@ -0,0 +1,32 @@ +name: Pull Request Security Audit + +on: + push: + paths: + - '**/Cargo.toml' + pull_request: + paths: + - '**/Cargo.toml' + +jobs: + security-audit: + runs-on: ubuntu-latest + if: "!contains(github.event.head_commit.message, 'ci skip')" + steps: + - uses: actions/checkout@v2 + + - name: Install cargo-audit + uses: actions-rs/cargo@v1 + with: + command: install + args: cargo-audit + + - name: Generate lockfile + uses: actions-rs/cargo@v1 + with: + command: generate-lockfile + + - name: Audit dependencies + uses: actions-rs/cargo@v1 + with: + command: audit diff --git a/.github/workflows/test_tokio.yml b/.github/workflows/test_tokio.yml new file mode 100644 index 00000000000..d6d32b8f54b --- /dev/null +++ b/.github/workflows/test_tokio.yml @@ -0,0 +1,129 @@ +on: + push: + branches: ["master"] + pull_request: + branches: ["master"] + +name: Test tokio + +env: + RUSTFLAGS: -Dwarnings + RUST_BACKTRACE: 1 + nightly: nightly-2020-01-25 + +jobs: + test_tokio: + name: Test tokio full + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: + - windows-latest + - ubuntu-latest + - macos-latest + steps: + - uses: actions/checkout@master + - name: Install Rust + run: rustup update stable + + # Run `tokio` with only `full` + - uses: actions-rs/cargo@v1 + with: + command: test + args: --manifest-path ${{ github.workspace }}/tokio/Cargo.toml --features full + name: tokio - cargo test --features full + + # Run `tokio` with "unstable" cfg flag + - uses: actions-rs/cargo@v1 + with: + command: test + args: --manifest-path ${{ github.workspace }}/tokio/Cargo.toml --features full + env: + RUSTFLAGS: '--cfg tokio_unstable' + name: tokio - cargo test --features full --cfg tokio_unstable + + test_cross_subcrates: + name: Test ${{ matrix.crate }} (${{ matrix.os }}) all-features + runs-on: ${{ matrix.os }} + strategy: + matrix: + crate: + - tokio + - tests-integration + os: + - windows-latest + - ubuntu-latest + - macos-latest + include: + - crate: tokio-macros + os: ubuntu-latest + - crate: tokio-test + os: ubuntu-latest + - crate: tokio-util + os: ubuntu-latest + - crate: examples + os: ubuntu-latest + + steps: + - uses: actions/checkout@master + - name: Install Rust + run: rustup update stable + + # Run with all crate features + - name: ${{ matrix.crate }} - cargo test --all-features + uses: actions-rs/cargo@v1 + with: + command: test + args: --manifest-path ${{ github.workspace }}/${{ matrix.crate }}/Cargo.toml --all-features + + # Check benches + - name: ${{ matrix.crate }} - cargo check --benches + uses: actions-rs/cargo@v1 + with: + command: check + args: --manifest-path ${{ github.workspace }}/${{ matrix.crate }}/Cargo.toml --all-features --benches + + - name: Patch Cargo.toml + shell: bash + run: | + set -e + + # Remove any existing patch statements + mv Cargo.toml Cargo.toml.bck + sed -n '/\[patch.crates-io\]/q;p' Cargo.toml.bck > Cargo.toml + + # Patch all crates + cat ci/patch.toml >> Cargo.toml + + # Print `Cargo.toml` for debugging + echo "~~~~ Cargo.toml ~~~~" + cat Cargo.toml + echo "~~~~~~~~~~~~~~~~~~~~" + + # Run with all crate features + - name: ${{ matrix.crate }} - cargo test --all-features + uses: actions-rs/cargo@v1 + with: + command: test + args: --manifest-path ${{ github.workspace }}/${{ matrix.crate }}/Cargo.toml --all-features + + test_integration: + name: Integration tests + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: + - windows-latest + - ubuntu-latest + - macos-latest + steps: + - uses: actions/checkout@master + - name: Install Rust + run: rustup update stable + - run: cargo install cargo-hack + name: Install cargo-hack + - uses: actions-rs/cargo@v1 + name: cargo hack test --each-feature + with: + command: hack + args: test --manifest-path ${{ github.workspace }}/tests-integration/Cargo.toml --each-feature diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0aa58273081..70bc4559b22 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -15,12 +15,14 @@ It should be considered a map to help you navigate the process. The [dev channel][dev] is available for any concerns not covered in this guide, please join us! -[dev]: https://discord.gg/6yGkFeN +[dev]: https://discord.gg/tokio ## Conduct The Tokio project adheres to the [Rust Code of Conduct][coc]. This describes -the _minimum_ behavior expected from all contributors. Instances of violations of the Code of Conduct can be reported by contacting the project team at [moderation@tokio.rs](mailto:moderation@tokio.rs). +the _minimum_ behavior expected from all contributors. Instances of violations of the +Code of Conduct can be reported by contacting the project team at +[moderation@tokio.rs](mailto:moderation@tokio.rs). [coc]: https://github.com/rust-lang/rust/blob/master/CODE_OF_CONDUCT.md @@ -29,8 +31,8 @@ the _minimum_ behavior expected from all contributors. Instances of violations o For any issue, there are fundamentally three ways an individual can contribute: 1. By opening the issue for discussion: For instance, if you believe that you - have uncovered a bug in Tokio, creating a new issue in the tokio-rs/tokio - issue tracker is the way to report it. + have discovered a bug in Tokio, creating a new issue in [the tokio-rs/tokio + issue tracker][issue] is the way to report it. 2. By helping to triage the issue: This can be done by providing supporting details (a test case that demonstrates a bug), providing @@ -42,21 +44,25 @@ For any issue, there are fundamentally three ways an individual can contribute: often, by opening a Pull Request that changes some bit of something in Tokio in a concrete and reviewable manner. +[issue]: https://github.com/tokio-rs/tokio/issues + **Anybody can participate in any stage of contribution**. We urge you to participate in the discussion around bugs and participate in reviewing PRs. ### Asking for General Help If you have reviewed existing documentation and still have questions or are -having problems, you can open an issue asking for help. +having problems, you can [open a discussion] asking for help. In exchange for receiving help, we ask that you contribute back a documentation PR that helps others avoid the problems that you encountered. +[open a discussion]: https://github.com/tokio-rs/tokio/discussions/new + ### Submitting a Bug Report -When opening a new issue in the Tokio issue tracker, users will be presented -with a [basic template][template] that should be filled in. If you believe that you have +When opening a new issue in the Tokio issue tracker, you will be presented +with a basic template that should be filled in. If you believe that you have uncovered a bug, please fill out this form, following the template to the best of your ability. Do not worry if you cannot answer every detail, just fill in what you can. @@ -72,7 +78,6 @@ cases should be limited, as much as possible, to using only Tokio APIs. See [How to create a Minimal, Complete, and Verifiable example][mcve]. [mcve]: https://stackoverflow.com/help/mcve -[template]: .github/PULL_REQUEST_TEMPLATE.md ### Triaging a Bug Report @@ -132,8 +137,13 @@ RUSTDOCFLAGS="--cfg docsrs" cargo +nightly doc --all-features ``` The `cargo fmt` command does not work on the Tokio codebase. You can use the command below instead: + ``` +# Mac or Linux rustfmt --check --edition 2018 $(find . -name '*.rs' -print) + +# Powershell +Get-ChildItem . -Filter "*.rs" -Recurse | foreach { rustfmt --check --edition 2018 $_.FullName } ``` The `--check` argument prints the things that need to be fixed. If you remove it, `rustfmt` will update your files locally instead. @@ -250,7 +260,7 @@ That said, if you have a number of commits that are "checkpoints" and don't represent a single logical change, please squash those together. Note that multiple commits often get squashed when they are landed (see the -notes about [commit squashing]). +notes about [commit squashing](#commit-squashing)). #### Commit message guidelines @@ -321,7 +331,7 @@ in order to evaluate whether the changes are correct and necessary. Keep an eye out for comments from code owners to provide guidance on conflicting feedback. -**Once the PR is open, do not rebase the commits**. See [Commit Squashing] for +**Once the PR is open, do not rebase the commits**. See [Commit Squashing](#commit-squashing) for more details. ### Commit Squashing diff --git a/Cargo.toml b/Cargo.toml index ebae0d3db6b..39d2936645b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,6 @@ members = [ "tokio", "tokio-macros", "tokio-test", - "tokio-tls", "tokio-util", # Internal diff --git a/README.md b/README.md index 5a661998f5a..31b2ae12a6c 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ the Rust programming language. It is: [crates-badge]: https://img.shields.io/crates/v/tokio.svg [crates-url]: https://crates.io/crates/tokio [mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg -[mit-url]: LICENSE +[mit-url]: https://github.com/tokio-rs/tokio/blob/master/LICENSE [azure-badge]: https://dev.azure.com/tokio-rs/Tokio/_apis/build/status/tokio-rs.tokio?branchName=master [azure-url]: https://dev.azure.com/tokio-rs/Tokio/_build/latest?definitionId=1&branchName=master [discord-badge]: https://img.shields.io/discord/500028886025895936.svg?logo=discord&style=flat-square @@ -90,19 +90,23 @@ async fn main() -> Result<(), Box> { } ``` -More examples can be found [here](examples). +More examples can be found [here][examples]. For a larger "real world" example, see the +[mini-redis] repository. + +[examples]: https://github.com/tokio-rs/tokio/tree/master/examples +[mini-redis]: https://github.com/tokio-rs/mini-redis/ ## Getting Help First, see if the answer to your question can be found in the [Guides] or the [API documentation]. If the answer is not there, there is an active community in the [Tokio Discord server][chat]. We would be happy to try to answer your -question. Last, if that doesn't work, try opening an [issue] with the question. +question. You can also ask your question on [the discussions page][discussions]. [Guides]: https://tokio.rs/docs/overview/ [API documentation]: https://docs.rs/tokio/latest/tokio [chat]: https://discord.gg/tokio -[issue]: https://github.com/tokio-rs/tokio/issues/new +[discussions]: https://github.com/tokio-rs/tokio/discussions ## Contributing @@ -149,15 +153,15 @@ several other libraries, including: ## Supported Rust Versions -Tokio is built against the latest stable, nightly, and beta Rust releases. The -minimum version supported is the stable release from three months before the -current stable release version. For example, if the latest stable Rust is 1.29, -the minimum version supported is 1.26. The current Tokio version is not -guaranteed to build on Rust versions earlier than the minimum supported version. +Tokio is built against the latest stable release. The minimum supported version is 1.39. +The current Tokio version is not guaranteed to build on Rust versions earlier than the +minimum supported version. ## License -This project is licensed under the [MIT license](LICENSE). +This project is licensed under the [MIT license]. + +[MIT license]: https://github.com/tokio-rs/tokio/blob/master/LICENSE ### Contribution diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000000..bf155ff9e2f --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,13 @@ +## Report a security issue + +The Tokio project team welcomes security reports and is committed to providing prompt attention to security issues. Security issues should be reported privately via [security@tokio.rs](mailto:security@tokio.rs). Security issues should not be reported via the public Github Issue tracker. + +## Vulnerability coordination + +Remediation of security vulnerabilities is prioritized by the project team. The project team coordinates remediation with third-party project stakeholders via [Github Security Advisories](https://help.github.com/en/github/managing-security-vulnerabilities/about-github-security-advisories). Third-party stakeholders may include the reporter of the issue, affected direct or indirect users of Tokio, and maintainers of upstream dependencies if applicable. + +Downstream project maintainers and Tokio users can request participation in coordination of applicable security issues by sending your contact email address, Github username(s) and any other salient information to [security@tokio.rs](mailto:security@tokio.rs). Participation in security issue coordination processes is at the discretion of the Tokio team. + +## Security advisories + +The project team is committed to transparency in the security issue disclosure process. The Tokio team announces security issues via [project Github Release notes](https://github.com/tokio-rs/tokio/releases) and the [RustSec advisory database](https://github.com/RustSec/advisory-db) (i.e. `cargo-audit`). diff --git a/azure-pipelines.yml b/azure-pipelines.yml index cc50f3c88ca..ba0b9da389b 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -3,7 +3,7 @@ pr: ["master"] variables: RUSTFLAGS: -Dwarnings - nightly: nightly-2020-01-25 + nightly: nightly-2020-05-29 jobs: # Test top level crate @@ -26,7 +26,6 @@ jobs: crates: - tokio-macros - tokio-test - - tokio-tls - tokio-util - examples @@ -48,6 +47,7 @@ jobs: - template: ci/azure-miri.yml parameters: name: miri + rust: $(nightly) # Try cross compiling - template: ci/azure-cross-compile.yml diff --git a/ci/azure-check-features.yml b/ci/azure-check-features.yml index f5985843e10..98921300523 100644 --- a/ci/azure-check-features.yml +++ b/ci/azure-check-features.yml @@ -30,3 +30,9 @@ jobs: # tracking-issue: https://github.com/rust-lang/cargo/issues/5133 - script: cargo hack check --all --each-feature -Z avoid-dev-deps displayName: cargo hack check --all --each-feature + + # Try with unstable feature flags + - script: cargo hack check --all --each-feature -Z avoid-dev-deps + displayName: cargo hack check --all --each-feature + env: + RUSTFLAGS: '--cfg tokio_unstable' diff --git a/ci/azure-clippy.yml b/ci/azure-clippy.yml index 58ab318f718..b843223f4b0 100644 --- a/ci/azure-clippy.yml +++ b/ci/azure-clippy.yml @@ -12,5 +12,5 @@ jobs: cargo clippy --version displayName: Install clippy - script: | - cargo clippy --all --all-features - displayName: cargo clippy --all + cargo clippy --all --all-features --tests + displayName: cargo clippy --all --tests diff --git a/ci/azure-cross-compile.yml b/ci/azure-cross-compile.yml index 74acaee282f..11c530ca15f 100644 --- a/ci/azure-cross-compile.yml +++ b/ci/azure-cross-compile.yml @@ -37,8 +37,8 @@ jobs: # Always patch - template: azure-patch-crates.yml - - script: cross check --all --exclude tokio-tls --target $(target) + - script: cross check --all --target $(target) displayName: Check source - # - script: cross check --tests --all --exclude tokio-tls --target $(target) + # - script: cross check --tests --all --target $(target) # displayName: Check tests diff --git a/ci/azure-loom.yml b/ci/azure-loom.yml index 001aedec263..df02711bb91 100644 --- a/ci/azure-loom.yml +++ b/ci/azure-loom.yml @@ -21,7 +21,7 @@ jobs: parameters: rust_version: ${{ parameters.rust }} - - script: RUSTFLAGS="--cfg loom" cargo test --lib --release --features "full" -- --nocapture $(scope) + - script: RUSTFLAGS="--cfg loom --cfg tokio_unstable" cargo test --lib --release --features "full" -- --nocapture $(scope) env: LOOM_MAX_PREEMPTIONS: 2 CI: 'True' diff --git a/ci/azure-miri.yml b/ci/azure-miri.yml index fb886edc7d4..05bc973b266 100644 --- a/ci/azure-miri.yml +++ b/ci/azure-miri.yml @@ -7,7 +7,7 @@ jobs: steps: - template: azure-install-rust.yml parameters: - rust_version: nightly + rust_version: ${{ parameters.rust }} - script: | rustup component add miri diff --git a/ci/azure-test-stable.yml b/ci/azure-test-stable.yml index ce22c942f38..55aa84e092d 100644 --- a/ci/azure-test-stable.yml +++ b/ci/azure-test-stable.yml @@ -21,6 +21,23 @@ jobs: - template: azure-is-release.yml + # Run `tokio` with only `full` + - script: cargo test --features full + env: + RUST_BACKTRACE: 1 + CI: 'True' + displayName: tokio - cargo test --features full + workingDirectory: $(Build.SourcesDirectory)/tokio + + # Run `tokio` with "unstable" cfg flag + - script: cargo test --all-features + env: + RUSTFLAGS: '--cfg tokio_unstable' + RUST_BACKTRACE: 1 + CI: 'True' + displayName: tokio - cargo test --features full + workingDirectory: $(Build.SourcesDirectory)/tokio + - ${{ each crate in parameters.crates }}: # Run with all crate features - script: cargo test --all-features diff --git a/ci/patch.toml b/ci/patch.toml index 22311cf9a76..1650e3b1da4 100644 --- a/ci/patch.toml +++ b/ci/patch.toml @@ -4,5 +4,4 @@ tokio = { path = "tokio" } tokio-macros = { path = "tokio-macros" } tokio-test = { path = "tokio-test" } -tokio-tls = { path = "tokio-tls" } tokio-util = { path = "tokio-util" } diff --git a/examples/Cargo.toml b/examples/Cargo.toml index c3dc6091518..fe3c90f9a56 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -7,7 +7,9 @@ edition = "2018" # If you copy one of the examples into a new project, you should be using # [dependencies] instead. [dev-dependencies] -tokio = { version = "0.2.0", path = "../tokio", features = ["full"] } +tokio = { version = "0.2.0", path = "../tokio", features = ["full", "tracing"] } +tracing = "0.1" +tracing-subscriber = { version = "0.2.7", default-features = false, features = ["fmt", "ansi", "env-filter", "chrono", "tracing-log"] } tokio-util = { version = "0.3.0", path = "../tokio-util", features = ["full"] } bytes = "0.5" futures = "0.3.0" diff --git a/examples/README.md b/examples/README.md index 8217d4c6bb0..15b06c092b5 100644 --- a/examples/README.md +++ b/examples/README.md @@ -13,8 +13,11 @@ A good starting point for the examples would be [`hello_world`](hello_world.rs) and [`echo`](echo.rs). Additionally [the tokio website][tokioweb] contains additional guides for some of the examples. +For a larger "real world" example, see the [`mini-redis`][redis] repository. + If you've got an example you'd like to see here, please feel free to open an issue. Otherwise if you've got an example you'd like to add, please feel free to make a PR! [tokioweb]: https://tokio.rs/docs/overview/ +[redis]: https://github.com/tokio-rs/mini-redis diff --git a/examples/chat.rs b/examples/chat.rs index b3fb727a2cc..c4b8c6a2afc 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -43,6 +43,26 @@ use std::task::{Context, Poll}; #[tokio::main] async fn main() -> Result<(), Box> { + use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; + // Configure a `tracing` subscriber that logs traces emitted by the chat + // server. + tracing_subscriber::fmt() + // Filter what traces are displayed based on the RUST_LOG environment + // variable. + // + // Traces emitted by the example code will always be displayed. You + // can set `RUST_LOG=tokio=trace` to enable additional traces emitted by + // Tokio itself. + .with_env_filter(EnvFilter::from_default_env().add_directive("chat=info".parse()?)) + // Log events when `tracing` spans are created, entered, exited, or + // closed. When Tokio's internal tracing support is enabled (as + // described above), this can be used to track the lifecycle of spawned + // tasks on the Tokio runtime. + .with_span_events(FmtSpan::FULL) + // Set this subscriber as the default, to collect all traces emitted by + // the program. + .init(); + // Create the shared state. This is how all the peers communicate. // // The server task will hold a handle to this. For every new client, the @@ -59,7 +79,7 @@ async fn main() -> Result<(), Box> { // Note that this is the Tokio TcpListener, which is fully async. let mut listener = TcpListener::bind(&addr).await?; - println!("server running on {}", addr); + tracing::info!("server running on {}", addr); loop { // Asynchronously wait for an inbound TcpStream. @@ -70,8 +90,9 @@ async fn main() -> Result<(), Box> { // Spawn our handler to be run asynchronously. tokio::spawn(async move { + tracing::debug!("accepted connection"); if let Err(e) = process(state, stream, addr).await { - println!("an error occurred; error = {:?}", e); + tracing::info!("an error occurred; error = {:?}", e); } }); } @@ -200,7 +221,7 @@ async fn process( Some(Ok(line)) => line, // We didn't get a line so we return early here. _ => { - println!("Failed to get username from {}. Client disconnected.", addr); + tracing::error!("Failed to get username from {}. Client disconnected.", addr); return Ok(()); } }; @@ -212,7 +233,7 @@ async fn process( { let mut state = state.lock().await; let msg = format!("{} has joined the chat", username); - println!("{}", msg); + tracing::info!("{}", msg); state.broadcast(addr, &msg).await; } @@ -233,9 +254,10 @@ async fn process( peer.lines.send(&msg).await?; } Err(e) => { - println!( + tracing::error!( "an error occurred while processing messages for {}; error = {:?}", - username, e + username, + e ); } } @@ -248,7 +270,7 @@ async fn process( state.peers.remove(&addr); let msg = format!("{} has left the chat", username); - println!("{}", msg); + tracing::info!("{}", msg); state.broadcast(addr, &msg).await; } diff --git a/examples/echo-udp.rs b/examples/echo-udp.rs index d8b2af9cbb4..bc688b9b79b 100644 --- a/examples/echo-udp.rs +++ b/examples/echo-udp.rs @@ -15,7 +15,6 @@ use std::error::Error; use std::net::SocketAddr; use std::{env, io}; -use tokio; use tokio::net::UdpSocket; struct Server { diff --git a/examples/echo.rs b/examples/echo.rs index 35b122794cc..f30680748db 100644 --- a/examples/echo.rs +++ b/examples/echo.rs @@ -21,7 +21,6 @@ #![warn(rust_2018_idioms)] -use tokio; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; diff --git a/examples/proxy.rs b/examples/proxy.rs index f7a9111f6a2..144f0179fc3 100644 --- a/examples/proxy.rs +++ b/examples/proxy.rs @@ -23,6 +23,7 @@ #![warn(rust_2018_idioms)] use tokio::io; +use tokio::io::AsyncWriteExt; use tokio::net::{TcpListener, TcpStream}; use futures::future::try_join; @@ -63,8 +64,15 @@ async fn transfer(mut inbound: TcpStream, proxy_addr: String) -> Result<(), Box< let (mut ri, mut wi) = inbound.split(); let (mut ro, mut wo) = outbound.split(); - let client_to_server = io::copy(&mut ri, &mut wo); - let server_to_client = io::copy(&mut ro, &mut wi); + let client_to_server = async { + io::copy(&mut ri, &mut wo).await?; + wo.shutdown().await + }; + + let server_to_client = async { + io::copy(&mut ro, &mut wi).await?; + wi.shutdown().await + }; try_join(client_to_server, server_to_client).await?; diff --git a/examples/tinyhttp.rs b/examples/tinyhttp.rs index 732da0d64ee..4870aea2f83 100644 --- a/examples/tinyhttp.rs +++ b/examples/tinyhttp.rs @@ -18,7 +18,6 @@ use futures::SinkExt; use http::{header::HeaderValue, Request, Response, StatusCode}; #[macro_use] extern crate serde_derive; -use serde_json; use std::{env, error::Error, fmt, io}; use tokio::net::{TcpListener, TcpStream}; use tokio::stream::StreamExt; diff --git a/tokio-macros/CHANGELOG.md b/tokio-macros/CHANGELOG.md index d9edc5d2e32..91844627ca3 100644 --- a/tokio-macros/CHANGELOG.md +++ b/tokio-macros/CHANGELOG.md @@ -1,15 +1,15 @@ # 0.2.5 (February 27, 2019) ### Fixed -- doc improvements (#2225). +- doc improvements ([#2225]). # 0.2.4 (January 27, 2019) ### Fixed -- generics on `#[tokio::main]` function (#2177). +- generics on `#[tokio::main]` function ([#2177]). ### Added -- support for `tokio::select!` (#2152). +- support for `tokio::select!` ([#2152]). # 0.2.3 (January 7, 2019) @@ -19,13 +19,20 @@ # 0.2.2 (January 7, 2019) ### Added -- General refactoring and inclusion of additional runtime options (#2022 and #2038) +- General refactoring and inclusion of additional runtime options ([#2022] and [#2038]) # 0.2.1 (December 18, 2019) ### Fixes -- inherit visibility when wrapping async fn (#1954). +- inherit visibility when wrapping async fn ([#1954]). # 0.2.0 (November 26, 2019) - Initial release + +[#2225]: https://github.com/tokio-rs/tokio/pull/2225 +[#2177]: https://github.com/tokio-rs/tokio/pull/2177 +[#2152]: https://github.com/tokio-rs/tokio/pull/2152 +[#2038]: https://github.com/tokio-rs/tokio/pull/2038 +[#2022]: https://github.com/tokio-rs/tokio/pull/2022 +[#1954]: https://github.com/tokio-rs/tokio/pull/1954 diff --git a/tokio-macros/src/entry.rs b/tokio-macros/src/entry.rs index 6a58b791ed8..2681f50d9c0 100644 --- a/tokio-macros/src/entry.rs +++ b/tokio-macros/src/entry.rs @@ -142,7 +142,7 @@ fn parse_knobs( let header = { if is_test { quote! { - #[test] + #[::core::prelude::v1::test] } } else { quote! {} @@ -334,14 +334,14 @@ pub(crate) mod old { let result = match runtime { Runtime::Threaded => quote! { - #[test] + #[::core::prelude::v1::test] #(#attrs)* #vis fn #name() #ret { tokio::runtime::Runtime::new().unwrap().block_on(async { #body }) } }, Runtime::Basic | Runtime::Auto => quote! { - #[test] + #[::core::prelude::v1::test] #(#attrs)* #vis fn #name() #ret { tokio::runtime::Builder::new() diff --git a/tokio-macros/src/lib.rs b/tokio-macros/src/lib.rs index 9fdfb5bd769..b1742d48662 100644 --- a/tokio-macros/src/lib.rs +++ b/tokio-macros/src/lib.rs @@ -55,6 +55,14 @@ use proc_macro::TokenStream; /// println!("Hello world"); /// } /// ``` +/// +/// ### NOTE: +/// +/// If you rename the tokio crate in your dependencies this macro +/// will not work. If you must rename the 0.2 version of tokio because +/// you're also using the 0.1 version of tokio, you _must_ make the +/// tokio 0.2 crate available as `tokio` in the module where this +/// macro is expanded. #[proc_macro_attribute] #[cfg(not(test))] // Work around for rust-lang/rust#62127 pub fn main_threaded(args: TokenStream, item: TokenStream) -> TokenStream { @@ -91,6 +99,14 @@ pub fn main_threaded(args: TokenStream, item: TokenStream) -> TokenStream { /// println!("Hello world"); /// } /// ``` +/// +/// ### NOTE: +/// +/// If you rename the tokio crate in your dependencies this macro +/// will not work. If you must rename the 0.2 version of tokio because +/// you're also using the 0.1 version of tokio, you _must_ make the +/// tokio 0.2 crate available as `tokio` in the module where this +/// macro is expanded. #[proc_macro_attribute] #[cfg(not(test))] // Work around for rust-lang/rust#62127 pub fn main(args: TokenStream, item: TokenStream) -> TokenStream { @@ -117,6 +133,14 @@ pub fn main(args: TokenStream, item: TokenStream) -> TokenStream { /// println!("Hello world"); /// } /// ``` +/// +/// ### NOTE: +/// +/// If you rename the tokio crate in your dependencies this macro +/// will not work. If you must rename the 0.2 version of tokio because +/// you're also using the 0.1 version of tokio, you _must_ make the +/// tokio 0.2 crate available as `tokio` in the module where this +/// macro is expanded. #[proc_macro_attribute] #[cfg(not(test))] // Work around for rust-lang/rust#62127 pub fn main_basic(args: TokenStream, item: TokenStream) -> TokenStream { @@ -149,6 +173,14 @@ pub fn main_basic(args: TokenStream, item: TokenStream) -> TokenStream { /// assert!(true); /// } /// ``` +/// +/// ### NOTE: +/// +/// If you rename the tokio crate in your dependencies this macro +/// will not work. If you must rename the 0.2 version of tokio because +/// you're also using the 0.1 version of tokio, you _must_ make the +/// tokio 0.2 crate available as `tokio` in the module where this +/// macro is expanded. #[proc_macro_attribute] pub fn test_threaded(args: TokenStream, item: TokenStream) -> TokenStream { entry::test(args, item, true) @@ -180,6 +212,14 @@ pub fn test_threaded(args: TokenStream, item: TokenStream) -> TokenStream { /// assert!(true); /// } /// ``` +/// +/// ### NOTE: +/// +/// If you rename the tokio crate in your dependencies this macro +/// will not work. If you must rename the 0.2 version of tokio because +/// you're also using the 0.1 version of tokio, you _must_ make the +/// tokio 0.2 crate available as `tokio` in the module where this +/// macro is expanded. #[proc_macro_attribute] pub fn test(args: TokenStream, item: TokenStream) -> TokenStream { entry::old::test(args, item) @@ -199,6 +239,14 @@ pub fn test(args: TokenStream, item: TokenStream) -> TokenStream { /// assert!(true); /// } /// ``` +/// +/// ### NOTE: +/// +/// If you rename the tokio crate in your dependencies this macro +/// will not work. If you must rename the 0.2 version of tokio because +/// you're also using the 0.1 version of tokio, you _must_ make the +/// tokio 0.2 crate available as `tokio` in the module where this +/// macro is expanded. #[proc_macro_attribute] pub fn test_basic(args: TokenStream, item: TokenStream) -> TokenStream { entry::test(args, item, false) diff --git a/tokio-test/CHANGELOG.md b/tokio-test/CHANGELOG.md index e5a0093d51f..50371c44ba2 100644 --- a/tokio-test/CHANGELOG.md +++ b/tokio-test/CHANGELOG.md @@ -1,3 +1,7 @@ +# 0.2.1 (April 17, 2020) + +- Add `Future` and `Stream` implementations for `task::Spawn`. + # 0.2.0 (November 25, 2019) - Initial release diff --git a/tokio-test/Cargo.toml b/tokio-test/Cargo.toml index a1e60500db2..130035c297d 100644 --- a/tokio-test/Cargo.toml +++ b/tokio-test/Cargo.toml @@ -7,13 +7,13 @@ name = "tokio-test" # - Cargo.toml # - Update CHANGELOG.md. # - Create "v0.2.x" git tag. -version = "0.2.0" +version = "0.2.1" edition = "2018" authors = ["Tokio Contributors "] license = "MIT" repository = "https://github.com/tokio-rs/tokio" homepage = "https://tokio.rs" -documentation = "https://docs.rs/tokio-test/0.2.0/tokio_test" +documentation = "https://docs.rs/tokio-test/0.2.1/tokio_test" description = """ Testing utilities for Tokio- and futures-based code """ diff --git a/tokio-test/src/io.rs b/tokio-test/src/io.rs index 8af6a9f6490..26ef57e47a2 100644 --- a/tokio-test/src/io.rs +++ b/tokio-test/src/io.rs @@ -12,7 +12,7 @@ //! //! # Usage //! -//! Attempting to write data that the mock isn't expected will result in a +//! Attempting to write data that the mock isn't expecting will result in a //! panic. //! //! [`AsyncRead`]: tokio::io::AsyncRead diff --git a/tokio-test/src/lib.rs b/tokio-test/src/lib.rs index c109c148145..185f317b973 100644 --- a/tokio-test/src/lib.rs +++ b/tokio-test/src/lib.rs @@ -1,4 +1,4 @@ -#![doc(html_root_url = "https://docs.rs/tokio-test/0.2.0")] +#![doc(html_root_url = "https://docs.rs/tokio-test/0.2.1")] #![warn( missing_debug_implementations, missing_docs, diff --git a/tokio-test/src/task.rs b/tokio-test/src/task.rs index 04328e3d5a9..728117cc158 100644 --- a/tokio-test/src/task.rs +++ b/tokio-test/src/task.rs @@ -46,21 +46,11 @@ const SLEEP: usize = 2; impl Spawn { /// Consumes `self` returning the inner value - pub fn into_inner(mut self) -> T + pub fn into_inner(self) -> T where T: Unpin, { - drop(self.task); - - // Pin::into_inner is unstable, so we work around it - // - // Safety: `T` is bound by `Unpin`. - unsafe { - let ptr = Pin::get_mut(self.future.as_mut()) as *mut T; - let future = Box::from_raw(ptr); - mem::forget(self.future); - *future - } + *Pin::into_inner(self.future) } /// Returns `true` if the inner future has received a wake notification @@ -116,6 +106,22 @@ impl Spawn { } } +impl Future for Spawn { + type Output = T::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.future.as_mut().poll(cx) + } +} + +impl Stream for Spawn { + type Item = T::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.future.as_mut().poll_next(cx) + } +} + impl MockTask { /// Creates new mock task fn new() -> Self { diff --git a/tokio-tls/CHANGELOG.md b/tokio-tls/CHANGELOG.md deleted file mode 100644 index c1a3e3a7135..00000000000 --- a/tokio-tls/CHANGELOG.md +++ /dev/null @@ -1,36 +0,0 @@ -# 0.3.0 (November 26, 2019) - -- Updates for tokio 0.2 release - -# 0.3.0-alpha.6 (September 30, 2019) - -- Move to `futures-*-preview 0.3.0-alpha.19` -- Move to `pin-project 0.4` - -# 0.3.0-alpha.5 (September 19, 2019) - -### Added -- `TlsStream::get_ref` and `TlsStream::get_mut` (#1537). - -# 0.3.0-alpha.4 (August 30, 2019) - -### Changed -- Track `tokio` 0.2.0-alpha.4 - -# 0.3.0-alpha.2 (August 17, 2019) - -### Changed -- Update `futures` dependency to 0.3.0-alpha.18. - -# 0.3.0-alpha.1 (August 8, 2019) - -### Changed -- Switch to `async`, `await`, and `std::future`. - -# 0.2.1 (January 6, 2019) - -* Implement `Clone` for `TlsConnector` and `TlsAcceptor` (#777) - -# 0.2.0 (August 8, 2018) - -* Initial release with `tokio` support. diff --git a/tokio-tls/Cargo.toml b/tokio-tls/Cargo.toml deleted file mode 100644 index a9877926933..00000000000 --- a/tokio-tls/Cargo.toml +++ /dev/null @@ -1,63 +0,0 @@ -[package] -name = "tokio-tls" -# When releasing to crates.io: -# - Remove path dependencies -# - Update html_root_url. -# - Update doc url -# - Cargo.toml -# - README.md -# - Update CHANGELOG.md. -# - Create "v0.3.x" git tag. -version = "0.3.0" -edition = "2018" -authors = ["Tokio Contributors "] -license = "MIT" -repository = "https://github.com/tokio-rs/tokio" -homepage = "https://tokio.rs" -documentation = "https://docs.rs/tokio-tls/0.3.0-alpha.6/tokio_tls/" -description = """ -An implementation of TLS/SSL streams for Tokio giving an implementation of TLS -for nonblocking I/O streams. -""" -categories = ["asynchronous", "network-programming"] - -[badges] -travis-ci = { repository = "tokio-rs/tokio-tls" } - -[dependencies] -native-tls = "0.2" -tokio = { version = "0.2.0", path = "../tokio" } - -[dev-dependencies] -tokio = { version = "0.2.0", path = "../tokio", features = ["macros", "stream", "rt-core", "io-util", "net"] } -tokio-util = { version = "0.3.0", path = "../tokio-util", features = ["full"] } - -cfg-if = "0.1" -env_logger = { version = "0.6", default-features = false } -futures = { version = "0.3.0", features = ["async-await"] } - -[target.'cfg(all(not(target_os = "macos"), not(windows), not(target_os = "ios")))'.dev-dependencies] -openssl = "0.10" - -[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dev-dependencies] -security-framework = "0.2" - -[target.'cfg(windows)'.dev-dependencies] -schannel = "0.1" - -[target.'cfg(windows)'.dev-dependencies.winapi] -version = "0.3" -features = [ - "lmcons", - "basetsd", - "minwinbase", - "minwindef", - "ntdef", - "sysinfoapi", - "timezoneapi", - "wincrypt", - "winerror", -] - -[package.metadata.docs.rs] -all-features = true diff --git a/tokio-tls/LICENSE b/tokio-tls/LICENSE deleted file mode 100644 index cdb28b4b56a..00000000000 --- a/tokio-tls/LICENSE +++ /dev/null @@ -1,25 +0,0 @@ -Copyright (c) 2019 Tokio Contributors - -Permission is hereby granted, free of charge, to any -person obtaining a copy of this software and associated -documentation files (the "Software"), to deal in the -Software without restriction, including without -limitation the rights to use, copy, modify, merge, -publish, distribute, sublicense, and/or sell copies of -the Software, and to permit persons to whom the Software -is furnished to do so, subject to the following -conditions: - -The above copyright notice and this permission notice -shall be included in all copies or substantial portions -of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF -ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED -TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A -PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. diff --git a/tokio-tls/README.md b/tokio-tls/README.md deleted file mode 100644 index 455612be8b8..00000000000 --- a/tokio-tls/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# tokio-tls - -An implementation of TLS/SSL streams for Tokio built on top of the [`native-tls` -crate] - -## License - -This project is licensed under the [MIT license](./LICENSE). - -### Contribution - -Unless you explicitly state otherwise, any contribution intentionally submitted -for inclusion in Tokio by you, shall be licensed as MIT, without any additional -terms or conditions. diff --git a/tokio-tls/examples/download-rust-lang.rs b/tokio-tls/examples/download-rust-lang.rs deleted file mode 100644 index 324c077539a..00000000000 --- a/tokio-tls/examples/download-rust-lang.rs +++ /dev/null @@ -1,40 +0,0 @@ -// #![warn(rust_2018_idioms)] - -use native_tls::TlsConnector; -use std::error::Error; -use std::net::ToSocketAddrs; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpStream; -use tokio_tls; - -#[tokio::main] -async fn main() -> Result<(), Box> { - let addr = "www.rust-lang.org:443" - .to_socket_addrs()? - .next() - .ok_or("failed to resolve www.rust-lang.org")?; - - let socket = TcpStream::connect(&addr).await?; - let cx = TlsConnector::builder().build()?; - let cx = tokio_tls::TlsConnector::from(cx); - - let mut socket = cx.connect("www.rust-lang.org", socket).await?; - - socket - .write_all( - "\ - GET / HTTP/1.0\r\n\ - Host: www.rust-lang.org\r\n\ - \r\n\ - " - .as_bytes(), - ) - .await?; - - let mut data = Vec::new(); - socket.read_to_end(&mut data).await?; - - // println!("data: {:?}", &data); - println!("{}", String::from_utf8_lossy(&data[..])); - Ok(()) -} diff --git a/tokio-tls/examples/identity.p12 b/tokio-tls/examples/identity.p12 deleted file mode 100644 index d16abb8c706..00000000000 Binary files a/tokio-tls/examples/identity.p12 and /dev/null differ diff --git a/tokio-tls/examples/tls-echo.rs b/tokio-tls/examples/tls-echo.rs deleted file mode 100644 index 96309567403..00000000000 --- a/tokio-tls/examples/tls-echo.rs +++ /dev/null @@ -1,55 +0,0 @@ -#![warn(rust_2018_idioms)] - -// A tiny async TLS echo server with Tokio -use native_tls; -use native_tls::Identity; -use tokio; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpListener; -use tokio_tls; - -/** -an example to setup a tls server. -how to test: -wget https://127.0.0.1:12345 --no-check-certificate -*/ -#[tokio::main] -async fn main() -> Result<(), Box> { - // Bind the server's socket - let addr = "127.0.0.1:12345".to_string(); - let mut tcp: TcpListener = TcpListener::bind(&addr).await?; - - // Create the TLS acceptor. - let der = include_bytes!("identity.p12"); - let cert = Identity::from_pkcs12(der, "mypass")?; - let tls_acceptor = - tokio_tls::TlsAcceptor::from(native_tls::TlsAcceptor::builder(cert).build()?); - loop { - // Asynchronously wait for an inbound socket. - let (socket, remote_addr) = tcp.accept().await?; - let tls_acceptor = tls_acceptor.clone(); - println!("accept connection from {}", remote_addr); - tokio::spawn(async move { - // Accept the TLS connection. - let mut tls_stream = tls_acceptor.accept(socket).await.expect("accept error"); - // In a loop, read data from the socket and write the data back. - - let mut buf = [0; 1024]; - let n = tls_stream - .read(&mut buf) - .await - .expect("failed to read data from socket"); - - if n == 0 { - return; - } - println!("read={}", unsafe { - String::from_utf8_unchecked(buf[0..n].into()) - }); - tls_stream - .write_all(&buf[0..n]) - .await - .expect("failed to write data to socket"); - }); - } -} diff --git a/tokio-tls/src/lib.rs b/tokio-tls/src/lib.rs deleted file mode 100644 index 2770650934b..00000000000 --- a/tokio-tls/src/lib.rs +++ /dev/null @@ -1,361 +0,0 @@ -#![doc(html_root_url = "https://docs.rs/tokio-tls/0.3.0")] -#![warn( - missing_debug_implementations, - missing_docs, - rust_2018_idioms, - unreachable_pub -)] -#![deny(intra_doc_link_resolution_failure)] -#![doc(test( - no_crate_inject, - attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) -))] - -//! Async TLS streams -//! -//! This library is an implementation of TLS streams using the most appropriate -//! system library by default for negotiating the connection. That is, on -//! Windows this library uses SChannel, on OSX it uses SecureTransport, and on -//! other platforms it uses OpenSSL. -//! -//! Each TLS stream implements the `Read` and `Write` traits to interact and -//! interoperate with the rest of the futures I/O ecosystem. Client connections -//! initiated from this crate verify hostnames automatically and by default. -//! -//! This crate primarily exports this ability through two newtypes, -//! `TlsConnector` and `TlsAcceptor`. These newtypes augment the -//! functionality provided by the `native-tls` crate, on which this crate is -//! built. Configuration of TLS parameters is still primarily done through the -//! `native-tls` crate. - -use tokio::io::{AsyncRead, AsyncWrite}; - -use native_tls::{Error, HandshakeError, MidHandshakeTlsStream}; -use std::fmt; -use std::future::Future; -use std::io::{self, Read, Write}; -use std::marker::Unpin; -use std::mem::MaybeUninit; -use std::pin::Pin; -use std::ptr::null_mut; -use std::task::{Context, Poll}; - -#[derive(Debug)] -struct AllowStd { - inner: S, - context: *mut (), -} - -/// A wrapper around an underlying raw stream which implements the TLS or SSL -/// protocol. -/// -/// A `TlsStream` represents a handshake that has been completed successfully -/// and both the server and the client are ready for receiving and sending -/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written -/// to a `TlsStream` are encrypted when passing through to `S`. -#[derive(Debug)] -pub struct TlsStream(native_tls::TlsStream>); - -/// A wrapper around a `native_tls::TlsConnector`, providing an async `connect` -/// method. -#[derive(Clone)] -pub struct TlsConnector(native_tls::TlsConnector); - -/// A wrapper around a `native_tls::TlsAcceptor`, providing an async `accept` -/// method. -#[derive(Clone)] -pub struct TlsAcceptor(native_tls::TlsAcceptor); - -struct MidHandshake(Option>>); - -enum StartedHandshake { - Done(TlsStream), - Mid(MidHandshakeTlsStream>), -} - -struct StartedHandshakeFuture(Option>); -struct StartedHandshakeFutureInner { - f: F, - stream: S, -} - -struct Guard<'a, S>(&'a mut TlsStream) -where - AllowStd: Read + Write; - -impl Drop for Guard<'_, S> -where - AllowStd: Read + Write, -{ - fn drop(&mut self) { - (self.0).0.get_mut().context = null_mut(); - } -} - -// *mut () context is neither Send nor Sync -unsafe impl Send for AllowStd {} -unsafe impl Sync for AllowStd {} - -impl AllowStd -where - S: Unpin, -{ - fn with_context(&mut self, f: F) -> R - where - F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R, - { - unsafe { - assert!(!self.context.is_null()); - let waker = &mut *(self.context as *mut _); - f(waker, Pin::new(&mut self.inner)) - } - } -} - -impl Read for AllowStd -where - S: AsyncRead + Unpin, -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self.with_context(|ctx, stream| stream.poll_read(ctx, buf)) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } -} - -impl Write for AllowStd -where - S: AsyncWrite + Unpin, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self.with_context(|ctx, stream| stream.poll_flush(ctx)) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } -} - -fn cvt(r: io::Result) -> Poll> { - match r { - Ok(v) => Poll::Ready(Ok(v)), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - Err(e) => Poll::Ready(Err(e)), - } -} - -impl TlsStream { - fn with_context(&mut self, ctx: &mut Context<'_>, f: F) -> R - where - F: FnOnce(&mut native_tls::TlsStream>) -> R, - AllowStd: Read + Write, - { - self.0.get_mut().context = ctx as *mut _ as *mut (); - let g = Guard(self); - f(&mut (g.0).0) - } - - /// Returns a shared reference to the inner stream. - pub fn get_ref(&self) -> &S - where - S: AsyncRead + AsyncWrite + Unpin, - { - &self.0.get_ref().inner - } - - /// Returns a mutable reference to the inner stream. - pub fn get_mut(&mut self) -> &mut S - where - S: AsyncRead + AsyncWrite + Unpin, - { - &mut self.0.get_mut().inner - } -} - -impl AsyncRead for TlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { - // Note that this does not forward to `S` because the buffer is - // unconditionally filled in by OpenSSL, not the actual object `S`. - // We're decrypting bytes from `S` into the buffer above! - false - } - - fn poll_read( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - self.with_context(ctx, |s| cvt(s.read(buf))) - } -} - -impl AsyncWrite for TlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_write( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.with_context(ctx, |s| cvt(s.write(buf))) - } - - fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { - self.with_context(ctx, |s| cvt(s.flush())) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { - match self.with_context(ctx, |s| s.shutdown()) { - Ok(()) => Poll::Ready(Ok(())), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - Err(e) => Poll::Ready(Err(e)), - } - } -} - -async fn handshake(f: F, stream: S) -> Result, Error> -where - F: FnOnce( - AllowStd, - ) -> Result>, HandshakeError>> - + Unpin, - S: AsyncRead + AsyncWrite + Unpin, -{ - let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream })); - - match start.await { - Err(e) => Err(e), - Ok(StartedHandshake::Done(s)) => Ok(s), - Ok(StartedHandshake::Mid(s)) => MidHandshake(Some(s)).await, - } -} - -impl Future for StartedHandshakeFuture -where - F: FnOnce( - AllowStd, - ) -> Result>, HandshakeError>> - + Unpin, - S: Unpin, - AllowStd: Read + Write, -{ - type Output = Result, Error>; - - fn poll( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - ) -> Poll, Error>> { - let inner = self.0.take().expect("future polled after completion"); - let stream = AllowStd { - inner: inner.stream, - context: ctx as *mut _ as *mut (), - }; - - match (inner.f)(stream) { - Ok(mut s) => { - s.get_mut().context = null_mut(); - Poll::Ready(Ok(StartedHandshake::Done(TlsStream(s)))) - } - Err(HandshakeError::WouldBlock(mut s)) => { - s.get_mut().context = null_mut(); - Poll::Ready(Ok(StartedHandshake::Mid(s))) - } - Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), - } - } -} - -impl TlsConnector { - /// Connects the provided stream with this connector, assuming the provided - /// domain. - /// - /// This function will internally call `TlsConnector::connect` to connect - /// the stream and returns a future representing the resolution of the - /// connection operation. The returned future will resolve to either - /// `TlsStream` or `Error` depending if it's successful or not. - /// - /// This is typically used for clients who have already established, for - /// example, a TCP connection to a remote server. That stream is then - /// provided here to perform the client half of a connection to a - /// TLS-powered server. - pub async fn connect(&self, domain: &str, stream: S) -> Result, Error> - where - S: AsyncRead + AsyncWrite + Unpin, - { - handshake(move |s| self.0.connect(domain, s), stream).await - } -} - -impl fmt::Debug for TlsConnector { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TlsConnector").finish() - } -} - -impl From for TlsConnector { - fn from(inner: native_tls::TlsConnector) -> TlsConnector { - TlsConnector(inner) - } -} - -impl TlsAcceptor { - /// Accepts a new client connection with the provided stream. - /// - /// This function will internally call `TlsAcceptor::accept` to connect - /// the stream and returns a future representing the resolution of the - /// connection operation. The returned future will resolve to either - /// `TlsStream` or `Error` depending if it's successful or not. - /// - /// This is typically used after a new socket has been accepted from a - /// `TcpListener`. That socket is then passed to this function to perform - /// the server half of accepting a client connection. - pub async fn accept(&self, stream: S) -> Result, Error> - where - S: AsyncRead + AsyncWrite + Unpin, - { - handshake(move |s| self.0.accept(s), stream).await - } -} - -impl fmt::Debug for TlsAcceptor { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TlsAcceptor").finish() - } -} - -impl From for TlsAcceptor { - fn from(inner: native_tls::TlsAcceptor) -> TlsAcceptor { - TlsAcceptor(inner) - } -} - -impl Future for MidHandshake { - type Output = Result, Error>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut_self = self.get_mut(); - let mut s = mut_self.0.take().expect("future polled after completion"); - - s.get_mut().context = cx as *mut _ as *mut (); - match s.handshake() { - Ok(stream) => Poll::Ready(Ok(TlsStream(stream))), - Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), - Err(HandshakeError::WouldBlock(mut s)) => { - s.get_mut().context = null_mut(); - mut_self.0 = Some(s); - Poll::Pending - } - } - } -} diff --git a/tokio-tls/tests/bad.rs b/tokio-tls/tests/bad.rs deleted file mode 100644 index 87d4ca2708e..00000000000 --- a/tokio-tls/tests/bad.rs +++ /dev/null @@ -1,124 +0,0 @@ -#![warn(rust_2018_idioms)] - -use cfg_if::cfg_if; -use env_logger; -use native_tls::TlsConnector; -use std::io::{self, Error}; -use std::net::ToSocketAddrs; -use tokio::net::TcpStream; -use tokio_tls; - -macro_rules! t { - ($e:expr) => { - match $e { - Ok(e) => e, - Err(e) => panic!("{} failed with {:?}", stringify!($e), e), - } - }; -} - -cfg_if! { - if #[cfg(feature = "force-rustls")] { - fn verify_failed(err: &Error, s: &str) { - let err = err.to_string(); - assert!(err.contains(s), "bad error: {}", err); - } - - fn assert_expired_error(err: &Error) { - verify_failed(err, "CertExpired"); - } - - fn assert_wrong_host(err: &Error) { - verify_failed(err, "CertNotValidForName"); - } - - fn assert_self_signed(err: &Error) { - verify_failed(err, "UnknownIssuer"); - } - - fn assert_untrusted_root(err: &Error) { - verify_failed(err, "UnknownIssuer"); - } - } else if #[cfg(any(feature = "force-openssl", - all(not(target_os = "macos"), - not(target_os = "windows"), - not(target_os = "ios"))))] { - fn verify_failed(err: &Error) { - assert!(format!("{}", err).contains("certificate verify failed")) - } - - use verify_failed as assert_expired_error; - use verify_failed as assert_wrong_host; - use verify_failed as assert_self_signed; - use verify_failed as assert_untrusted_root; - } else if #[cfg(any(target_os = "macos", target_os = "ios"))] { - - fn assert_invalid_cert_chain(err: &Error) { - assert!(format!("{}", err).contains("was not trusted.")) - } - - use crate::assert_invalid_cert_chain as assert_expired_error; - use crate::assert_invalid_cert_chain as assert_wrong_host; - use crate::assert_invalid_cert_chain as assert_self_signed; - use crate::assert_invalid_cert_chain as assert_untrusted_root; - } else { - fn assert_expired_error(err: &Error) { - let s = err.to_string(); - assert!(s.contains("system clock"), "error = {:?}", s); - } - - fn assert_wrong_host(err: &Error) { - let s = err.to_string(); - assert!(s.contains("CN name"), "error = {:?}", s); - } - - fn assert_self_signed(err: &Error) { - let s = err.to_string(); - assert!(s.contains("root certificate which is not trusted"), "error = {:?}", s); - } - - use assert_self_signed as assert_untrusted_root; - } -} - -async fn get_host(host: &'static str) -> Error { - drop(env_logger::try_init()); - - let addr = format!("{}:443", host); - let addr = t!(addr.to_socket_addrs()).next().unwrap(); - - let socket = t!(TcpStream::connect(&addr).await); - let builder = TlsConnector::builder(); - let cx = t!(builder.build()); - let cx = tokio_tls::TlsConnector::from(cx); - let res = cx - .connect(host, socket) - .await - .map_err(|e| Error::new(io::ErrorKind::Other, e)); - - assert!(res.is_err()); - res.err().unwrap() -} - -#[tokio::test] -async fn expired() { - assert_expired_error(&get_host("expired.badssl.com").await) -} - -// TODO: the OSX builders on Travis apparently fail this tests spuriously? -// passes locally though? Seems... bad! -#[tokio::test] -#[cfg_attr(all(target_os = "macos", feature = "force-openssl"), ignore)] -async fn wrong_host() { - assert_wrong_host(&get_host("wrong.host.badssl.com").await) -} - -#[tokio::test] -async fn self_signed() { - assert_self_signed(&get_host("self-signed.badssl.com").await) -} - -#[tokio::test] -async fn untrusted_root() { - assert_untrusted_root(&get_host("untrusted-root.badssl.com").await) -} diff --git a/tokio-tls/tests/google.rs b/tokio-tls/tests/google.rs deleted file mode 100644 index 13b78d31831..00000000000 --- a/tokio-tls/tests/google.rs +++ /dev/null @@ -1,102 +0,0 @@ -#![warn(rust_2018_idioms)] - -use cfg_if::cfg_if; -use env_logger; -use native_tls; -use native_tls::TlsConnector; -use std::io; -use std::net::ToSocketAddrs; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpStream; -use tokio_tls; - -macro_rules! t { - ($e:expr) => { - match $e { - Ok(e) => e, - Err(e) => panic!("{} failed with {:?}", stringify!($e), e), - } - }; -} - -cfg_if! { - if #[cfg(feature = "force-rustls")] { - fn assert_bad_hostname_error(err: &io::Error) { - let err = err.to_string(); - assert!(err.contains("CertNotValidForName"), "bad error: {}", err); - } - } else if #[cfg(any(feature = "force-openssl", - all(not(target_os = "macos"), - not(target_os = "windows"), - not(target_os = "ios"))))] { - fn assert_bad_hostname_error(err: &io::Error) { - let err = err.get_ref().unwrap(); - let err = err.downcast_ref::().unwrap(); - assert!(format!("{}", err).contains("certificate verify failed")); - } - } else if #[cfg(any(target_os = "macos", target_os = "ios"))] { - fn assert_bad_hostname_error(err: &io::Error) { - let err = err.get_ref().unwrap(); - let err = err.downcast_ref::().unwrap(); - assert!(format!("{}", err).contains("was not trusted.")); - } - } else { - fn assert_bad_hostname_error(err: &io::Error) { - let err = err.get_ref().unwrap(); - let err = err.downcast_ref::().unwrap(); - assert!(format!("{}", err).contains("CN name")); - } - } -} - -#[tokio::test] -async fn fetch_google() { - drop(env_logger::try_init()); - - // First up, resolve google.com - let addr = t!("google.com:443".to_socket_addrs()).next().unwrap(); - - let socket = TcpStream::connect(&addr).await.unwrap(); - - // Send off the request by first negotiating an SSL handshake, then writing - // of our request, then flushing, then finally read off the response. - let builder = TlsConnector::builder(); - let connector = t!(builder.build()); - let connector = tokio_tls::TlsConnector::from(connector); - let mut socket = t!(connector.connect("google.com", socket).await); - t!(socket.write_all(b"GET / HTTP/1.0\r\n\r\n").await); - let mut data = Vec::new(); - t!(socket.read_to_end(&mut data).await); - - // any response code is fine - assert!(data.starts_with(b"HTTP/1.0 ")); - - let data = String::from_utf8_lossy(&data); - let data = data.trim_end(); - assert!(data.ends_with("") || data.ends_with("")); -} - -fn native2io(e: native_tls::Error) -> io::Error { - io::Error::new(io::ErrorKind::Other, e) -} - -// see comment in bad.rs for ignore reason -#[cfg_attr(all(target_os = "macos", feature = "force-openssl"), ignore)] -#[tokio::test] -async fn wrong_hostname_error() { - drop(env_logger::try_init()); - - let addr = t!("google.com:443".to_socket_addrs()).next().unwrap(); - - let socket = t!(TcpStream::connect(&addr).await); - let builder = TlsConnector::builder(); - let connector = t!(builder.build()); - let connector = tokio_tls::TlsConnector::from(connector); - let res = connector - .connect("rust-lang.org", socket) - .await - .map_err(native2io); - - assert!(res.is_err()); - assert_bad_hostname_error(&res.err().unwrap()); -} diff --git a/tokio-tls/tests/smoke.rs b/tokio-tls/tests/smoke.rs deleted file mode 100644 index 8788dd6d467..00000000000 --- a/tokio-tls/tests/smoke.rs +++ /dev/null @@ -1,629 +0,0 @@ -#![warn(rust_2018_idioms)] - -use cfg_if::cfg_if; -use env_logger; -use futures::join; -use native_tls; -use native_tls::{Identity, TlsAcceptor, TlsConnector}; -use std::io::Write; -use std::marker::Unpin; -use std::process::Command; -use std::ptr; -use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, Error, ErrorKind}; -use tokio::net::{TcpListener, TcpStream}; -use tokio::stream::StreamExt; -use tokio_tls; - -macro_rules! t { - ($e:expr) => { - match $e { - Ok(e) => e, - Err(e) => panic!("{} failed with {:?}", stringify!($e), e), - } - }; -} - -#[allow(dead_code)] -struct Keys { - cert_der: Vec, - pkey_der: Vec, - pkcs12_der: Vec, -} - -#[allow(dead_code)] -fn openssl_keys() -> &'static Keys { - static INIT: Once = Once::new(); - static mut KEYS: *mut Keys = ptr::null_mut(); - - INIT.call_once(|| { - let path = t!(env::current_exe()); - let path = path.parent().unwrap(); - let keyfile = path.join("test.key"); - let certfile = path.join("test.crt"); - let config = path.join("openssl.config"); - - File::create(&config) - .unwrap() - .write_all( - b"\ - [req]\n\ - distinguished_name=dn\n\ - [ dn ]\n\ - CN=localhost\n\ - [ ext ]\n\ - basicConstraints=CA:FALSE,pathlen:0\n\ - subjectAltName = @alt_names - extendedKeyUsage=serverAuth,clientAuth - [alt_names] - DNS.1 = localhost - ", - ) - .unwrap(); - - let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost"; - let output = t!(Command::new("openssl") - .arg("req") - .arg("-nodes") - .arg("-x509") - .arg("-newkey") - .arg("rsa:2048") - .arg("-config") - .arg(&config) - .arg("-extensions") - .arg("ext") - .arg("-subj") - .arg(subj) - .arg("-keyout") - .arg(&keyfile) - .arg("-out") - .arg(&certfile) - .arg("-days") - .arg("1") - .output()); - assert!(output.status.success()); - - let crtout = t!(Command::new("openssl") - .arg("x509") - .arg("-outform") - .arg("der") - .arg("-in") - .arg(&certfile) - .output()); - assert!(crtout.status.success()); - let keyout = t!(Command::new("openssl") - .arg("rsa") - .arg("-outform") - .arg("der") - .arg("-in") - .arg(&keyfile) - .output()); - assert!(keyout.status.success()); - - let pkcs12out = t!(Command::new("openssl") - .arg("pkcs12") - .arg("-export") - .arg("-nodes") - .arg("-inkey") - .arg(&keyfile) - .arg("-in") - .arg(&certfile) - .arg("-password") - .arg("pass:foobar") - .output()); - assert!(pkcs12out.status.success()); - - let keys = Box::new(Keys { - cert_der: crtout.stdout, - pkey_der: keyout.stdout, - pkcs12_der: pkcs12out.stdout, - }); - unsafe { - KEYS = Box::into_raw(keys); - } - }); - unsafe { &*KEYS } -} - -cfg_if! { - if #[cfg(feature = "rustls")] { - use webpki; - use untrusted; - use std::env; - use std::fs::File; - use std::process::Command; - use std::sync::Once; - - use untrusted::Input; - use webpki::trust_anchor_util; - - fn server_cx() -> io::Result { - let mut cx = ServerContext::new(); - - let (cert, key) = keys(); - cx.config_mut() - .set_single_cert(vec![cert.to_vec()], key.to_vec()); - - Ok(cx) - } - - fn configure_client(cx: &mut ClientContext) { - let (cert, _key) = keys(); - let cert = Input::from(cert); - let anchor = trust_anchor_util::cert_der_as_trust_anchor(cert).unwrap(); - cx.config_mut().root_store.add_trust_anchors(&[anchor]); - } - - // Like OpenSSL we generate certificates on the fly, but for OSX we - // also have to put them into a specific keychain. We put both the - // certificates and the keychain next to our binary. - // - // Right now I don't know of a way to programmatically create a - // self-signed certificate, so we just fork out to the `openssl` binary. - fn keys() -> (&'static [u8], &'static [u8]) { - static INIT: Once = Once::new(); - static mut KEYS: *mut (Vec, Vec) = ptr::null_mut(); - - INIT.call_once(|| { - let (key, cert) = openssl_keys(); - let path = t!(env::current_exe()); - let path = path.parent().unwrap(); - let keyfile = path.join("test.key"); - let certfile = path.join("test.crt"); - let config = path.join("openssl.config"); - - File::create(&config).unwrap().write_all(b"\ - [req]\n\ - distinguished_name=dn\n\ - [ dn ]\n\ - CN=localhost\n\ - [ ext ]\n\ - basicConstraints=CA:FALSE,pathlen:0\n\ - subjectAltName = @alt_names - [alt_names] - DNS.1 = localhost - ").unwrap(); - - let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost"; - let output = t!(Command::new("openssl") - .arg("req") - .arg("-nodes") - .arg("-x509") - .arg("-newkey").arg("rsa:2048") - .arg("-config").arg(&config) - .arg("-extensions").arg("ext") - .arg("-subj").arg(subj) - .arg("-keyout").arg(&keyfile) - .arg("-out").arg(&certfile) - .arg("-days").arg("1") - .output()); - assert!(output.status.success()); - - let crtout = t!(Command::new("openssl") - .arg("x509") - .arg("-outform").arg("der") - .arg("-in").arg(&certfile) - .output()); - assert!(crtout.status.success()); - let keyout = t!(Command::new("openssl") - .arg("rsa") - .arg("-outform").arg("der") - .arg("-in").arg(&keyfile) - .output()); - assert!(keyout.status.success()); - - let cert = crtout.stdout; - let key = keyout.stdout; - unsafe { - KEYS = Box::into_raw(Box::new((cert, key))); - } - }); - unsafe { - (&(*KEYS).0, &(*KEYS).1) - } - } - } else if #[cfg(any(feature = "force-openssl", - all(not(target_os = "macos"), - not(target_os = "windows"), - not(target_os = "ios"))))] { - use std::fs::File; - use std::env; - use std::sync::Once; - - fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) { - let keys = openssl_keys(); - - let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar")); - let srv = TlsAcceptor::builder(pkcs12); - - let cert = t!(native_tls::Certificate::from_der(&keys.cert_der)); - - let mut client = TlsConnector::builder(); - t!(client.add_root_certificate(cert).build()); - - (t!(srv.build()).into(), t!(client.build()).into()) - } - } else if #[cfg(any(target_os = "macos", target_os = "ios"))] { - use std::env; - use std::fs::File; - use std::sync::Once; - - fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) { - let keys = openssl_keys(); - - let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar")); - let srv = TlsAcceptor::builder(pkcs12); - - let cert = native_tls::Certificate::from_der(&keys.cert_der).unwrap(); - let mut client = TlsConnector::builder(); - client.add_root_certificate(cert); - - (t!(srv.build()).into(), t!(client.build()).into()) - } - } else { - use schannel; - use winapi; - - use std::env; - use std::fs::File; - use std::io; - use std::mem; - use std::sync::Once; - - use schannel::cert_context::CertContext; - use schannel::cert_store::{CertStore, CertAdd, Memory}; - use winapi::shared::basetsd::*; - use winapi::shared::lmcons::*; - use winapi::shared::minwindef::*; - use winapi::shared::ntdef::WCHAR; - use winapi::um::minwinbase::*; - use winapi::um::sysinfoapi::*; - use winapi::um::timezoneapi::*; - use winapi::um::wincrypt::*; - - const FRIENDLY_NAME: &str = "tokio-tls localhost testing cert"; - - fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) { - let cert = localhost_cert(); - let mut store = t!(Memory::new()).into_store(); - t!(store.add_cert(&cert, CertAdd::Always)); - let pkcs12_der = t!(store.export_pkcs12("foobar")); - let pkcs12 = t!(Identity::from_pkcs12(&pkcs12_der, "foobar")); - - let srv = TlsAcceptor::builder(pkcs12); - let client = TlsConnector::builder(); - (t!(srv.build()).into(), t!(client.build()).into()) - } - - // ==================================================================== - // Magic! - // - // Lots of magic is happening here to wrangle certificates for running - // these tests on Windows. For more information see the test suite - // in the schannel-rs crate as this is just coyping that. - // - // The general gist of this though is that the only way to add custom - // trusted certificates is to add it to the system store of trust. To - // do that we go through the whole rigamarole here to generate a new - // self-signed certificate and then insert that into the system store. - // - // This generates some dialogs, so we print what we're doing sometimes, - // and otherwise we just manage the ephemeral certificates. Because - // they're in the system store we always ensure that they're only valid - // for a small period of time (e.g. 1 day). - - fn localhost_cert() -> CertContext { - static INIT: Once = Once::new(); - INIT.call_once(|| { - for cert in local_root_store().certs() { - let name = match cert.friendly_name() { - Ok(name) => name, - Err(_) => continue, - }; - if name != FRIENDLY_NAME { - continue - } - if !cert.is_time_valid().unwrap() { - io::stdout().write_all(br#" - -The tokio-tls test suite is about to delete an old copy of one of its -certificates from your root trust store. This certificate was only valid for one -day and it is no longer needed. The host should be "localhost" and the -description should mention "tokio-tls". - - "#).unwrap(); - cert.delete().unwrap(); - } else { - return - } - } - - install_certificate().unwrap(); - }); - - for cert in local_root_store().certs() { - let name = match cert.friendly_name() { - Ok(name) => name, - Err(_) => continue, - }; - if name == FRIENDLY_NAME { - return cert - } - } - - panic!("couldn't find a cert"); - } - - fn local_root_store() -> CertStore { - if env::var("CI").is_ok() { - CertStore::open_local_machine("Root").unwrap() - } else { - CertStore::open_current_user("Root").unwrap() - } - } - - fn install_certificate() -> io::Result { - unsafe { - let mut provider = 0; - let mut hkey = 0; - - let mut buffer = "tokio-tls test suite".encode_utf16() - .chain(Some(0)) - .collect::>(); - let res = CryptAcquireContextW(&mut provider, - buffer.as_ptr(), - ptr::null_mut(), - PROV_RSA_FULL, - CRYPT_MACHINE_KEYSET); - if res != TRUE { - // create a new key container (since it does not exist) - let res = CryptAcquireContextW(&mut provider, - buffer.as_ptr(), - ptr::null_mut(), - PROV_RSA_FULL, - CRYPT_NEWKEYSET | CRYPT_MACHINE_KEYSET); - if res != TRUE { - return Err(Error::last_os_error()) - } - } - - // create a new keypair (RSA-2048) - let res = CryptGenKey(provider, - AT_SIGNATURE, - 0x0800<<16 | CRYPT_EXPORTABLE, - &mut hkey); - if res != TRUE { - return Err(Error::last_os_error()); - } - - // start creating the certificate - let name = "CN=localhost,O=tokio-tls,OU=tokio-tls,\ - G=tokio_tls".encode_utf16() - .chain(Some(0)) - .collect::>(); - let mut cname_buffer: [WCHAR; UNLEN as usize + 1] = mem::zeroed(); - let mut cname_len = cname_buffer.len() as DWORD; - let res = CertStrToNameW(X509_ASN_ENCODING, - name.as_ptr(), - CERT_X500_NAME_STR, - ptr::null_mut(), - cname_buffer.as_mut_ptr() as *mut u8, - &mut cname_len, - ptr::null_mut()); - if res != TRUE { - return Err(Error::last_os_error()); - } - - let mut subject_issuer = CERT_NAME_BLOB { - cbData: cname_len, - pbData: cname_buffer.as_ptr() as *mut u8, - }; - let mut key_provider = CRYPT_KEY_PROV_INFO { - pwszContainerName: buffer.as_mut_ptr(), - pwszProvName: ptr::null_mut(), - dwProvType: PROV_RSA_FULL, - dwFlags: CRYPT_MACHINE_KEYSET, - cProvParam: 0, - rgProvParam: ptr::null_mut(), - dwKeySpec: AT_SIGNATURE, - }; - let mut sig_algorithm = CRYPT_ALGORITHM_IDENTIFIER { - pszObjId: szOID_RSA_SHA256RSA.as_ptr() as *mut _, - Parameters: mem::zeroed(), - }; - let mut expiration_date: SYSTEMTIME = mem::zeroed(); - GetSystemTime(&mut expiration_date); - let mut file_time: FILETIME = mem::zeroed(); - let res = SystemTimeToFileTime(&expiration_date, - &mut file_time); - if res != TRUE { - return Err(Error::last_os_error()); - } - let mut timestamp: u64 = file_time.dwLowDateTime as u64 | - (file_time.dwHighDateTime as u64) << 32; - // one day, timestamp unit is in 100 nanosecond intervals - timestamp += (1E9 as u64) / 100 * (60 * 60 * 24); - file_time.dwLowDateTime = timestamp as u32; - file_time.dwHighDateTime = (timestamp >> 32) as u32; - let res = FileTimeToSystemTime(&file_time, - &mut expiration_date); - if res != TRUE { - return Err(Error::last_os_error()); - } - - // create a self signed certificate - let cert_context = CertCreateSelfSignCertificate( - 0 as ULONG_PTR, - &mut subject_issuer, - 0, - &mut key_provider, - &mut sig_algorithm, - ptr::null_mut(), - &mut expiration_date, - ptr::null_mut()); - if cert_context.is_null() { - return Err(Error::last_os_error()); - } - - // TODO: this is.. a terrible hack. Right now `schannel` - // doesn't provide a public method to go from a raw - // cert context pointer to the `CertContext` structure it - // has, so we just fake it here with a transmute. This'll - // probably break at some point, but hopefully by then - // it'll have a method to do this! - struct MyCertContext(T); - impl Drop for MyCertContext { - fn drop(&mut self) {} - } - - let cert_context = MyCertContext(cert_context); - let cert_context: CertContext = mem::transmute(cert_context); - - cert_context.set_friendly_name(FRIENDLY_NAME)?; - - // install the certificate to the machine's local store - io::stdout().write_all(br#" - -The tokio-tls test suite is about to add a certificate to your set of root -and trusted certificates. This certificate should be for the domain "localhost" -with the description related to "tokio-tls". This certificate is only valid -for one day and will be automatically deleted if you re-run the tokio-tls -test suite later. - - "#).unwrap(); - local_root_store().add_cert(&cert_context, - CertAdd::ReplaceExisting)?; - Ok(cert_context) - } - } - } -} - -const AMT: usize = 128 * 1024; - -async fn copy_data(mut w: W) -> Result { - let mut data = vec![9; AMT as usize]; - let mut amt = 0; - while !data.is_empty() { - let written = w.write(&data).await?; - if written <= data.len() { - amt += written; - data.resize(data.len() - written, 0); - } else { - w.write_all(&data).await?; - amt += data.len(); - break; - } - - println!("remaining: {}", data.len()); - } - Ok(amt) -} - -#[tokio::test] -async fn client_to_server() { - drop(env_logger::try_init()); - - // Create a server listening on a port, then figure out what that port is - let mut srv = t!(TcpListener::bind("127.0.0.1:0").await); - let addr = t!(srv.local_addr()); - - let (server_cx, client_cx) = contexts(); - - // Create a future to accept one socket, connect the ssl stream, and then - // read all the data from it. - let server = async move { - let mut incoming = srv.incoming(); - let socket = t!(incoming.next().await.unwrap()); - let mut socket = t!(server_cx.accept(socket).await); - let mut data = Vec::new(); - t!(socket.read_to_end(&mut data).await); - data - }; - - // Create a future to connect to our server, connect the ssl stream, and - // then write a bunch of data to it. - let client = async move { - let socket = t!(TcpStream::connect(&addr).await); - let socket = t!(client_cx.connect("localhost", socket).await); - copy_data(socket).await - }; - - // Finally, run everything! - let (data, _) = join!(server, client); - // assert_eq!(amt, AMT); - assert!(data == vec![9; AMT]); -} - -#[tokio::test] -async fn server_to_client() { - drop(env_logger::try_init()); - - // Create a server listening on a port, then figure out what that port is - let mut srv = t!(TcpListener::bind("127.0.0.1:0").await); - let addr = t!(srv.local_addr()); - - let (server_cx, client_cx) = contexts(); - - let server = async move { - let mut incoming = srv.incoming(); - let socket = t!(incoming.next().await.unwrap()); - let socket = t!(server_cx.accept(socket).await); - copy_data(socket).await - }; - - let client = async move { - let socket = t!(TcpStream::connect(&addr).await); - let mut socket = t!(client_cx.connect("localhost", socket).await); - let mut data = Vec::new(); - t!(socket.read_to_end(&mut data).await); - data - }; - - // Finally, run everything! - let (_, data) = join!(server, client); - // assert_eq!(amt, AMT); - assert!(data == vec![9; AMT]); -} - -#[tokio::test] -async fn one_byte_at_a_time() { - const AMT: usize = 1024; - drop(env_logger::try_init()); - - let mut srv = t!(TcpListener::bind("127.0.0.1:0").await); - let addr = t!(srv.local_addr()); - - let (server_cx, client_cx) = contexts(); - - let server = async move { - let mut incoming = srv.incoming(); - let socket = t!(incoming.next().await.unwrap()); - let mut socket = t!(server_cx.accept(socket).await); - let mut amt = 0; - for b in std::iter::repeat(9).take(AMT) { - let data = [b as u8]; - t!(socket.write_all(&data).await); - amt += 1; - } - amt - }; - - let client = async move { - let socket = t!(TcpStream::connect(&addr).await); - let mut socket = t!(client_cx.connect("localhost", socket).await); - let mut data = Vec::new(); - loop { - let mut buf = [0; 1]; - match socket.read_exact(&mut buf).await { - Ok(_) => data.extend_from_slice(&buf), - Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => break, - Err(err) => panic!(err), - } - } - data - }; - - let (amt, data) = join!(server, client); - assert_eq!(amt, AMT); - assert!(data == vec![9; AMT as usize]); -} diff --git a/tokio-util/CHANGELOG.md b/tokio-util/CHANGELOG.md index 9847064a76e..eaafe263e07 100644 --- a/tokio-util/CHANGELOG.md +++ b/tokio-util/CHANGELOG.md @@ -3,24 +3,30 @@ ### Fixed - Adjust minimum-supported Tokio version to v0.2.5 to account for an internal - dependency on features in that version of Tokio. (#2326) + dependency on features in that version of Tokio. ([#2326]) # 0.3.0 (March 4, 2020) ### Changed - **Breaking Change**: Change `Encoder` trait to take a generic `Item` parameter, which allows - codec writers to pass references into `Framed` and `FramedWrite` types. (#1746) + codec writers to pass references into `Framed` and `FramedWrite` types. ([#1746]) ### Added -- Add futures-io/tokio::io compatibility layer. (#2117) -- Add `Framed::with_capacity`. (#2215) +- Add futures-io/tokio::io compatibility layer. ([#2117]) +- Add `Framed::with_capacity`. ([#2215]) ### Fixed -- Use advance over split_to when data is not needed. (#2198) +- Use advance over split_to when data is not needed. ([#2198]) # 0.2.0 (November 26, 2019) - Initial release + +[#2326]: https://github.com/tokio-rs/tokio/pull/2326 +[#2215]: https://github.com/tokio-rs/tokio/pull/2215 +[#2198]: https://github.com/tokio-rs/tokio/pull/2198 +[#2117]: https://github.com/tokio-rs/tokio/pull/2117 +[#1746]: https://github.com/tokio-rs/tokio/pull/1746 diff --git a/tokio-util/src/codec/framed.rs b/tokio-util/src/codec/framed.rs index d2e7659eda2..36370da2694 100644 --- a/tokio-util/src/codec/framed.rs +++ b/tokio-util/src/codec/framed.rs @@ -1,10 +1,9 @@ use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; -use crate::codec::framed_read::{framed_read2, framed_read2_with_buffer, FramedRead2}; -use crate::codec::framed_write::{framed_write2, framed_write2_with_buffer, FramedWrite2}; +use crate::codec::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; use tokio::{ - io::{AsyncBufRead, AsyncRead, AsyncWrite}, + io::{AsyncRead, AsyncWrite}, stream::Stream, }; @@ -12,8 +11,7 @@ use bytes::BytesMut; use futures_sink::Sink; use pin_project_lite::pin_project; use std::fmt; -use std::io::{self, BufRead, Read, Write}; -use std::mem::MaybeUninit; +use std::io; use std::pin::Pin; use std::task::{Context, Poll}; @@ -30,37 +28,7 @@ pin_project! { /// [`Decoder::framed`]: crate::codec::Decoder::framed() pub struct Framed { #[pin] - inner: FramedRead2>>, - } -} - -pin_project! { - pub(crate) struct Fuse { - #[pin] - pub(crate) io: T, - pub(crate) codec: U, - } -} - -/// Abstracts over `FramedRead2` being either `FramedRead2>>` or -/// `FramedRead2>` and lets the io and codec parts be extracted in either case. -pub(crate) trait ProjectFuse { - type Io; - type Codec; - - fn project(self: Pin<&mut Self>) -> Fuse, &mut Self::Codec>; -} - -impl ProjectFuse for Fuse { - type Io = T; - type Codec = U; - - fn project(self: Pin<&mut Self>) -> Fuse, &mut Self::Codec> { - let self_ = self.project(); - Fuse { - io: self_.io, - codec: self_.codec, - } + inner: FramedImpl } } @@ -93,7 +61,11 @@ where /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split pub fn new(inner: T, codec: U) -> Framed { Framed { - inner: framed_read2(framed_write2(Fuse { io: inner, codec })), + inner: FramedImpl { + inner, + codec, + state: Default::default(), + }, } } @@ -123,10 +95,18 @@ where /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split pub fn with_capacity(inner: T, codec: U, capacity: usize) -> Framed { Framed { - inner: framed_read2_with_buffer( - framed_write2(Fuse { io: inner, codec }), - BytesMut::with_capacity(capacity), - ), + inner: FramedImpl { + inner, + codec, + state: RWFrames { + read: ReadFrame { + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(capacity), + }, + write: WriteFrame::default(), + }, + }, } } } @@ -161,16 +141,14 @@ impl Framed { /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split pub fn from_parts(parts: FramedParts) -> Framed { Framed { - inner: framed_read2_with_buffer( - framed_write2_with_buffer( - Fuse { - io: parts.io, - codec: parts.codec, - }, - parts.write_buf, - ), - parts.read_buf, - ), + inner: FramedImpl { + inner: parts.io, + codec: parts.codec, + state: RWFrames { + read: parts.read_buf.into(), + write: parts.write_buf.into(), + }, + }, } } @@ -181,7 +159,7 @@ impl Framed { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_ref(&self) -> &T { - &self.inner.get_ref().get_ref().io + &self.inner.inner } /// Returns a mutable reference to the underlying I/O stream wrapped by @@ -191,7 +169,7 @@ impl Framed { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_mut(&mut self) -> &mut T { - &mut self.inner.get_mut().get_mut().io + &mut self.inner.inner } /// Returns a reference to the underlying codec wrapped by @@ -200,7 +178,7 @@ impl Framed { /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn codec(&self) -> &U { - &self.inner.get_ref().get_ref().codec + &self.inner.codec } /// Returns a mutable reference to the underlying codec wrapped by @@ -209,12 +187,17 @@ impl Framed { /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn codec_mut(&mut self) -> &mut U { - &mut self.inner.get_mut().get_mut().codec + &mut self.inner.codec } /// Returns a reference to the read buffer. pub fn read_buffer(&self) -> &BytesMut { - self.inner.buffer() + &self.inner.state.read.buffer + } + + /// Returns a mutable reference to the read buffer. + pub fn read_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.inner.state.read.buffer } /// Consumes the `Framed`, returning its underlying I/O stream. @@ -223,7 +206,7 @@ impl Framed { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_inner(self) -> T { - self.inner.into_inner().into_inner().io + self.inner.inner } /// Consumes the `Framed`, returning its underlying I/O stream, the buffer @@ -233,19 +216,17 @@ impl Framed { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_parts(self) -> FramedParts { - let (inner, read_buf) = self.inner.into_parts(); - let (inner, write_buf) = inner.into_parts(); - FramedParts { - io: inner.io, - codec: inner.codec, - read_buf, - write_buf, + io: self.inner.inner, + codec: self.inner.codec, + read_buf: self.inner.state.read.buffer, + write_buf: self.inner.state.write.buffer, _priv: (), } } } +// This impl just defers to the underlying FramedImpl impl Stream for Framed where T: AsyncRead, @@ -258,6 +239,7 @@ where } } +// This impl just defers to the underlying FramedImpl impl Sink for Framed where T: AsyncWrite, @@ -267,19 +249,19 @@ where type Error = U::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.get_pin_mut().poll_ready(cx) + self.project().inner.poll_ready(cx) } fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { - self.project().inner.get_pin_mut().start_send(item) + self.project().inner.start_send(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.get_pin_mut().poll_flush(cx) + self.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.get_pin_mut().poll_close(cx) + self.project().inner.poll_close(cx) } } @@ -290,109 +272,19 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Framed") - .field("io", &self.inner.get_ref().get_ref().io) - .field("codec", &self.inner.get_ref().get_ref().codec) + .field("io", self.get_ref()) + .field("codec", self.codec()) .finish() } } -// ===== impl Fuse ===== - -impl Read for Fuse { - fn read(&mut self, dst: &mut [u8]) -> io::Result { - self.io.read(dst) - } -} - -impl BufRead for Fuse { - fn fill_buf(&mut self) -> io::Result<&[u8]> { - self.io.fill_buf() - } - - fn consume(&mut self, amt: usize) { - self.io.consume(amt) - } -} - -impl AsyncRead for Fuse { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { - self.io.prepare_uninitialized_buffer(buf) - } - - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - self.project().io.poll_read(cx, buf) - } -} - -impl AsyncBufRead for Fuse { - fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().io.poll_fill_buf(cx) - } - - fn consume(self: Pin<&mut Self>, amt: usize) { - self.project().io.consume(amt) - } -} - -impl Write for Fuse { - fn write(&mut self, src: &[u8]) -> io::Result { - self.io.write(src) - } - - fn flush(&mut self) -> io::Result<()> { - self.io.flush() - } -} - -impl AsyncWrite for Fuse { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.project().io.poll_write(cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().io.poll_flush(cx) - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().io.poll_shutdown(cx) - } -} - -impl Decoder for Fuse { - type Item = U::Item; - type Error = U::Error; - - fn decode(&mut self, buffer: &mut BytesMut) -> Result, Self::Error> { - self.codec.decode(buffer) - } - - fn decode_eof(&mut self, buffer: &mut BytesMut) -> Result, Self::Error> { - self.codec.decode_eof(buffer) - } -} - -impl> Encoder for Fuse { - type Error = U::Error; - - fn encode(&mut self, item: I, dst: &mut BytesMut) -> Result<(), Self::Error> { - self.codec.encode(item, dst) - } -} - /// `FramedParts` contains an export of the data of a Framed transport. /// It can be used to construct a new [`Framed`] with a different codec. /// It contains all current buffers and the inner transport. /// /// [`Framed`]: crate::codec::Framed #[derive(Debug)] +#[allow(clippy::manual_non_exhaustive)] pub struct FramedParts { /// The inner transport used to read bytes to and write bytes to pub io: T, diff --git a/tokio-util/src/codec/framed_impl.rs b/tokio-util/src/codec/framed_impl.rs new file mode 100644 index 00000000000..eb2e0d38c6d --- /dev/null +++ b/tokio-util/src/codec/framed_impl.rs @@ -0,0 +1,225 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use tokio::{ + io::{AsyncRead, AsyncWrite}, + stream::Stream, +}; + +use bytes::{Buf, BytesMut}; +use futures_core::ready; +use futures_sink::Sink; +use log::trace; +use pin_project_lite::pin_project; +use std::borrow::{Borrow, BorrowMut}; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + #[derive(Debug)] + pub(crate) struct FramedImpl { + #[pin] + pub(crate) inner: T, + pub(crate) state: State, + pub(crate) codec: U, + } +} + +const INITIAL_CAPACITY: usize = 8 * 1024; +const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY; + +pub(crate) struct ReadFrame { + pub(crate) eof: bool, + pub(crate) is_readable: bool, + pub(crate) buffer: BytesMut, +} + +pub(crate) struct WriteFrame { + pub(crate) buffer: BytesMut, +} + +#[derive(Default)] +pub(crate) struct RWFrames { + pub(crate) read: ReadFrame, + pub(crate) write: WriteFrame, +} + +impl Default for ReadFrame { + fn default() -> Self { + Self { + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(INITIAL_CAPACITY), + } + } +} + +impl Default for WriteFrame { + fn default() -> Self { + Self { + buffer: BytesMut::with_capacity(INITIAL_CAPACITY), + } + } +} + +impl From for ReadFrame { + fn from(mut buffer: BytesMut) -> Self { + let size = buffer.capacity(); + if size < INITIAL_CAPACITY { + buffer.reserve(INITIAL_CAPACITY - size); + } + + Self { + buffer, + is_readable: size > 0, + eof: false, + } + } +} + +impl From for WriteFrame { + fn from(mut buffer: BytesMut) -> Self { + let size = buffer.capacity(); + if size < INITIAL_CAPACITY { + buffer.reserve(INITIAL_CAPACITY - size); + } + + Self { buffer } + } +} + +impl Borrow for RWFrames { + fn borrow(&self) -> &ReadFrame { + &self.read + } +} +impl BorrowMut for RWFrames { + fn borrow_mut(&mut self) -> &mut ReadFrame { + &mut self.read + } +} +impl Borrow for RWFrames { + fn borrow(&self) -> &WriteFrame { + &self.write + } +} +impl BorrowMut for RWFrames { + fn borrow_mut(&mut self) -> &mut WriteFrame { + &mut self.write + } +} +impl Stream for FramedImpl +where + T: AsyncRead, + U: Decoder, + R: BorrowMut, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut pinned = self.project(); + let state: &mut ReadFrame = pinned.state.borrow_mut(); + loop { + // Repeatedly call `decode` or `decode_eof` as long as it is + // "readable". Readable is defined as not having returned `None`. If + // the upstream has returned EOF, and the decoder is no longer + // readable, it can be assumed that the decoder will never become + // readable again, at which point the stream is terminated. + if state.is_readable { + if state.eof { + let frame = pinned.codec.decode_eof(&mut state.buffer)?; + return Poll::Ready(frame.map(Ok)); + } + + trace!("attempting to decode a frame"); + + if let Some(frame) = pinned.codec.decode(&mut state.buffer)? { + trace!("frame decoded from buffer"); + return Poll::Ready(Some(Ok(frame))); + } + + state.is_readable = false; + } + + assert!(!state.eof); + + // Otherwise, try to read more data and try again. Make sure we've + // got room for at least one byte to read to ensure that we don't + // get a spurious 0 that looks like EOF + state.buffer.reserve(1); + let bytect = match pinned.inner.as_mut().poll_read_buf(cx, &mut state.buffer)? { + Poll::Ready(ct) => ct, + Poll::Pending => return Poll::Pending, + }; + if bytect == 0 { + state.eof = true; + } + + state.is_readable = true; + } + } +} + +impl Sink for FramedImpl +where + T: AsyncWrite, + U: Encoder, + U::Error: From, + W: BorrowMut, +{ + type Error = U::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.state.borrow().buffer.len() >= BACKPRESSURE_BOUNDARY { + self.as_mut().poll_flush(cx) + } else { + Poll::Ready(Ok(())) + } + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + let pinned = self.project(); + pinned + .codec + .encode(item, &mut pinned.state.borrow_mut().buffer)?; + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + trace!("flushing framed transport"); + let mut pinned = self.project(); + + while !pinned.state.borrow_mut().buffer.is_empty() { + let WriteFrame { buffer } = pinned.state.borrow_mut(); + trace!("writing; remaining={}", buffer.len()); + + let buf = &buffer; + let n = ready!(pinned.inner.as_mut().poll_write(cx, &buf))?; + + if n == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to \ + write frame to transport", + ) + .into())); + } + + pinned.state.borrow_mut().buffer.advance(n); + } + + // Try flushing the underlying IO + ready!(pinned.inner.poll_flush(cx))?; + + trace!("framed transport flushed"); + Poll::Ready(Ok(())) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + ready!(self.project().inner.poll_shutdown(cx))?; + + Poll::Ready(Ok(())) + } +} diff --git a/tokio-util/src/codec/framed_read.rs b/tokio-util/src/codec/framed_read.rs index e7798c327ee..a6844b73557 100644 --- a/tokio-util/src/codec/framed_read.rs +++ b/tokio-util/src/codec/framed_read.rs @@ -1,11 +1,10 @@ -use crate::codec::framed::{Fuse, ProjectFuse}; +use crate::codec::framed_impl::{FramedImpl, ReadFrame}; use crate::codec::Decoder; use tokio::{io::AsyncRead, stream::Stream}; use bytes::BytesMut; use futures_sink::Sink; -use log::trace; use pin_project_lite::pin_project; use std::fmt; use std::pin::Pin; @@ -18,22 +17,10 @@ pin_project! { /// [`AsyncRead`]: tokio::io::AsyncRead pub struct FramedRead { #[pin] - inner: FramedRead2>, + inner: FramedImpl, } } -pin_project! { - pub(crate) struct FramedRead2 { - #[pin] - inner: T, - eof: bool, - is_readable: bool, - buffer: BytesMut, - } -} - -const INITIAL_CAPACITY: usize = 8 * 1024; - // ===== impl FramedRead ===== impl FramedRead @@ -44,10 +31,11 @@ where /// Creates a new `FramedRead` with the given `decoder`. pub fn new(inner: T, decoder: D) -> FramedRead { FramedRead { - inner: framed_read2(Fuse { - io: inner, + inner: FramedImpl { + inner, codec: decoder, - }), + state: Default::default(), + }, } } @@ -55,13 +43,15 @@ where /// initial size. pub fn with_capacity(inner: T, decoder: D, capacity: usize) -> FramedRead { FramedRead { - inner: framed_read2_with_buffer( - Fuse { - io: inner, - codec: decoder, + inner: FramedImpl { + inner, + codec: decoder, + state: ReadFrame { + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(capacity), }, - BytesMut::with_capacity(capacity), - ), + }, } } } @@ -74,7 +64,7 @@ impl FramedRead { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_ref(&self) -> &T { - &self.inner.inner.io + &self.inner.inner } /// Returns a mutable reference to the underlying I/O stream wrapped by @@ -84,7 +74,7 @@ impl FramedRead { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_mut(&mut self) -> &mut T { - &mut self.inner.inner.io + &mut self.inner.inner } /// Consumes the `FramedRead`, returning its underlying I/O stream. @@ -93,25 +83,26 @@ impl FramedRead { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_inner(self) -> T { - self.inner.inner.io + self.inner.inner } /// Returns a reference to the underlying decoder. pub fn decoder(&self) -> &D { - &self.inner.inner.codec + &self.inner.codec } /// Returns a mutable reference to the underlying decoder. pub fn decoder_mut(&mut self) -> &mut D { - &mut self.inner.inner.codec + &mut self.inner.codec } /// Returns a reference to the read buffer. pub fn read_buffer(&self) -> &BytesMut { - &self.inner.buffer + &self.inner.state.buffer } } +// This impl just defers to the underlying FramedImpl impl Stream for FramedRead where T: AsyncRead, @@ -132,43 +123,19 @@ where type Error = T::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .inner - .project() - .inner - .project() - .io - .poll_ready(cx) + self.project().inner.project().inner.poll_ready(cx) } fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { - self.project() - .inner - .project() - .inner - .project() - .io - .start_send(item) + self.project().inner.project().inner.start_send(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .inner - .project() - .inner - .project() - .io - .poll_flush(cx) + self.project().inner.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .inner - .project() - .inner - .project() - .io - .poll_close(cx) + self.project().inner.project().inner.poll_close(cx) } } @@ -179,126 +146,11 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FramedRead") - .field("inner", &self.inner.inner.io) - .field("decoder", &self.inner.inner.codec) - .field("eof", &self.inner.eof) - .field("is_readable", &self.inner.is_readable) - .field("buffer", &self.inner.buffer) + .field("inner", &self.get_ref()) + .field("decoder", &self.decoder()) + .field("eof", &self.inner.state.eof) + .field("is_readable", &self.inner.state.is_readable) + .field("buffer", &self.read_buffer()) .finish() } } - -// ===== impl FramedRead2 ===== - -pub(crate) fn framed_read2(inner: T) -> FramedRead2 { - FramedRead2 { - inner, - eof: false, - is_readable: false, - buffer: BytesMut::with_capacity(INITIAL_CAPACITY), - } -} - -pub(crate) fn framed_read2_with_buffer(inner: T, mut buf: BytesMut) -> FramedRead2 { - if buf.capacity() < INITIAL_CAPACITY { - let bytes_to_reserve = INITIAL_CAPACITY - buf.capacity(); - buf.reserve(bytes_to_reserve); - } - FramedRead2 { - inner, - eof: false, - is_readable: !buf.is_empty(), - buffer: buf, - } -} - -impl FramedRead2 { - pub(crate) fn get_ref(&self) -> &T { - &self.inner - } - - pub(crate) fn into_inner(self) -> T { - self.inner - } - - pub(crate) fn into_parts(self) -> (T, BytesMut) { - (self.inner, self.buffer) - } - - pub(crate) fn get_mut(&mut self) -> &mut T { - &mut self.inner - } - - pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { - self.project().inner - } - - pub(crate) fn buffer(&self) -> &BytesMut { - &self.buffer - } -} - -impl Stream for FramedRead2 -where - T: ProjectFuse + AsyncRead, - T::Codec: Decoder, -{ - type Item = Result<::Item, ::Error>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut pinned = self.project(); - loop { - // Repeatedly call `decode` or `decode_eof` as long as it is - // "readable". Readable is defined as not having returned `None`. If - // the upstream has returned EOF, and the decoder is no longer - // readable, it can be assumed that the decoder will never become - // readable again, at which point the stream is terminated. - if *pinned.is_readable { - if *pinned.eof { - let frame = pinned - .inner - .as_mut() - .project() - .codec - .decode_eof(&mut pinned.buffer)?; - return Poll::Ready(frame.map(Ok)); - } - - trace!("attempting to decode a frame"); - - if let Some(frame) = pinned - .inner - .as_mut() - .project() - .codec - .decode(&mut pinned.buffer)? - { - trace!("frame decoded from buffer"); - return Poll::Ready(Some(Ok(frame))); - } - - *pinned.is_readable = false; - } - - assert!(!*pinned.eof); - - // Otherwise, try to read more data and try again. Make sure we've - // got room for at least one byte to read to ensure that we don't - // get a spurious 0 that looks like EOF - pinned.buffer.reserve(1); - let bytect = match pinned - .inner - .as_mut() - .poll_read_buf(cx, &mut pinned.buffer)? - { - Poll::Ready(ct) => ct, - Poll::Pending => return Poll::Pending, - }; - if bytect == 0 { - *pinned.eof = true; - } - - *pinned.is_readable = true; - } - } -} diff --git a/tokio-util/src/codec/framed_write.rs b/tokio-util/src/codec/framed_write.rs index c0049b2d04a..834eb6ed0f8 100644 --- a/tokio-util/src/codec/framed_write.rs +++ b/tokio-util/src/codec/framed_write.rs @@ -1,20 +1,12 @@ -use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; -use crate::codec::framed::{Fuse, ProjectFuse}; +use crate::codec::framed_impl::{FramedImpl, WriteFrame}; -use tokio::{ - io::{AsyncBufRead, AsyncRead, AsyncWrite}, - stream::Stream, -}; +use tokio::{io::AsyncWrite, stream::Stream}; -use bytes::{Buf, BytesMut}; -use futures_core::ready; use futures_sink::Sink; -use log::trace; use pin_project_lite::pin_project; use std::fmt; -use std::io::{self, BufRead, Read}; -use std::mem::MaybeUninit; +use std::io; use std::pin::Pin; use std::task::{Context, Poll}; @@ -24,21 +16,10 @@ pin_project! { /// [`Sink`]: futures_sink::Sink pub struct FramedWrite { #[pin] - inner: FramedWrite2>, + inner: FramedImpl, } } -pin_project! { - pub(crate) struct FramedWrite2 { - #[pin] - inner: T, - buffer: BytesMut, - } -} - -const INITIAL_CAPACITY: usize = 8 * 1024; -const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY; - impl FramedWrite where T: AsyncWrite, @@ -46,10 +27,11 @@ where /// Creates a new `FramedWrite` with the given `encoder`. pub fn new(inner: T, encoder: E) -> FramedWrite { FramedWrite { - inner: framed_write2(Fuse { - io: inner, + inner: FramedImpl { + inner, codec: encoder, - }), + state: WriteFrame::default(), + }, } } } @@ -62,7 +44,7 @@ impl FramedWrite { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_ref(&self) -> &T { - &self.inner.inner.io + &self.inner.inner } /// Returns a mutable reference to the underlying I/O stream wrapped by @@ -72,7 +54,7 @@ impl FramedWrite { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_mut(&mut self) -> &mut T { - &mut self.inner.inner.io + &mut self.inner.inner } /// Consumes the `FramedWrite`, returning its underlying I/O stream. @@ -81,21 +63,21 @@ impl FramedWrite { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_inner(self) -> T { - self.inner.inner.io + self.inner.inner } - /// Returns a reference to the underlying decoder. + /// Returns a reference to the underlying encoder. pub fn encoder(&self) -> &E { - &self.inner.inner.codec + &self.inner.codec } - /// Returns a mutable reference to the underlying decoder. + /// Returns a mutable reference to the underlying encoder. pub fn encoder_mut(&mut self) -> &mut E { - &mut self.inner.inner.codec + &mut self.inner.codec } } -// This impl just defers to the underlying FramedWrite2 +// This impl just defers to the underlying FramedImpl impl Sink for FramedWrite where T: AsyncWrite, @@ -121,6 +103,7 @@ where } } +// This impl just defers to the underlying T: Stream impl Stream for FramedWrite where T: Stream, @@ -128,13 +111,7 @@ where type Item = T::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .inner - .project() - .inner - .project() - .io - .poll_next(cx) + self.project().inner.project().inner.poll_next(cx) } } @@ -145,180 +122,9 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FramedWrite") - .field("inner", &self.inner.get_ref().io) - .field("encoder", &self.inner.get_ref().codec) - .field("buffer", &self.inner.buffer) + .field("inner", &self.get_ref()) + .field("encoder", &self.encoder()) + .field("buffer", &self.inner.state.buffer) .finish() } } - -// ===== impl FramedWrite2 ===== - -pub(crate) fn framed_write2(inner: T) -> FramedWrite2 { - FramedWrite2 { - inner, - buffer: BytesMut::with_capacity(INITIAL_CAPACITY), - } -} - -pub(crate) fn framed_write2_with_buffer(inner: T, mut buf: BytesMut) -> FramedWrite2 { - if buf.capacity() < INITIAL_CAPACITY { - let bytes_to_reserve = INITIAL_CAPACITY - buf.capacity(); - buf.reserve(bytes_to_reserve); - } - FramedWrite2 { inner, buffer: buf } -} - -impl FramedWrite2 { - pub(crate) fn get_ref(&self) -> &T { - &self.inner - } - - pub(crate) fn into_inner(self) -> T { - self.inner - } - - pub(crate) fn into_parts(self) -> (T, BytesMut) { - (self.inner, self.buffer) - } - - pub(crate) fn get_mut(&mut self) -> &mut T { - &mut self.inner - } -} - -impl Sink for FramedWrite2 -where - T: ProjectFuse + AsyncWrite, - T::Codec: Encoder, -{ - type Error = >::Error; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // If the buffer is already over 8KiB, then attempt to flush it. If after flushing it's - // *still* over 8KiB, then apply backpressure (reject the send). - if self.buffer.len() >= BACKPRESSURE_BOUNDARY { - match self.as_mut().poll_flush(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok(())) => (), - }; - - if self.buffer.len() >= BACKPRESSURE_BOUNDARY { - return Poll::Pending; - } - } - Poll::Ready(Ok(())) - } - - fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { - let mut pinned = self.project(); - pinned - .inner - .project() - .codec - .encode(item, &mut pinned.buffer)?; - Ok(()) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - trace!("flushing framed transport"); - let mut pinned = self.project(); - - while !pinned.buffer.is_empty() { - trace!("writing; remaining={}", pinned.buffer.len()); - - let buf = &pinned.buffer; - let n = ready!(pinned.inner.as_mut().poll_write(cx, &buf))?; - - if n == 0 { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to \ - write frame to transport", - ) - .into())); - } - - pinned.buffer.advance(n); - } - - // Try flushing the underlying IO - ready!(pinned.inner.poll_flush(cx))?; - - trace!("framed transport flushed"); - Poll::Ready(Ok(())) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().poll_flush(cx))?; - ready!(self.project().inner.poll_shutdown(cx))?; - - Poll::Ready(Ok(())) - } -} - -impl Decoder for FramedWrite2 { - type Item = T::Item; - type Error = T::Error; - - fn decode(&mut self, src: &mut BytesMut) -> Result, T::Error> { - self.inner.decode(src) - } - - fn decode_eof(&mut self, src: &mut BytesMut) -> Result, T::Error> { - self.inner.decode_eof(src) - } -} - -impl Read for FramedWrite2 { - fn read(&mut self, dst: &mut [u8]) -> io::Result { - self.inner.read(dst) - } -} - -impl BufRead for FramedWrite2 { - fn fill_buf(&mut self) -> io::Result<&[u8]> { - self.inner.fill_buf() - } - - fn consume(&mut self, amt: usize) { - self.inner.consume(amt) - } -} - -impl AsyncRead for FramedWrite2 { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } - - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - self.project().inner.poll_read(cx, buf) - } -} - -impl AsyncBufRead for FramedWrite2 { - fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_fill_buf(cx) - } - - fn consume(self: Pin<&mut Self>, amt: usize) { - self.project().inner.consume(amt) - } -} - -impl ProjectFuse for FramedWrite2 -where - T: ProjectFuse, -{ - type Io = T::Io; - type Codec = T::Codec; - - fn project(self: Pin<&mut Self>) -> Fuse, &mut Self::Codec> { - self.project().inner.project() - } -} diff --git a/tokio-util/src/codec/length_delimited.rs b/tokio-util/src/codec/length_delimited.rs index f11fb3ea465..2426b771ae5 100644 --- a/tokio-util/src/codec/length_delimited.rs +++ b/tokio-util/src/codec/length_delimited.rs @@ -364,13 +364,13 @@ //! +------------+--------------+ //! ``` //! -//! [`LengthDelimitedCodec::new()`]: struct.LengthDelimitedCodec.html#method.new -//! [`FramedRead`]: struct.FramedRead.html -//! [`FramedWrite`]: struct.FramedWrite.html -//! [`AsyncRead`]: ../../trait.AsyncRead.html -//! [`AsyncWrite`]: ../../trait.AsyncWrite.html -//! [`Encoder`]: ../trait.Encoder.html -//! [`BytesMut`]: https://docs.rs/bytes/0.4/bytes/struct.BytesMut.html +//! [`LengthDelimitedCodec::new()`]: method@LengthDelimitedCodec::new +//! [`FramedRead`]: struct@FramedRead +//! [`FramedWrite`]: struct@FramedWrite +//! [`AsyncRead`]: trait@tokio::io::AsyncRead +//! [`AsyncWrite`]: trait@tokio::io::AsyncWrite +//! [`Encoder`]: trait@Encoder +//! [`BytesMut`]: bytes::BytesMut use crate::codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite}; diff --git a/tokio-util/src/codec/mod.rs b/tokio-util/src/codec/mod.rs index ec76a6419f0..e89aa7c9ac8 100644 --- a/tokio-util/src/codec/mod.rs +++ b/tokio-util/src/codec/mod.rs @@ -1,8 +1,13 @@ -//! Utilities for encoding and decoding frames. +//! Adaptors from AsyncRead/AsyncWrite to Stream/Sink //! -//! Contains adapters to go from streams of bytes, [`AsyncRead`] and -//! [`AsyncWrite`], to framed streams implementing [`Sink`] and [`Stream`]. -//! Framed streams are also known as transports. +//! Raw I/O objects work with byte sequences, but higher-level code +//! usually wants to batch these into meaningful chunks, called +//! "frames". +//! +//! This module contains adapters to go from streams of bytes, +//! [`AsyncRead`] and [`AsyncWrite`], to framed streams implementing +//! [`Sink`] and [`Stream`]. Framed streams are also known as +//! transports. //! //! [`AsyncRead`]: tokio::io::AsyncRead //! [`AsyncWrite`]: tokio::io::AsyncWrite @@ -18,6 +23,10 @@ pub use self::decoder::Decoder; mod encoder; pub use self::encoder::Encoder; +mod framed_impl; +#[allow(unused_imports)] +pub(crate) use self::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; + mod framed; pub use self::framed::{Framed, FramedParts}; diff --git a/tokio/CHANGELOG.md b/tokio/CHANGELOG.md index 92228f51f35..b852b393da4 100644 --- a/tokio/CHANGELOG.md +++ b/tokio/CHANGELOG.md @@ -1,7 +1,49 @@ +# 0.2.21 (May 13, 2020) + +### Fixes + +- macros: disambiguate built-in `#[test]` attribute in macro expansion (#2503) +- rt: `LocalSet` and task budgeting (#2462). +- rt: task budgeting with `block_in_place` (#2502). +- sync: release `broadcast` channel memory without sending a value (#2509). +- time: notify when resetting a `Delay` to a time in the past (#2290). + +### Added +- io: `get_mut`, `get_ref`, and `into_inner` to `Lines` (#2450). +- io: `mio::Ready` argument to `PollEvented` (#2419). +- os: illumos support (#2486). +- rt: `Handle::spawn_blocking` (#2501). +- sync: `OwnedMutexGuard` for `Arc>` (#2455). + +# 0.2.20 (April 28, 2020) + +### Fixes +- sync: `broadcast` closing the channel no longer requires capacity (#2448). +- rt: regression when configuring runtime with `max_threads` less than number of CPUs (#2457). + +# 0.2.19 (April 24, 2020) + +### Fixes +- docs: misc improvements (#2400, #2405, #2414, #2420, #2423, #2426, #2427, #2434, #2436, #2440). +- rt: support `block_in_place` in more contexts (#2409, #2410). +- stream: no panic in `merge()` and `chain()` when using `size_hint()` (#2430). +- task: include visibility modifier when defining a task-local (#2416). + +### Added +- rt: `runtime::Handle::block_on` (#2437). +- sync: owned `Semaphore` permit (#2421). +- tcp: owned split (#2270). + +# 0.2.18 (April 12, 2020) + +### Fixes +- task: `LocalSet` was incorrectly marked as `Send` (#2398) +- io: correctly report `WriteZero` failure in `write_int` (#2334) + # 0.2.17 (April 9, 2020) ### Fixes -- rt: bug in work-stealing queue (#2387) +- rt: bug in work-stealing queue (#2387) ### Changes - rt: threadpool uses logical CPU count instead of physical by default (#2391) @@ -11,121 +53,121 @@ ### Fixes - sync: fix a regression where `Mutex`, `Semaphore`, and `RwLock` futures no - longer implement `Sync` (#2375) -- fs: fix `fs::copy` not copying file permissions (#2354) + longer implement `Sync` ([#2375]) +- fs: fix `fs::copy` not copying file permissions ([#2354]) ### Added -- time: added `deadline` method to `delay_queue::Expired` (#2300) -- io: added `StreamReader` (#2052) +- time: added `deadline` method to `delay_queue::Expired` ([#2300]) +- io: added `StreamReader` ([#2052]) # 0.2.15 (April 2, 2020) ### Fixes -- rt: fix queue regression (#2362). +- rt: fix queue regression ([#2362]). ### Added -- sync: Add disarm to `mpsc::Sender` (#2358). +- sync: Add disarm to `mpsc::Sender` ([#2358]). # 0.2.14 (April 1, 2020) ### Fixes -- rt: concurrency bug in scheduler (#2273). -- rt: concurrency bug with shell runtime (#2333). -- test-util: correct pause/resume of time (#2253). -- time: `DelayQueue` correct wakeup after `insert` (#2285). +- rt: concurrency bug in scheduler ([#2273]). +- rt: concurrency bug with shell runtime ([#2333]). +- test-util: correct pause/resume of time ([#2253]). +- time: `DelayQueue` correct wakeup after `insert` ([#2285]). ### Added -- io: impl `RawFd`, `AsRawHandle` for std io types (#2335). +- io: impl `RawFd`, `AsRawHandle` for std io types ([#2335]). - rt: automatic cooperative task yielding (#2160, #2343, #2349). -- sync: `RwLock::into_inner` (#2321). +- sync: `RwLock::into_inner` ([#2321]). ### Changed -- sync: semaphore, mutex internals rewritten to avoid allocations (#2325). +- sync: semaphore, mutex internals rewritten to avoid allocations ([#2325]). # 0.2.13 (February 28, 2020) ### Fixes -- macros: unresolved import in `pin!` (#2281). +- macros: unresolved import in `pin!` ([#2281]). # 0.2.12 (February 27, 2020) ### Fixes -- net: `UnixStream::poll_shutdown` should call `shutdown(Write)` (#2245). -- process: Wake up read and write on `EPOLLERR` (#2218). +- net: `UnixStream::poll_shutdown` should call `shutdown(Write)` ([#2245]). +- process: Wake up read and write on `EPOLLERR` ([#2218]). - rt: potential deadlock when using `block_in_place` and shutting down the - runtime (#2119). -- rt: only detect number of CPUs if `core_threads` not specified (#2238). -- sync: reduce `watch::Receiver` struct size (#2191). -- time: succeed when setting delay of `$MAX-1` (#2184). -- time: avoid having to poll `DelayQueue` after inserting new delay (#2217). + runtime ([#2119]). +- rt: only detect number of CPUs if `core_threads` not specified ([#2238]). +- sync: reduce `watch::Receiver` struct size ([#2191]). +- time: succeed when setting delay of `$MAX-1` ([#2184]). +- time: avoid having to poll `DelayQueue` after inserting new delay ([#2217]). ### Added -- macros: `pin!` variant that assigns to identifier and pins (#2274). -- net: impl `Stream` for `Listener` types (#2275). +- macros: `pin!` variant that assigns to identifier and pins ([#2274]). +- net: impl `Stream` for `Listener` types ([#2275]). - rt: `Runtime::shutdown_timeout` waits for runtime to shutdown for specified - duration (#2186). + duration ([#2186]). - stream: `StreamMap` merges streams and can insert / remove streams at - runtime (#2185). -- stream: `StreamExt::skip()` skips a fixed number of items (#2204). -- stream: `StreamExt::skip_while()` skips items based on a predicate (#2205). -- sync: `Notify` provides basic `async` / `await` task notification (#2210). -- sync: `Mutex::into_inner` retrieves guarded data (#2250). + runtime ([#2185]). +- stream: `StreamExt::skip()` skips a fixed number of items ([#2204]). +- stream: `StreamExt::skip_while()` skips items based on a predicate ([#2205]). +- sync: `Notify` provides basic `async` / `await` task notification ([#2210]). +- sync: `Mutex::into_inner` retrieves guarded data ([#2250]). - sync: `mpsc::Sender::send_timeout` sends, waiting for up to specified duration - for channel capacity (#2227). -- time: impl `Ord` and `Hash` for `Instant` (#2239). + for channel capacity ([#2227]). +- time: impl `Ord` and `Hash` for `Instant` ([#2239]). # 0.2.11 (January 27, 2020) ### Fixes - docs: misc fixes and tweaks (#2155, #2103, #2027, #2167, #2175). -- macros: handle generics in `#[tokio::main]` method (#2177). -- sync: `broadcast` potential lost notifications (#2135). -- rt: improve "no runtime" panic messages (#2145). +- macros: handle generics in `#[tokio::main]` method ([#2177]). +- sync: `broadcast` potential lost notifications ([#2135]). +- rt: improve "no runtime" panic messages ([#2145]). ### Added -- optional support for using `parking_lot` internally (#2164). -- fs: `fs::copy`, an async version of `std::fs::copy` (#2079). -- macros: `select!` waits for the first branch to complete (#2152). -- macros: `join!` waits for all branches to complete (#2158). -- macros: `try_join!` waits for all branches to complete or the first error (#2169). -- macros: `pin!` pins a value to the stack (#2163). -- net: `ReadHalf::poll()` and `ReadHalf::poll_peak` (#2151) -- stream: `StreamExt::timeout()` sets a per-item max duration (#2149). -- stream: `StreamExt::fold()` applies a function, producing a single value. (#2122). -- sync: impl `Eq`, `PartialEq` for `oneshot::RecvError` (#2168). -- task: methods for inspecting the `JoinError` cause (#2051). +- optional support for using `parking_lot` internally ([#2164]). +- fs: `fs::copy`, an async version of `std::fs::copy` ([#2079]). +- macros: `select!` waits for the first branch to complete ([#2152]). +- macros: `join!` waits for all branches to complete ([#2158]). +- macros: `try_join!` waits for all branches to complete or the first error ([#2169]). +- macros: `pin!` pins a value to the stack ([#2163]). +- net: `ReadHalf::poll()` and `ReadHalf::poll_peak` ([#2151]) +- stream: `StreamExt::timeout()` sets a per-item max duration ([#2149]). +- stream: `StreamExt::fold()` applies a function, producing a single value. ([#2122]). +- sync: impl `Eq`, `PartialEq` for `oneshot::RecvError` ([#2168]). +- task: methods for inspecting the `JoinError` cause ([#2051]). # 0.2.10 (January 21, 2020) ### Fixes -- `#[tokio::main]` when `rt-core` feature flag is not enabled (#2139). -- remove `AsyncBufRead` from `BufStream` impl block (#2108). -- potential undefined behavior when implementing `AsyncRead` incorrectly (#2030). +- `#[tokio::main]` when `rt-core` feature flag is not enabled ([#2139]). +- remove `AsyncBufRead` from `BufStream` impl block ([#2108]). +- potential undefined behavior when implementing `AsyncRead` incorrectly ([#2030]). ### Added -- `BufStream::with_capacity` (#2125). -- impl `From` and `Default` for `RwLock` (#2089). +- `BufStream::with_capacity` ([#2125]). +- impl `From` and `Default` for `RwLock` ([#2089]). - `io::ReadHalf::is_pair_of` checks if provided `WriteHalf` is for the same underlying object (#1762, #2144). -- `runtime::Handle::try_current()` returns a handle to the current runtime (#2118). -- `stream::empty()` returns an immediately ready empty stream (#2092). -- `stream::once(val)` returns a stream that yields a single value: `val` (#2094). -- `stream::pending()` returns a stream that never becomes ready (#2092). -- `StreamExt::chain()` sequences a second stream after the first completes (#2093). -- `StreamExt::collect()` transform a stream into a collection (#2109). -- `StreamExt::fuse` ends the stream after the first `None` (#2085). -- `StreamExt::merge` combines two streams, yielding values as they become ready (#2091). -- Task-local storage (#2126). +- `runtime::Handle::try_current()` returns a handle to the current runtime ([#2118]). +- `stream::empty()` returns an immediately ready empty stream ([#2092]). +- `stream::once(val)` returns a stream that yields a single value: `val` ([#2094]). +- `stream::pending()` returns a stream that never becomes ready ([#2092]). +- `StreamExt::chain()` sequences a second stream after the first completes ([#2093]). +- `StreamExt::collect()` transform a stream into a collection ([#2109]). +- `StreamExt::fuse` ends the stream after the first `None` ([#2085]). +- `StreamExt::merge` combines two streams, yielding values as they become ready ([#2091]). +- Task-local storage ([#2126]). # 0.2.9 (January 9, 2020) ### Fixes -- `AsyncSeek` impl for `File` (#1986). +- `AsyncSeek` impl for `File` ([#1986]). - rt: shutdown deadlock in `threaded_scheduler` (#2074, #2082). -- rt: memory ordering when dropping `JoinHandle` (#2044). +- rt: memory ordering when dropping `JoinHandle` ([#2044]). - docs: misc API documentation fixes and improvements. # 0.2.8 (January 7, 2020) @@ -137,108 +179,108 @@ ### Fixes - potential deadlock when dropping `basic_scheduler` Runtime. -- calling `spawn_blocking` from within a `spawn_blocking` (#2006). -- storing a `Runtime` instance in a thread-local (#2011). +- calling `spawn_blocking` from within a `spawn_blocking` ([#2006]). +- storing a `Runtime` instance in a thread-local ([#2011]). - miscellaneous documentation fixes. -- rt: fix `Waker::will_wake` to return true when tasks match (#2045). -- test-util: `time::advance` runs pending tasks before changing the time (#2059). +- rt: fix `Waker::will_wake` to return true when tasks match ([#2045]). +- test-util: `time::advance` runs pending tasks before changing the time ([#2059]). ### Added -- `net::lookup_host` maps a `T: ToSocketAddrs` to a stream of `SocketAddrs` (#1870). -- `process::Child` fields are made public to match `std` (#2014). -- impl `Stream` for `sync::broadcast::Receiver` (#2012). -- `sync::RwLock` provides an asynchonous read-write lock (#1699). -- `runtime::Handle::current` returns the handle for the current runtime (#2040). -- `StreamExt::filter` filters stream values according to a predicate (#2001). -- `StreamExt::filter_map` simultaneously filter and map stream values (#2001). -- `StreamExt::try_next` convenience for streams of `Result` (#2005). -- `StreamExt::take` limits a stream to a specified number of values (#2025). -- `StreamExt::take_while` limits a stream based on a predicate (#2029). -- `StreamExt::all` tests if every element of the stream matches a predicate (#2035). -- `StreamExt::any` tests if any element of the stream matches a predicate (#2034). -- `task::LocalSet.await` runs spawned tasks until the set is idle (#1971). -- `time::DelayQueue::len` returns the number entries in the queue (#1755). -- expose runtime options from the `#[tokio::main]` and `#[tokio::test]` (#2022). +- `net::lookup_host` maps a `T: ToSocketAddrs` to a stream of `SocketAddrs` ([#1870]). +- `process::Child` fields are made public to match `std` ([#2014]). +- impl `Stream` for `sync::broadcast::Receiver` ([#2012]). +- `sync::RwLock` provides an asynchonous read-write lock ([#1699]). +- `runtime::Handle::current` returns the handle for the current runtime ([#2040]). +- `StreamExt::filter` filters stream values according to a predicate ([#2001]). +- `StreamExt::filter_map` simultaneously filter and map stream values ([#2001]). +- `StreamExt::try_next` convenience for streams of `Result` ([#2005]). +- `StreamExt::take` limits a stream to a specified number of values ([#2025]). +- `StreamExt::take_while` limits a stream based on a predicate ([#2029]). +- `StreamExt::all` tests if every element of the stream matches a predicate ([#2035]). +- `StreamExt::any` tests if any element of the stream matches a predicate ([#2034]). +- `task::LocalSet.await` runs spawned tasks until the set is idle ([#1971]). +- `time::DelayQueue::len` returns the number entries in the queue ([#1755]). +- expose runtime options from the `#[tokio::main]` and `#[tokio::test]` ([#2022]). # 0.2.6 (December 19, 2019) ### Fixes -- `fs::File::seek` API regression (#1991). +- `fs::File::seek` API regression ([#1991]). # 0.2.5 (December 18, 2019) ### Added -- `io::AsyncSeek` trait (#1924). -- `Mutex::try_lock` (#1939) -- `mpsc::Receiver::try_recv` and `mpsc::UnboundedReceiver::try_recv` (#1939). -- `writev` support for `TcpStream` (#1956). -- `time::throttle` for throttling streams (#1949). -- implement `Stream` for `time::DelayQueue` (#1975). -- `sync::broadcast` provides a fan-out channel (#1943). -- `sync::Semaphore` provides an async semaphore (#1973). -- `stream::StreamExt` provides stream utilities (#1962). +- `io::AsyncSeek` trait ([#1924]). +- `Mutex::try_lock` ([#1939]) +- `mpsc::Receiver::try_recv` and `mpsc::UnboundedReceiver::try_recv` ([#1939]). +- `writev` support for `TcpStream` ([#1956]). +- `time::throttle` for throttling streams ([#1949]). +- implement `Stream` for `time::DelayQueue` ([#1975]). +- `sync::broadcast` provides a fan-out channel ([#1943]). +- `sync::Semaphore` provides an async semaphore ([#1973]). +- `stream::StreamExt` provides stream utilities ([#1962]). ### Fixes -- deadlock risk while shutting down the runtime (#1972). -- panic while shutting down the runtime (#1978). -- `sync::MutexGuard` debug output (#1961). +- deadlock risk while shutting down the runtime ([#1972]). +- panic while shutting down the runtime ([#1978]). +- `sync::MutexGuard` debug output ([#1961]). - misc doc improvements (#1933, #1934, #1940, #1942). ### Changes - runtime threads are configured with `runtime::Builder::core_threads` and `runtime::Builder::max_threads`. `runtime::Builder::num_threads` is - deprecated (#1977). + deprecated ([#1977]). # 0.2.4 (December 6, 2019) ### Fixes -- `sync::Mutex` deadlock when `lock()` future is dropped early (#1898). +- `sync::Mutex` deadlock when `lock()` future is dropped early ([#1898]). # 0.2.3 (December 6, 2019) ### Added -- read / write integers using `AsyncReadExt` and `AsyncWriteExt` (#1863). -- `read_buf` / `write_buf` for reading / writing `Buf` / `BufMut` (#1881). -- `TcpStream::poll_peek` - pollable API for performing TCP peek (#1864). +- read / write integers using `AsyncReadExt` and `AsyncWriteExt` ([#1863]). +- `read_buf` / `write_buf` for reading / writing `Buf` / `BufMut` ([#1881]). +- `TcpStream::poll_peek` - pollable API for performing TCP peek ([#1864]). - `sync::oneshot::error::TryRecvError` provides variants to detect the error - kind (#1874). -- `LocalSet::block_on` accepts `!'static` task (#1882). -- `task::JoinError` is now `Sync` (#1888). + kind ([#1874]). +- `LocalSet::block_on` accepts `!'static` task ([#1882]). +- `task::JoinError` is now `Sync` ([#1888]). - impl conversions between `tokio::time::Instant` and - `std::time::Instant` (#1904). + `std::time::Instant` ([#1904]). ### Fixes -- calling `spawn_blocking` after runtime shutdown (#1875). -- `LocalSet` drop inifinite loop (#1892). -- `LocalSet` hang under load (#1905). +- calling `spawn_blocking` after runtime shutdown ([#1875]). +- `LocalSet` drop inifinite loop ([#1892]). +- `LocalSet` hang under load ([#1905]). - improved documentation (#1865, #1866, #1868, #1874, #1876, #1911). # 0.2.2 (November 29, 2019) ### Fixes -- scheduling with `basic_scheduler` (#1861). -- update `spawn` panic message to specify that a task scheduler is required (#1839). -- API docs example for `runtime::Builder` to include a task scheduler (#1841). -- general documentation (#1834). -- building on illumos/solaris (#1772). -- panic when dropping `LocalSet` (#1843). -- API docs mention the required Cargo features for `Builder::{basic, threaded}_scheduler` (#1858). +- scheduling with `basic_scheduler` ([#1861]). +- update `spawn` panic message to specify that a task scheduler is required ([#1839]). +- API docs example for `runtime::Builder` to include a task scheduler ([#1841]). +- general documentation ([#1834]). +- building on illumos/solaris ([#1772]). +- panic when dropping `LocalSet` ([#1843]). +- API docs mention the required Cargo features for `Builder::{basic, threaded}_scheduler` ([#1858]). ### Added -- impl `Stream` for `signal::unix::Signal` (#1849). -- API docs for platform specific behavior of `signal::ctrl_c` and `signal::unix::Signal` (#1854). -- API docs for `signal::unix::Signal::{recv, poll_recv}` and `signal::windows::CtrlBreak::{recv, poll_recv}` (#1854). -- `File::into_std` and `File::try_into_std` methods (#1856). +- impl `Stream` for `signal::unix::Signal` ([#1849]). +- API docs for platform specific behavior of `signal::ctrl_c` and `signal::unix::Signal` ([#1854]). +- API docs for `signal::unix::Signal::{recv, poll_recv}` and `signal::windows::CtrlBreak::{recv, poll_recv}` ([#1854]). +- `File::into_std` and `File::try_into_std` methods ([#1856]). # 0.2.1 (November 26, 2019) ### Fixes -- API docs for `TcpListener::incoming`, `UnixListener::incoming` (#1831). +- API docs for `TcpListener::incoming`, `UnixListener::incoming` ([#1831]). ### Added -- `tokio::task::LocalSet` provides a strategy for spawning `!Send` tasks (#1733). -- export `tokio::time::Elapsed` (#1826). -- impl `AsRawFd`, `AsRawHandle` for `tokio::fs::File` (#1827). +- `tokio::task::LocalSet` provides a strategy for spawning `!Send` tasks ([#1733]). +- export `tokio::time::Elapsed` ([#1826]). +- impl `AsRawFd`, `AsRawHandle` for `tokio::fs::File` ([#1827]). # 0.2.0 (November 26, 2019) @@ -264,69 +306,69 @@ another. This changelog entry contains a highlight # 0.1.21 (May 30, 2019) ### Changed -- Bump `tokio-trace-core` version to 0.2 (#1111). +- Bump `tokio-trace-core` version to 0.2 ([#1111]). # 0.1.20 (May 14, 2019) ### Added - `tokio::runtime::Builder::panic_handler` allows configuring handling - panics on the runtime (#1055). + panics on the runtime ([#1055]). # 0.1.19 (April 22, 2019) ### Added -- Re-export `tokio::sync::Mutex` primitive (#964). +- Re-export `tokio::sync::Mutex` primitive ([#964]). # 0.1.18 (March 22, 2019) ### Added -- `TypedExecutor` re-export and implementations (#993). +- `TypedExecutor` re-export and implementations ([#993]). # 0.1.17 (March 13, 2019) ### Added -- Propagate trace subscriber in the runtime (#966). +- Propagate trace subscriber in the runtime ([#966]). # 0.1.16 (March 1, 2019) ### Fixed -- async-await: track latest nightly changes (#940). +- async-await: track latest nightly changes ([#940]). ### Added -- `sync::Watch`, a single value broadcast channel (#922). -- Async equivalent of read / write file helpers being added to `std` (#896). +- `sync::Watch`, a single value broadcast channel ([#922]). +- Async equivalent of read / write file helpers being added to `std` ([#896]). # 0.1.15 (January 24, 2019) ### Added -- Re-export tokio-sync APIs (#839). -- Stream enumerate combinator (#832). +- Re-export tokio-sync APIs ([#839]). +- Stream enumerate combinator ([#832]). # 0.1.14 (January 6, 2019) * Use feature flags to break up the crate, allowing users to pick & choose - components (#808). -* Export `UnixDatagram` and `UnixDatagramFramed` (#772). + components ([#808]). +* Export `UnixDatagram` and `UnixDatagramFramed` ([#772]). # 0.1.13 (November 21, 2018) -* Fix `Runtime::reactor()` when no tasks are spawned (#721). -* `runtime::Builder` no longer uses deprecated methods (#749). +* Fix `Runtime::reactor()` when no tasks are spawned ([#721]). +* `runtime::Builder` no longer uses deprecated methods ([#749]). * Provide `after_start` and `before_stop` configuration settings for - `Runtime` (#756). -* Implement throttle stream combinator (#736). + `Runtime` ([#756]). +* Implement throttle stream combinator ([#736]). # 0.1.12 (October 23, 2018) -* runtime: expose `keep_alive` on runtime builder (#676). -* runtime: create a reactor per worker thread (#660). -* codec: fix panic in `LengthDelimitedCodec` (#682). -* io: re-export `tokio_io::io::read` function (#689). -* runtime: check for executor re-entry in more places (#708). +* runtime: expose `keep_alive` on runtime builder ([#676]). +* runtime: create a reactor per worker thread ([#660]). +* codec: fix panic in `LengthDelimitedCodec` ([#682]). +* io: re-export `tokio_io::io::read` function ([#689]). +* runtime: check for executor re-entry in more places ([#708]). # 0.1.11 (September 28, 2018) -* Fix `tokio-async-await` dependency (#675). +* Fix `tokio-async-await` dependency ([#675]). # 0.1.10 (September 27, 2018) @@ -334,65 +376,65 @@ another. This changelog entry contains a highlight # 0.1.9 (September 27, 2018) -* Experimental async/await improvements (#661). -* Re-export `TaskExecutor` from `tokio-current-thread` (#652). -* Improve `Runtime` builder API (#645). +* Experimental async/await improvements ([#661]). +* Re-export `TaskExecutor` from `tokio-current-thread` ([#652]). +* Improve `Runtime` builder API ([#645]). * `tokio::run` panics when called from the context of an executor - (#646). -* Introduce `StreamExt` with a `timeout` helper (#573). -* Move `length_delimited` into `tokio` (#575). -* Re-organize `tokio::net` module (#548). + ([#646]). +* Introduce `StreamExt` with a `timeout` helper ([#573]). +* Move `length_delimited` into `tokio` ([#575]). +* Re-organize `tokio::net` module ([#548]). * Re-export `tokio-current-thread::spawn` in current_thread runtime - (#579). + ([#579]). # 0.1.8 (August 23, 2018) -* Extract tokio::executor::current_thread to a sub crate (#370) -* Add `Runtime::block_on` (#398) -* Add `runtime::current_thread::block_on_all` (#477) -* Misc documentation improvements (#450) -* Implement `std::error::Error` for error types (#501) +* Extract tokio::executor::current_thread to a sub crate ([#370]) +* Add `Runtime::block_on` ([#398]) +* Add `runtime::current_thread::block_on_all` ([#477]) +* Misc documentation improvements ([#450]) +* Implement `std::error::Error` for error types ([#501]) # 0.1.7 (June 6, 2018) -* Add `Runtime::block_on` for concurrent runtime (#391). +* Add `Runtime::block_on` for concurrent runtime ([#391]). * Provide handle to `current_thread::Runtime` that allows spawning tasks from - other threads (#340). -* Provide `clock::now()`, a configurable source of time (#381). + other threads ([#340]). +* Provide `clock::now()`, a configurable source of time ([#381]). # 0.1.6 (May 2, 2018) -* Add asynchronous filesystem APIs (#323). -* Add "current thread" runtime variant (#308). +* Add asynchronous filesystem APIs ([#323]). +* Add "current thread" runtime variant ([#308]). * `CurrentThread`: Expose inner `Park` instance. -* Improve fairness of `CurrentThread` executor (#313). +* Improve fairness of `CurrentThread` executor ([#313]). # 0.1.5 (March 30, 2018) -* Provide timer API (#266) +* Provide timer API ([#266]) # 0.1.4 (March 22, 2018) -* Fix build on FreeBSD (#218) -* Shutdown the Runtime when the handle is dropped (#214) -* Set Runtime thread name prefix for worker threads (#232) -* Add builder for Runtime (#234) -* Extract TCP and UDP types into separate crates (#224) +* Fix build on FreeBSD ([#218]) +* Shutdown the Runtime when the handle is dropped ([#214]) +* Set Runtime thread name prefix for worker threads ([#232]) +* Add builder for Runtime ([#234]) +* Extract TCP and UDP types into separate crates ([#224]) * Optionally support futures 0.2. # 0.1.3 (March 09, 2018) -* Fix `CurrentThread::turn` to block on idle (#212). +* Fix `CurrentThread::turn` to block on idle ([#212]). # 0.1.2 (March 09, 2018) -* Introduce Tokio Runtime (#141) -* Provide `CurrentThread` for more flexible usage of current thread executor (#141). -* Add Lio for platforms that support it (#142). -* I/O resources now lazily bind to the reactor (#160). -* Extract Reactor to dedicated crate (#169) -* Add facade to sub crates and add prelude (#166). -* Switch TCP/UDP fns to poll_ -> Poll<...> style (#175) +* Introduce Tokio Runtime ([#141]) +* Provide `CurrentThread` for more flexible usage of current thread executor ([#141]). +* Add Lio for platforms that support it ([#142]). +* I/O resources now lazily bind to the reactor ([#160]). +* Extract Reactor to dedicated crate ([#169]) +* Add facade to sub crates and add prelude ([#166]). +* Switch TCP/UDP fns to poll_ -> Poll<...> style ([#175]) # 0.1.1 (February 09, 2018) @@ -401,3 +443,174 @@ another. This changelog entry contains a highlight # 0.1.0 (February 07, 2018) * Initial crate released based on [RFC](https://github.com/tokio-rs/tokio-rfcs/pull/3). + +[#2375]: https://github.com/tokio-rs/tokio/pull/2375 +[#2362]: https://github.com/tokio-rs/tokio/pull/2362 +[#2358]: https://github.com/tokio-rs/tokio/pull/2358 +[#2354]: https://github.com/tokio-rs/tokio/pull/2354 +[#2335]: https://github.com/tokio-rs/tokio/pull/2335 +[#2333]: https://github.com/tokio-rs/tokio/pull/2333 +[#2325]: https://github.com/tokio-rs/tokio/pull/2325 +[#2321]: https://github.com/tokio-rs/tokio/pull/2321 +[#2300]: https://github.com/tokio-rs/tokio/pull/2300 +[#2285]: https://github.com/tokio-rs/tokio/pull/2285 +[#2281]: https://github.com/tokio-rs/tokio/pull/2281 +[#2275]: https://github.com/tokio-rs/tokio/pull/2275 +[#2274]: https://github.com/tokio-rs/tokio/pull/2274 +[#2273]: https://github.com/tokio-rs/tokio/pull/2273 +[#2253]: https://github.com/tokio-rs/tokio/pull/2253 +[#2250]: https://github.com/tokio-rs/tokio/pull/2250 +[#2245]: https://github.com/tokio-rs/tokio/pull/2245 +[#2239]: https://github.com/tokio-rs/tokio/pull/2239 +[#2238]: https://github.com/tokio-rs/tokio/pull/2238 +[#2227]: https://github.com/tokio-rs/tokio/pull/2227 +[#2218]: https://github.com/tokio-rs/tokio/pull/2218 +[#2217]: https://github.com/tokio-rs/tokio/pull/2217 +[#2210]: https://github.com/tokio-rs/tokio/pull/2210 +[#2205]: https://github.com/tokio-rs/tokio/pull/2205 +[#2204]: https://github.com/tokio-rs/tokio/pull/2204 +[#2191]: https://github.com/tokio-rs/tokio/pull/2191 +[#2186]: https://github.com/tokio-rs/tokio/pull/2186 +[#2185]: https://github.com/tokio-rs/tokio/pull/2185 +[#2184]: https://github.com/tokio-rs/tokio/pull/2184 +[#2177]: https://github.com/tokio-rs/tokio/pull/2177 +[#2169]: https://github.com/tokio-rs/tokio/pull/2169 +[#2168]: https://github.com/tokio-rs/tokio/pull/2168 +[#2164]: https://github.com/tokio-rs/tokio/pull/2164 +[#2163]: https://github.com/tokio-rs/tokio/pull/2163 +[#2158]: https://github.com/tokio-rs/tokio/pull/2158 +[#2152]: https://github.com/tokio-rs/tokio/pull/2152 +[#2151]: https://github.com/tokio-rs/tokio/pull/2151 +[#2149]: https://github.com/tokio-rs/tokio/pull/2149 +[#2145]: https://github.com/tokio-rs/tokio/pull/2145 +[#2139]: https://github.com/tokio-rs/tokio/pull/2139 +[#2135]: https://github.com/tokio-rs/tokio/pull/2135 +[#2126]: https://github.com/tokio-rs/tokio/pull/2126 +[#2125]: https://github.com/tokio-rs/tokio/pull/2125 +[#2122]: https://github.com/tokio-rs/tokio/pull/2122 +[#2119]: https://github.com/tokio-rs/tokio/pull/2119 +[#2118]: https://github.com/tokio-rs/tokio/pull/2118 +[#2109]: https://github.com/tokio-rs/tokio/pull/2109 +[#2108]: https://github.com/tokio-rs/tokio/pull/2108 +[#2094]: https://github.com/tokio-rs/tokio/pull/2094 +[#2093]: https://github.com/tokio-rs/tokio/pull/2093 +[#2092]: https://github.com/tokio-rs/tokio/pull/2092 +[#2091]: https://github.com/tokio-rs/tokio/pull/2091 +[#2089]: https://github.com/tokio-rs/tokio/pull/2089 +[#2085]: https://github.com/tokio-rs/tokio/pull/2085 +[#2079]: https://github.com/tokio-rs/tokio/pull/2079 +[#2059]: https://github.com/tokio-rs/tokio/pull/2059 +[#2052]: https://github.com/tokio-rs/tokio/pull/2052 +[#2051]: https://github.com/tokio-rs/tokio/pull/2051 +[#2045]: https://github.com/tokio-rs/tokio/pull/2045 +[#2044]: https://github.com/tokio-rs/tokio/pull/2044 +[#2040]: https://github.com/tokio-rs/tokio/pull/2040 +[#2035]: https://github.com/tokio-rs/tokio/pull/2035 +[#2034]: https://github.com/tokio-rs/tokio/pull/2034 +[#2030]: https://github.com/tokio-rs/tokio/pull/2030 +[#2029]: https://github.com/tokio-rs/tokio/pull/2029 +[#2025]: https://github.com/tokio-rs/tokio/pull/2025 +[#2022]: https://github.com/tokio-rs/tokio/pull/2022 +[#2014]: https://github.com/tokio-rs/tokio/pull/2014 +[#2012]: https://github.com/tokio-rs/tokio/pull/2012 +[#2011]: https://github.com/tokio-rs/tokio/pull/2011 +[#2006]: https://github.com/tokio-rs/tokio/pull/2006 +[#2005]: https://github.com/tokio-rs/tokio/pull/2005 +[#2001]: https://github.com/tokio-rs/tokio/pull/2001 +[#1991]: https://github.com/tokio-rs/tokio/pull/1991 +[#1986]: https://github.com/tokio-rs/tokio/pull/1986 +[#1978]: https://github.com/tokio-rs/tokio/pull/1978 +[#1977]: https://github.com/tokio-rs/tokio/pull/1977 +[#1975]: https://github.com/tokio-rs/tokio/pull/1975 +[#1973]: https://github.com/tokio-rs/tokio/pull/1973 +[#1972]: https://github.com/tokio-rs/tokio/pull/1972 +[#1971]: https://github.com/tokio-rs/tokio/pull/1971 +[#1962]: https://github.com/tokio-rs/tokio/pull/1962 +[#1961]: https://github.com/tokio-rs/tokio/pull/1961 +[#1956]: https://github.com/tokio-rs/tokio/pull/1956 +[#1949]: https://github.com/tokio-rs/tokio/pull/1949 +[#1943]: https://github.com/tokio-rs/tokio/pull/1943 +[#1939]: https://github.com/tokio-rs/tokio/pull/1939 +[#1924]: https://github.com/tokio-rs/tokio/pull/1924 +[#1905]: https://github.com/tokio-rs/tokio/pull/1905 +[#1904]: https://github.com/tokio-rs/tokio/pull/1904 +[#1898]: https://github.com/tokio-rs/tokio/pull/1898 +[#1892]: https://github.com/tokio-rs/tokio/pull/1892 +[#1888]: https://github.com/tokio-rs/tokio/pull/1888 +[#1882]: https://github.com/tokio-rs/tokio/pull/1882 +[#1881]: https://github.com/tokio-rs/tokio/pull/1881 +[#1875]: https://github.com/tokio-rs/tokio/pull/1875 +[#1874]: https://github.com/tokio-rs/tokio/pull/1874 +[#1870]: https://github.com/tokio-rs/tokio/pull/1870 +[#1864]: https://github.com/tokio-rs/tokio/pull/1864 +[#1863]: https://github.com/tokio-rs/tokio/pull/1863 +[#1861]: https://github.com/tokio-rs/tokio/pull/1861 +[#1858]: https://github.com/tokio-rs/tokio/pull/1858 +[#1856]: https://github.com/tokio-rs/tokio/pull/1856 +[#1854]: https://github.com/tokio-rs/tokio/pull/1854 +[#1849]: https://github.com/tokio-rs/tokio/pull/1849 +[#1843]: https://github.com/tokio-rs/tokio/pull/1843 +[#1841]: https://github.com/tokio-rs/tokio/pull/1841 +[#1839]: https://github.com/tokio-rs/tokio/pull/1839 +[#1834]: https://github.com/tokio-rs/tokio/pull/1834 +[#1831]: https://github.com/tokio-rs/tokio/pull/1831 +[#1827]: https://github.com/tokio-rs/tokio/pull/1827 +[#1826]: https://github.com/tokio-rs/tokio/pull/1826 +[#1772]: https://github.com/tokio-rs/tokio/pull/1772 +[#1755]: https://github.com/tokio-rs/tokio/pull/1755 +[#1733]: https://github.com/tokio-rs/tokio/pull/1733 +[#1699]: https://github.com/tokio-rs/tokio/pull/1699 +[#1111]: https://github.com/tokio-rs/tokio/pull/1111 +[#1055]: https://github.com/tokio-rs/tokio/pull/1055 +[#993]: https://github.com/tokio-rs/tokio/pull/993 +[#966]: https://github.com/tokio-rs/tokio/pull/966 +[#964]: https://github.com/tokio-rs/tokio/pull/964 +[#940]: https://github.com/tokio-rs/tokio/pull/940 +[#922]: https://github.com/tokio-rs/tokio/pull/922 +[#896]: https://github.com/tokio-rs/tokio/pull/896 +[#839]: https://github.com/tokio-rs/tokio/pull/839 +[#832]: https://github.com/tokio-rs/tokio/pull/832 +[#808]: https://github.com/tokio-rs/tokio/pull/808 +[#772]: https://github.com/tokio-rs/tokio/pull/772 +[#756]: https://github.com/tokio-rs/tokio/pull/756 +[#749]: https://github.com/tokio-rs/tokio/pull/749 +[#736]: https://github.com/tokio-rs/tokio/pull/736 +[#721]: https://github.com/tokio-rs/tokio/pull/721 +[#708]: https://github.com/tokio-rs/tokio/pull/708 +[#689]: https://github.com/tokio-rs/tokio/pull/689 +[#682]: https://github.com/tokio-rs/tokio/pull/682 +[#676]: https://github.com/tokio-rs/tokio/pull/676 +[#675]: https://github.com/tokio-rs/tokio/pull/675 +[#661]: https://github.com/tokio-rs/tokio/pull/661 +[#660]: https://github.com/tokio-rs/tokio/pull/660 +[#652]: https://github.com/tokio-rs/tokio/pull/652 +[#646]: https://github.com/tokio-rs/tokio/pull/646 +[#645]: https://github.com/tokio-rs/tokio/pull/645 +[#579]: https://github.com/tokio-rs/tokio/pull/579 +[#575]: https://github.com/tokio-rs/tokio/pull/575 +[#573]: https://github.com/tokio-rs/tokio/pull/573 +[#548]: https://github.com/tokio-rs/tokio/pull/548 +[#501]: https://github.com/tokio-rs/tokio/pull/501 +[#477]: https://github.com/tokio-rs/tokio/pull/477 +[#450]: https://github.com/tokio-rs/tokio/pull/450 +[#398]: https://github.com/tokio-rs/tokio/pull/398 +[#391]: https://github.com/tokio-rs/tokio/pull/391 +[#381]: https://github.com/tokio-rs/tokio/pull/381 +[#370]: https://github.com/tokio-rs/tokio/pull/370 +[#340]: https://github.com/tokio-rs/tokio/pull/340 +[#323]: https://github.com/tokio-rs/tokio/pull/323 +[#313]: https://github.com/tokio-rs/tokio/pull/313 +[#308]: https://github.com/tokio-rs/tokio/pull/308 +[#266]: https://github.com/tokio-rs/tokio/pull/266 +[#234]: https://github.com/tokio-rs/tokio/pull/234 +[#232]: https://github.com/tokio-rs/tokio/pull/232 +[#224]: https://github.com/tokio-rs/tokio/pull/224 +[#218]: https://github.com/tokio-rs/tokio/pull/218 +[#214]: https://github.com/tokio-rs/tokio/pull/214 +[#212]: https://github.com/tokio-rs/tokio/pull/212 +[#175]: https://github.com/tokio-rs/tokio/pull/175 +[#169]: https://github.com/tokio-rs/tokio/pull/169 +[#166]: https://github.com/tokio-rs/tokio/pull/166 +[#160]: https://github.com/tokio-rs/tokio/pull/160 +[#142]: https://github.com/tokio-rs/tokio/pull/142 +[#141]: https://github.com/tokio-rs/tokio/pull/141 diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index 3d075bbc2d5..969fbcf01c5 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -8,12 +8,12 @@ name = "tokio" # - README.md # - Update CHANGELOG.md. # - Create "v0.2.x" git tag. -version = "0.2.17" +version = "0.2.21" edition = "2018" authors = ["Tokio Contributors "] license = "MIT" readme = "README.md" -documentation = "https://docs.rs/tokio/0.2.17/tokio/" +documentation = "https://docs.rs/tokio/0.2.21/tokio/" repository = "https://github.com/tokio-rs/tokio" homepage = "https://tokio.rs" description = """ @@ -106,6 +106,7 @@ iovec = { version = "0.1.4", optional = true } num_cpus = { version = "1.8.0", optional = true } parking_lot = { version = "0.10.0", optional = true } # Not in full slab = { version = "0.4.1", optional = true } # Backs `DelayQueue` +tracing = { version = "0.1.16", default-features = false, features = ["std"], optional = true } # Not in full [target.'cfg(unix)'.dependencies] mio-uds = { version = "0.6.5", optional = true } @@ -121,15 +122,14 @@ default-features = false optional = true [dev-dependencies] -tokio-test = { version = "0.2.0" } +tokio-test = { version = "0.2.0", path = "../tokio-test" } futures = { version = "0.3.0", features = ["async-await"] } +futures-test = "0.3.0" proptest = "0.9.4" tempfile = "3.1.0" -# loom is currently not compiling on windows. -# See: https://github.com/Xudong-Huang/generator-rs/issues/19 -[target.'cfg(not(windows))'.dev-dependencies] -loom = { version = "0.3.1", features = ["futures", "checkpoint"] } +[target.'cfg(loom)'.dev-dependencies] +loom = { version = "0.3.4", features = ["futures", "checkpoint"] } [package.metadata.docs.rs] all-features = true diff --git a/tokio/README.md b/tokio/README.md index 080181f8975..31b2ae12a6c 100644 --- a/tokio/README.md +++ b/tokio/README.md @@ -20,16 +20,17 @@ the Rust programming language. It is: [crates-badge]: https://img.shields.io/crates/v/tokio.svg [crates-url]: https://crates.io/crates/tokio [mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg -[mit-url]: LICENSE +[mit-url]: https://github.com/tokio-rs/tokio/blob/master/LICENSE [azure-badge]: https://dev.azure.com/tokio-rs/Tokio/_apis/build/status/tokio-rs.tokio?branchName=master [azure-url]: https://dev.azure.com/tokio-rs/Tokio/_build/latest?definitionId=1&branchName=master [discord-badge]: https://img.shields.io/discord/500028886025895936.svg?logo=discord&style=flat-square -[discord-url]: https://discord.gg/6yGkFeN +[discord-url]: https://discord.gg/tokio [Website](https://tokio.rs) | -[Guides](https://tokio.rs/docs/) | -[API Docs](https://docs.rs/tokio/0.2/tokio) | -[Chat](https://discord.gg/6yGkFeN) +[Guides](https://tokio.rs/docs/overview/) | +[API Docs](https://docs.rs/tokio/latest/tokio) | +[Roadmap](https://github.com/tokio-rs/tokio/blob/master/ROADMAP.md) | +[Chat](https://discord.gg/tokio) ## Overview @@ -45,20 +46,11 @@ level, it provides a few major components: These components provide the runtime components necessary for building an asynchronous application. -[net]: https://docs.rs/tokio/0.2/tokio/net/index.html -[scheduler]: https://docs.rs/tokio/0.2/tokio/runtime/index.html +[net]: https://docs.rs/tokio/latest/tokio/net/index.html +[scheduler]: https://docs.rs/tokio/latest/tokio/runtime/index.html ## Example -To get started, add the following to `Cargo.toml`. - -```toml -tokio = { version = "0.2", features = ["full"] } -``` - -Tokio requires components to be explicitly enabled using feature flags. As a -shorthand, the `full` feature enables all components. - A basic TCP echo server with Tokio: ```rust,no_run @@ -98,19 +90,23 @@ async fn main() -> Result<(), Box> { } ``` -More examples can be found [here](../examples). +More examples can be found [here][examples]. For a larger "real world" example, see the +[mini-redis] repository. + +[examples]: https://github.com/tokio-rs/tokio/tree/master/examples +[mini-redis]: https://github.com/tokio-rs/mini-redis/ ## Getting Help First, see if the answer to your question can be found in the [Guides] or the [API documentation]. If the answer is not there, there is an active community in the [Tokio Discord server][chat]. We would be happy to try to answer your -question. Last, if that doesn't work, try opening an [issue] with the question. +question. You can also ask your question on [the discussions page][discussions]. -[Guides]: https://tokio.rs/docs/ -[API documentation]: https://docs.rs/tokio/0.2 -[chat]: https://discord.gg/6yGkFeN -[issue]: https://github.com/tokio-rs/tokio/issues/new +[Guides]: https://tokio.rs/docs/overview/ +[API documentation]: https://docs.rs/tokio/latest/tokio +[chat]: https://discord.gg/tokio +[discussions]: https://github.com/tokio-rs/tokio/discussions ## Contributing @@ -118,36 +114,54 @@ question. Last, if that doesn't work, try opening an [issue] with the question. you! We have a [contributing guide][guide] to help you get involved in the Tokio project. -[guide]: CONTRIBUTING.md +[guide]: https://github.com/tokio-rs/tokio/blob/master/CONTRIBUTING.md ## Related Projects In addition to the crates in this repository, the Tokio project also maintains several other libraries, including: +* [`hyper`]: A fast and correct HTTP/1.1 and HTTP/2 implementation for Rust. + +* [`tonic`]: A gRPC over HTTP/2 implementation focused on high performance, interoperability, and flexibility. + +* [`warp`]: A super-easy, composable, web server framework for warp speeds. + +* [`tower`]: A library of modular and reusable components for building robust networking clients and servers. + * [`tracing`] (formerly `tokio-trace`): A framework for application-level tracing and async-aware diagnostics. +* [`rdbc`]: A Rust database connectivity library for MySQL, Postgres and SQLite. + * [`mio`]: A low-level, cross-platform abstraction over OS I/O APIs that powers `tokio`. * [`bytes`]: Utilities for working with bytes, including efficient byte buffers. +* [`loom`]: A testing tool for concurrent Rust code + +[`warp`]: https://github.com/seanmonstar/warp +[`hyper`]: https://github.com/hyperium/hyper +[`tonic`]: https://github.com/hyperium/tonic +[`tower`]: https://github.com/tower-rs/tower +[`loom`]: https://github.com/tokio-rs/loom +[`rdbc`]: https://github.com/tokio-rs/rdbc [`tracing`]: https://github.com/tokio-rs/tracing [`mio`]: https://github.com/tokio-rs/mio [`bytes`]: https://github.com/tokio-rs/bytes ## Supported Rust Versions -Tokio is built against the latest stable, nightly, and beta Rust releases. The -minimum version supported is the stable release from three months before the -current stable release version. For example, if the latest stable Rust is 1.29, -the minimum version supported is 1.26. The current Tokio version is not -guaranteed to build on Rust versions earlier than the minimum supported version. +Tokio is built against the latest stable release. The minimum supported version is 1.39. +The current Tokio version is not guaranteed to build on Rust versions earlier than the +minimum supported version. ## License -This project is licensed under the [MIT license](LICENSE). +This project is licensed under the [MIT license]. + +[MIT license]: https://github.com/tokio-rs/tokio/blob/master/LICENSE ### Contribution diff --git a/tokio/src/coop.rs b/tokio/src/coop.rs index 78905e3ccfc..27e969c59d4 100644 --- a/tokio/src/coop.rs +++ b/tokio/src/coop.rs @@ -1,11 +1,12 @@ //! Opt-in yield points for improved cooperative scheduling. //! -//! A single call to [`poll`] on a top-level task may potentially do a lot of work before it -//! returns `Poll::Pending`. If a task runs for a long period of time without yielding back to the -//! executor, it can starve other tasks waiting on that executor to execute them, or drive -//! underlying resources. Since Rust does not have a runtime, it is difficult to forcibly preempt a -//! long-running task. Instead, this module provides an opt-in mechanism for futures to collaborate -//! with the executor to avoid starvation. +//! A single call to [`poll`] on a top-level task may potentially do a lot of +//! work before it returns `Poll::Pending`. If a task runs for a long period of +//! time without yielding back to the executor, it can starve other tasks +//! waiting on that executor to execute them, or drive underlying resources. +//! Since Rust does not have a runtime, it is difficult to forcibly preempt a +//! long-running task. Instead, this module provides an opt-in mechanism for +//! futures to collaborate with the executor to avoid starvation. //! //! Consider a future like this one: //! @@ -16,9 +17,10 @@ //! } //! ``` //! -//! It may look harmless, but consider what happens under heavy load if the input stream is -//! _always_ ready. If we spawn `drop_all`, the task will never yield, and will starve other tasks -//! and resources on the same executor. With opt-in yield points, this problem is alleviated: +//! It may look harmless, but consider what happens under heavy load if the +//! input stream is _always_ ready. If we spawn `drop_all`, the task will never +//! yield, and will starve other tasks and resources on the same executor. With +//! opt-in yield points, this problem is alleviated: //! //! ```ignore //! # use tokio::stream::{Stream, StreamExt}; @@ -29,71 +31,99 @@ //! } //! ``` //! -//! The `proceed` future will coordinate with the executor to make sure that every so often control -//! is yielded back to the executor so it can run other tasks. +//! The `proceed` future will coordinate with the executor to make sure that +//! every so often control is yielded back to the executor so it can run other +//! tasks. //! //! # Placing yield points //! -//! Voluntary yield points should be placed _after_ at least some work has been done. If they are -//! not, a future sufficiently deep in the task hierarchy may end up _never_ getting to run because -//! of the number of yield points that inevitably appear before it is reached. In general, you will -//! want yield points to only appear in "leaf" futures -- those that do not themselves poll other -//! futures. By doing this, you avoid double-counting each iteration of the outer future against -//! the cooperating budget. +//! Voluntary yield points should be placed _after_ at least some work has been +//! done. If they are not, a future sufficiently deep in the task hierarchy may +//! end up _never_ getting to run because of the number of yield points that +//! inevitably appear before it is reached. In general, you will want yield +//! points to only appear in "leaf" futures -- those that do not themselves poll +//! other futures. By doing this, you avoid double-counting each iteration of +//! the outer future against the cooperating budget. //! -//! [`poll`]: https://doc.rust-lang.org/std/future/trait.Future.html#tymethod.poll +//! [`poll`]: method@std::future::Future::poll // NOTE: The doctests in this module are ignored since the whole module is (currently) private. use std::cell::Cell; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; - -/// Constant used to determine how much "work" a task is allowed to do without yielding. -/// -/// The value itself is chosen somewhat arbitrarily. It needs to be high enough to amortize wakeup -/// and scheduling costs, but low enough that we do not starve other tasks for too long. The value -/// also needs to be high enough that particularly deep tasks are able to do at least some useful -/// work at all. -/// -/// Note that as more yield points are added in the ecosystem, this value will probably also have -/// to be raised. -const BUDGET: usize = 128; - -/// Constant used to determine if budgeting has been disabled. -const UNCONSTRAINED: usize = usize::max_value(); thread_local! { - static HITS: Cell = Cell::new(UNCONSTRAINED); + static CURRENT: Cell = Cell::new(Budget::unconstrained()); } -/// Run the given closure with a cooperative task budget. -/// -/// Enabling budgeting when it is already enabled is a no-op. -#[inline(always)] -pub(crate) fn budget(f: F) -> R -where - F: FnOnce() -> R, -{ - HITS.with(move |hits| { - if hits.get() != UNCONSTRAINED { - // We are already being budgeted. - // - // Arguably this should be an error, but it can happen "correctly" - // such as with block_on + LocalSet, so we make it a no-op. - return f(); +/// Opaque type tracking the amount of "work" a task may still do before +/// yielding back to the scheduler. +#[derive(Debug, Copy, Clone)] +pub(crate) struct Budget(Option); + +impl Budget { + /// Budget assigned to a task on each poll. + /// + /// The value itself is chosen somewhat arbitrarily. It needs to be high + /// enough to amortize wakeup and scheduling costs, but low enough that we + /// do not starve other tasks for too long. The value also needs to be high + /// enough that particularly deep tasks are able to do at least some useful + /// work at all. + /// + /// Note that as more yield points are added in the ecosystem, this value + /// will probably also have to be raised. + const fn initial() -> Budget { + Budget(Some(128)) + } + + /// Returns an unconstrained budget. Operations will not be limited. + const fn unconstrained() -> Budget { + Budget(None) + } +} + +cfg_rt_threaded! { + impl Budget { + fn has_remaining(self) -> bool { + self.0.map(|budget| budget > 0).unwrap_or(true) } + } +} - struct Guard<'a>(&'a Cell); - impl<'a> Drop for Guard<'a> { - fn drop(&mut self) { - self.0.set(UNCONSTRAINED); - } +/// Run the given closure with a cooperative task budget. When the function +/// returns, the budget is reset to the value prior to calling the function. +#[inline(always)] +pub(crate) fn budget(f: impl FnOnce() -> R) -> R { + with_budget(Budget::initial(), f) +} + +cfg_rt_threaded! { + /// Set the current task's budget + #[cfg(feature = "blocking")] + pub(crate) fn set(budget: Budget) { + CURRENT.with(|cell| cell.set(budget)) + } +} + +#[inline(always)] +fn with_budget(budget: Budget, f: impl FnOnce() -> R) -> R { + struct ResetGuard<'a> { + cell: &'a Cell, + prev: Budget, + } + + impl<'a> Drop for ResetGuard<'a> { + fn drop(&mut self) { + self.cell.set(self.prev); } + } + + CURRENT.with(move |cell| { + let prev = cell.get(); + + cell.set(budget); + + let _guard = ResetGuard { cell, prev }; - hits.set(BUDGET); - let _guard = Guard(hits); f() }) } @@ -101,279 +131,171 @@ where cfg_rt_threaded! { #[inline(always)] pub(crate) fn has_budget_remaining() -> bool { - HITS.with(|hits| hits.get() > 0) + CURRENT.with(|cell| cell.get().has_remaining()) } } cfg_blocking_impl! { /// Forcibly remove the budgeting constraints early. - pub(crate) fn stop() { - HITS.with(|hits| { - hits.set(UNCONSTRAINED); - }); + /// + /// Returns the remaining budget + pub(crate) fn stop() -> Budget { + CURRENT.with(|cell| { + let prev = cell.get(); + cell.set(Budget::unconstrained()); + prev + }) } } -/// Invoke `f` with a subset of the remaining budget. -/// -/// This is useful if you have sub-futures that you need to poll, but that you want to restrict -/// from using up your entire budget. For example, imagine the following future: -/// -/// ```rust -/// # use std::{future::Future, pin::Pin, task::{Context, Poll}}; -/// use futures::stream::FuturesUnordered; -/// struct MyFuture { -/// big: FuturesUnordered, -/// small: F2, -/// } -/// -/// use tokio::stream::Stream; -/// impl Future for MyFuture -/// where F1: Future, F2: Future -/// # , F1: Unpin, F2: Unpin -/// { -/// type Output = F2::Output; -/// -/// // fn poll(...) -/// # fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { -/// # let this = &mut *self; -/// let mut big = // something to pin self.big -/// # Pin::new(&mut this.big); -/// let small = // something to pin self.small -/// # Pin::new(&mut this.small); -/// -/// // see if any of the big futures have finished -/// while let Some(e) = futures::ready!(big.as_mut().poll_next(cx)) { -/// // do something with e -/// # let _ = e; -/// } -/// -/// // see if the small future has finished -/// small.poll(cx) -/// } -/// # } -/// ``` -/// -/// It could be that every time `poll` gets called, `big` ends up spending the entire budget, and -/// `small` never gets polled. That would be sad. If you want to stick up for the little future, -/// that's what `limit` is for. It lets you portion out a smaller part of the yield budget to a -/// particular segment of your code. In the code above, you would write -/// -/// ```rust,ignore -/// # use std::{future::Future, pin::Pin, task::{Context, Poll}}; -/// # use futures::stream::FuturesUnordered; -/// # struct MyFuture { -/// # big: FuturesUnordered, -/// # small: F2, -/// # } -/// # -/// # use tokio::stream::Stream; -/// # impl Future for MyFuture -/// # where F1: Future, F2: Future -/// # , F1: Unpin, F2: Unpin -/// # { -/// # type Output = F2::Output; -/// # fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { -/// # let this = &mut *self; -/// # let mut big = Pin::new(&mut this.big); -/// # let small = Pin::new(&mut this.small); -/// # -/// // see if any of the big futures have finished -/// while let Some(e) = futures::ready!(tokio::coop::limit(64, || big.as_mut().poll_next(cx))) { -/// # // do something with e -/// # let _ = e; -/// # } -/// # small.poll(cx) -/// # } -/// # } -/// ``` -/// -/// Now, even if `big` spends its entire budget, `small` will likely be left with some budget left -/// to also do useful work. In particular, if the remaining budget was `N` at the start of `poll`, -/// `small` will have at least a budget of `N - 64`. It may be more if `big` did not spend its -/// entire budget. -/// -/// Note that you cannot _increase_ your budget by calling `limit`. The budget provided to the code -/// inside the buget is the _minimum_ of the _current_ budget and the bound. -/// -#[allow(unreachable_pub, dead_code)] -pub fn limit(bound: usize, f: impl FnOnce() -> R) -> R { - HITS.with(|hits| { - let budget = hits.get(); - // with_bound cannot _increase_ the remaining budget - let bound = std::cmp::min(budget, bound); - // When f() exits, how much should we add to what is left? - let floor = budget.saturating_sub(bound); - // Make sure we restore the remaining budget even on panic - struct RestoreBudget<'a>(&'a Cell, usize); - impl<'a> Drop for RestoreBudget<'a> { - fn drop(&mut self) { - let left = self.0.get(); - self.0.set(self.1 + left); - } - } - // Time to restrict! - hits.set(bound); - let _restore = RestoreBudget(&hits, floor); - f() - }) -} +cfg_coop! { + use std::task::{Context, Poll}; + + #[must_use] + pub(crate) struct RestoreOnPending(Cell); -/// Returns `Poll::Pending` if the current task has exceeded its budget and should yield. -#[allow(unreachable_pub, dead_code)] -#[inline] -pub fn poll_proceed(cx: &mut Context<'_>) -> Poll<()> { - HITS.with(|hits| { - let n = hits.get(); - if n == UNCONSTRAINED { - // opted out of budgeting - Poll::Ready(()) - } else if n == 0 { - cx.waker().wake_by_ref(); - Poll::Pending - } else { - hits.set(n.saturating_sub(1)); - Poll::Ready(()) + impl RestoreOnPending { + pub(crate) fn made_progress(&self) { + self.0.set(Budget::unconstrained()); } - }) -} + } -/// Resolves immediately unless the current task has already exceeded its budget. -/// -/// This should be placed after at least some work has been done. Otherwise a future sufficiently -/// deep in the task hierarchy may end up never getting to run because of the number of yield -/// points that inevitably appear before it is even reached. For example: -/// -/// ```ignore -/// # use tokio::stream::{Stream, StreamExt}; -/// async fn drop_all(mut input: I) { -/// while let Some(_) = input.next().await { -/// tokio::coop::proceed().await; -/// } -/// } -/// ``` -#[allow(unreachable_pub, dead_code)] -#[inline] -pub async fn proceed() { - use crate::future::poll_fn; - poll_fn(|cx| poll_proceed(cx)).await; -} + impl Drop for RestoreOnPending { + fn drop(&mut self) { + // Don't reset if budget was unconstrained or if we made progress. + // They are both represented as the remembered budget being unconstrained. + let budget = self.0.get(); + if !budget.is_unconstrained() { + CURRENT.with(|cell| { + cell.set(budget); + }); + } + } + } -pin_project_lite::pin_project! { - /// A future that cooperatively yields to the task scheduler when polling, - /// if the task's budget is exhausted. - /// - /// Internally, this is simply a future combinator which calls - /// [`poll_proceed`] in its `poll` implementation before polling the wrapped - /// future. - /// - /// # Examples + /// Returns `Poll::Pending` if the current task has exceeded its budget and should yield. /// - /// ```rust,ignore - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::coop::CoopFutureExt; + /// When you call this method, the current budget is decremented. However, to ensure that + /// progress is made every time a task is polled, the budget is automatically restored to its + /// former value if the returned `RestoreOnPending` is dropped. It is the caller's + /// responsibility to call `RestoreOnPending::made_progress` if it made progress, to ensure + /// that the budget empties appropriately. /// - /// async { /* ... */ } - /// .cooperate() - /// .await; - /// # } - /// ``` - /// - /// [`poll_proceed`]: fn.poll_proceed.html - #[derive(Debug)] - #[allow(unreachable_pub, dead_code)] - pub struct CoopFuture { - #[pin] - future: F, - } -} + /// Note that `RestoreOnPending` restores the budget **as it was before `poll_proceed`**. + /// Therefore, if the budget is _further_ adjusted between when `poll_proceed` returns and + /// `RestRestoreOnPending` is dropped, those adjustments are erased unless the caller indicates + /// that progress was made. + #[inline] + pub(crate) fn poll_proceed(cx: &mut Context<'_>) -> Poll { + CURRENT.with(|cell| { + let mut budget = cell.get(); -impl Future for CoopFuture { - type Output = F::Output; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - ready!(poll_proceed(cx)); - self.project().future.poll(cx) + if budget.decrement() { + let restore = RestoreOnPending(Cell::new(cell.get())); + cell.set(budget); + Poll::Ready(restore) + } else { + cx.waker().wake_by_ref(); + Poll::Pending + } + }) } -} -impl CoopFuture { - /// Returns a new `CoopFuture` wrapping the given future. - /// - #[allow(unreachable_pub, dead_code)] - pub fn new(future: F) -> Self { - Self { future } - } -} + impl Budget { + /// Decrement the budget. Returns `true` if successful. Decrementing fails + /// when there is not enough remaining budget. + fn decrement(&mut self) -> bool { + if let Some(num) = &mut self.0 { + if *num > 0 { + *num -= 1; + true + } else { + false + } + } else { + true + } + } -// Currently only used by `tokio::sync`; and if we make this combinator public, -// it should probably be on the `FutureExt` trait instead. -cfg_sync! { - /// Extension trait providing `Future::cooperate` extension method. - /// - /// Note: if/when the co-op API becomes public, this method should probably be - /// provided by `FutureExt`, instead. - pub(crate) trait CoopFutureExt: Future { - /// Wrap `self` to cooperatively yield to the scheduler when polling, if the - /// task's budget is exhausted. - fn cooperate(self) -> CoopFuture - where - Self: Sized, - { - CoopFuture::new(self) + fn is_unconstrained(self) -> bool { + self.0.is_none() } } - - impl CoopFutureExt for F where F: Future {} } #[cfg(all(test, not(loom)))] mod test { use super::*; - fn get() -> usize { - HITS.with(|hits| hits.get()) + fn get() -> Budget { + CURRENT.with(|cell| cell.get()) } #[test] fn bugeting() { + use futures::future::poll_fn; use tokio_test::*; - assert_eq!(get(), UNCONSTRAINED); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), UNCONSTRAINED); + assert!(get().0.is_none()); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + + assert!(get().0.is_none()); + drop(coop); + assert!(get().0.is_none()); + budget(|| { - assert_eq!(get(), BUDGET); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), BUDGET - 1); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), BUDGET - 2); + assert_eq!(get().0, Budget::initial().0); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + drop(coop); + // we didn't make progress + assert_eq!(get().0, Budget::initial().0); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + coop.made_progress(); + drop(coop); + // we _did_ make progress + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2); + coop.made_progress(); + drop(coop); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2); + + budget(|| { + assert_eq!(get().0, Budget::initial().0); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + coop.made_progress(); + drop(coop); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + }); + + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2); }); - assert_eq!(get(), UNCONSTRAINED); + + assert!(get().0.is_none()); budget(|| { - limit(3, || { - assert_eq!(get(), 3); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), 2); - limit(4, || { - assert_eq!(get(), 2); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), 1); - }); - assert_eq!(get(), 1); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), 0); - assert_pending!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), 0); - assert_pending!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), 0); - }); - assert_eq!(get(), BUDGET - 3); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), BUDGET - 4); - assert_ready!(task::spawn(proceed()).poll()); - assert_eq!(get(), BUDGET - 5); + let n = get().0.unwrap(); + + for _ in 0..n { + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + coop.made_progress(); + } + + let mut task = task::spawn(poll_fn(|cx| { + let coop = ready!(poll_proceed(cx)); + coop.made_progress(); + Poll::Ready(()) + })); + + assert_pending!(task.poll()); }); } } diff --git a/tokio/src/fs/dir_builder.rs b/tokio/src/fs/dir_builder.rs new file mode 100644 index 00000000000..8752a3716aa --- /dev/null +++ b/tokio/src/fs/dir_builder.rs @@ -0,0 +1,117 @@ +use crate::fs::asyncify; + +use std::io; +use std::path::Path; + +/// A builder for creating directories in various manners. +/// +/// Additional Unix-specific options are available via importing the +/// [`DirBuilderExt`] trait. +/// +/// This is a specialized version of [`std::fs::DirBuilder`] for usage on +/// the Tokio runtime. +/// +/// [std::fs::DirBuilder]: std::fs::DirBuilder +/// [`DirBuilderExt`]: crate::fs::os::unix::DirBuilderExt +#[derive(Debug, Default)] +pub struct DirBuilder { + /// Indicates whether to create parent directories if they are missing. + recursive: bool, + + /// Set the Unix mode for newly created directories. + #[cfg(unix)] + pub(super) mode: Option, +} + +impl DirBuilder { + /// Creates a new set of options with default mode/security settings for all + /// platforms and also non-recursive. + /// + /// This is an async version of [`std::fs::DirBuilder::new`][std] + /// + /// [std]: std::fs::DirBuilder::new + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::DirBuilder; + /// + /// let builder = DirBuilder::new(); + /// ``` + pub fn new() -> Self { + Default::default() + } + + /// Indicates whether to create directories recursively (including all parent directories). + /// Parents that do not exist are created with the same security and permissions settings. + /// + /// This option defaults to `false`. + /// + /// This is an async version of [`std::fs::DirBuilder::recursive`][std] + /// + /// [std]: std::fs::DirBuilder::recursive + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::DirBuilder; + /// + /// let mut builder = DirBuilder::new(); + /// builder.recursive(true); + /// ``` + pub fn recursive(&mut self, recursive: bool) -> &mut Self { + self.recursive = recursive; + self + } + + /// Creates the specified directory with the configured options. + /// + /// It is considered an error if the directory already exists unless + /// recursive mode is enabled. + /// + /// This is an async version of [`std::fs::DirBuilder::create`][std] + /// + /// [std]: std::fs::DirBuilder::create + /// + /// # Errors + /// + /// An error will be returned under the following circumstances: + /// + /// * Path already points to an existing file. + /// * Path already points to an existing directory and the mode is + /// non-recursive. + /// * The calling process doesn't have permissions to create the directory + /// or its missing parents. + /// * Other I/O error occurred. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::DirBuilder; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// DirBuilder::new() + /// .recursive(true) + /// .create("/tmp/foo/bar/baz") + /// .await?; + /// + /// Ok(()) + /// } + /// ``` + pub async fn create>(&self, path: P) -> io::Result<()> { + let path = path.as_ref().to_owned(); + let mut builder = std::fs::DirBuilder::new(); + builder.recursive(self.recursive); + + #[cfg(unix)] + { + if let Some(mode) = self.mode { + std::os::unix::fs::DirBuilderExt::mode(&mut builder, mode); + } + } + + asyncify(move || builder.create(path)).await + } +} diff --git a/tokio/src/fs/file.rs b/tokio/src/fs/file.rs index a1f22fc9b6c..f3bc98546a9 100644 --- a/tokio/src/fs/file.rs +++ b/tokio/src/fs/file.rs @@ -24,12 +24,28 @@ use std::task::Poll::*; /// Tokio runtime. /// /// An instance of a `File` can be read and/or written depending on what options -/// it was opened with. Files also implement Seek to alter the logical cursor -/// that the file contains internally. +/// it was opened with. Files also implement [`AsyncSeek`] to alter the logical +/// cursor that the file contains internally. /// -/// Files are automatically closed when they go out of scope. +/// A file will not be closed immediately when it goes out of scope if there +/// are any IO operations that have not yet completed. To ensure that a file is +/// closed immediately when it is dropped, you should call [`flush`] before +/// dropping it. Note that this does not ensure that the file has been fully +/// written to disk; the operating system might keep the changes around in an +/// in-memory buffer. See the [`sync_all`] method for telling the OS to write +/// the data to disk. /// -/// [std]: std::fs::File +/// Reading and writing to a `File` is usually done using the convenience +/// methods found on the [`AsyncReadExt`] and [`AsyncWriteExt`] traits. Examples +/// import these traits through [the prelude]. +/// +/// [std]: struct@std::fs::File +/// [`AsyncSeek`]: trait@crate::io::AsyncSeek +/// [`flush`]: fn@crate::io::AsyncWriteExt::flush +/// [`sync_all`]: fn@crate::fs::File::sync_all +/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt +/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt +/// [the prelude]: crate::prelude /// /// # Examples /// @@ -37,7 +53,7 @@ use std::task::Poll::*; /// /// ```no_run /// use tokio::fs::File; -/// use tokio::prelude::*; +/// use tokio::prelude::*; // for write_all() /// /// # async fn dox() -> std::io::Result<()> { /// let mut file = File::create("foo.txt").await?; @@ -50,7 +66,7 @@ use std::task::Poll::*; /// /// ```no_run /// use tokio::fs::File; -/// use tokio::prelude::*; +/// use tokio::prelude::*; // for read_to_end() /// /// # async fn dox() -> std::io::Result<()> { /// let mut file = File::open("foo.txt").await?; @@ -114,6 +130,11 @@ impl File { /// # Ok(()) /// # } /// ``` + /// + /// The [`read_to_end`] method is defined on the [`AsyncReadExt`] trait. + /// + /// [`read_to_end`]: fn@crate::io::AsyncReadExt::read_to_end + /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt pub async fn open(path: impl AsRef) -> io::Result { let path = path.as_ref().to_owned(); let std = asyncify(|| sys::File::open(path)).await?; @@ -149,6 +170,11 @@ impl File { /// # Ok(()) /// # } /// ``` + /// + /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait. + /// + /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all + /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt pub async fn create(path: impl AsRef) -> io::Result { let path = path.as_ref().to_owned(); let std_file = asyncify(move || sys::File::create(path)).await?; @@ -195,6 +221,11 @@ impl File { /// # Ok(()) /// # } /// ``` + /// + /// The [`read_exact`] method is defined on the [`AsyncReadExt`] trait. + /// + /// [`read_exact`]: fn@crate::io::AsyncReadExt::read_exact + /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt pub async fn seek(&mut self, mut pos: SeekFrom) -> io::Result { self.complete_inflight().await; @@ -251,6 +282,11 @@ impl File { /// # Ok(()) /// # } /// ``` + /// + /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait. + /// + /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all + /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt pub async fn sync_all(&mut self) -> io::Result<()> { self.complete_inflight().await; @@ -280,6 +316,11 @@ impl File { /// # Ok(()) /// # } /// ``` + /// + /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait. + /// + /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all + /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt pub async fn sync_data(&mut self) -> io::Result<()> { self.complete_inflight().await; @@ -312,6 +353,11 @@ impl File { /// # Ok(()) /// # } /// ``` + /// + /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait. + /// + /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all + /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt pub async fn set_len(&mut self, size: u64) -> io::Result<()> { self.complete_inflight().await; @@ -491,6 +537,11 @@ impl File { } impl AsyncRead for File { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { + // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/fs.rs#L668 + false + } + fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/tokio/src/fs/mod.rs b/tokio/src/fs/mod.rs index 3eb03764631..a2b062b1a30 100644 --- a/tokio/src/fs/mod.rs +++ b/tokio/src/fs/mod.rs @@ -33,6 +33,9 @@ pub use self::create_dir::create_dir; mod create_dir_all; pub use self::create_dir_all::create_dir_all; +mod dir_builder; +pub use self::dir_builder::DirBuilder; + mod file; pub use self::file::File; diff --git a/tokio/src/fs/open_options.rs b/tokio/src/fs/open_options.rs index 3210f4b7b5d..ba3d9a6cf67 100644 --- a/tokio/src/fs/open_options.rs +++ b/tokio/src/fs/open_options.rs @@ -382,6 +382,12 @@ impl OpenOptions { let std = asyncify(move || opts.open(path)).await?; Ok(File::from_std(std)) } + + /// Returns a mutable reference to the the underlying std::fs::OpenOptions + #[cfg(unix)] + pub(super) fn as_inner_mut(&mut self) -> &mut std::fs::OpenOptions { + &mut self.0 + } } impl From for OpenOptions { diff --git a/tokio/src/fs/os/unix/dir_builder_ext.rs b/tokio/src/fs/os/unix/dir_builder_ext.rs new file mode 100644 index 00000000000..e9a25b959c6 --- /dev/null +++ b/tokio/src/fs/os/unix/dir_builder_ext.rs @@ -0,0 +1,29 @@ +use crate::fs::dir_builder::DirBuilder; + +/// Unix-specific extensions to [`DirBuilder`]. +/// +/// [`DirBuilder`]: crate::fs::DirBuilder +pub trait DirBuilderExt { + /// Sets the mode to create new directories with. + /// + /// This option defaults to 0o777. + /// + /// # Examples + /// + /// + /// ```no_run + /// use tokio::fs::DirBuilder; + /// use tokio::fs::os::unix::DirBuilderExt; + /// + /// let mut builder = DirBuilder::new(); + /// builder.mode(0o775); + /// ``` + fn mode(&mut self, mode: u32) -> &mut Self; +} + +impl DirBuilderExt for DirBuilder { + fn mode(&mut self, mode: u32) -> &mut Self { + self.mode = Some(mode); + self + } +} diff --git a/tokio/src/fs/os/unix/mod.rs b/tokio/src/fs/os/unix/mod.rs index 3b0bec38bd5..826222ebf23 100644 --- a/tokio/src/fs/os/unix/mod.rs +++ b/tokio/src/fs/os/unix/mod.rs @@ -2,3 +2,9 @@ mod symlink; pub use self::symlink::symlink; + +mod open_options_ext; +pub use self::open_options_ext::OpenOptionsExt; + +mod dir_builder_ext; +pub use self::dir_builder_ext::DirBuilderExt; diff --git a/tokio/src/fs/os/unix/open_options_ext.rs b/tokio/src/fs/os/unix/open_options_ext.rs new file mode 100644 index 00000000000..ff892758804 --- /dev/null +++ b/tokio/src/fs/os/unix/open_options_ext.rs @@ -0,0 +1,79 @@ +use crate::fs::open_options::OpenOptions; +use std::os::unix::fs::OpenOptionsExt as StdOpenOptionsExt; + +/// Unix-specific extensions to [`fs::OpenOptions`]. +/// +/// This mirrors the definition of [`std::os::unix::fs::OpenOptionsExt`]. +/// +/// +/// [`fs::OpenOptions`]: crate::fs::OpenOptions +/// [`std::os::unix::fs::OpenOptionsExt`]: std::os::unix::fs::OpenOptionsExt +pub trait OpenOptionsExt { + /// Sets the mode bits that a new file will be created with. + /// + /// If a new file is created as part of an `OpenOptions::open` call then this + /// specified `mode` will be used as the permission bits for the new file. + /// If no `mode` is set, the default of `0o666` will be used. + /// The operating system masks out bits with the system's `umask`, to produce + /// the final permissions. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// use tokio::fs::os::unix::OpenOptionsExt; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut options = OpenOptions::new(); + /// options.mode(0o644); // Give read/write for owner and read for others. + /// let file = options.open("foo.txt").await?; + /// + /// Ok(()) + /// } + /// ``` + fn mode(&mut self, mode: u32) -> &mut Self; + + /// Pass custom flags to the `flags` argument of `open`. + /// + /// The bits that define the access mode are masked out with `O_ACCMODE`, to + /// ensure they do not interfere with the access mode set by Rusts options. + /// + /// Custom flags can only set flags, not remove flags set by Rusts options. + /// This options overwrites any previously set custom flags. + /// + /// # Examples + /// + /// ```no_run + /// use libc; + /// use tokio::fs::OpenOptions; + /// use tokio::fs::os::unix::OpenOptionsExt; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut options = OpenOptions::new(); + /// options.write(true); + /// if cfg!(unix) { + /// options.custom_flags(libc::O_NOFOLLOW); + /// } + /// let file = options.open("foo.txt").await?; + /// + /// Ok(()) + /// } + /// ``` + fn custom_flags(&mut self, flags: i32) -> &mut Self; +} + +impl OpenOptionsExt for OpenOptions { + fn mode(&mut self, mode: u32) -> &mut OpenOptions { + self.as_inner_mut().mode(mode); + self + } + + fn custom_flags(&mut self, flags: i32) -> &mut OpenOptions { + self.as_inner_mut().custom_flags(flags); + self + } +} diff --git a/tokio/src/fs/read_dir.rs b/tokio/src/fs/read_dir.rs index fbc006df8d1..f9b16c66c5d 100644 --- a/tokio/src/fs/read_dir.rs +++ b/tokio/src/fs/read_dir.rs @@ -99,7 +99,7 @@ impl crate::stream::Stream for ReadDir { /// Entries returned by the [`ReadDir`] stream. /// -/// [`ReadDir`]: struct.ReadDir.html +/// [`ReadDir`]: struct@ReadDir /// /// This is a specialized version of [`std::fs::DirEntry`] for usage from the /// Tokio runtime. diff --git a/tokio/src/fs/remove_dir_all.rs b/tokio/src/fs/remove_dir_all.rs index 3b2b2e0453e..0a237550f9c 100644 --- a/tokio/src/fs/remove_dir_all.rs +++ b/tokio/src/fs/remove_dir_all.rs @@ -7,7 +7,7 @@ use std::path::Path; /// /// This is an async version of [`std::fs::remove_dir_all`][std] /// -/// [std]: https://doc.rust-lang.org/std/fs/fn.remove_dir_all.html +/// [std]: fn@std::fs::remove_dir_all pub async fn remove_dir_all(path: impl AsRef) -> io::Result<()> { let path = path.as_ref().to_owned(); asyncify(move || std::fs::remove_dir_all(path)).await diff --git a/tokio/src/fs/set_permissions.rs b/tokio/src/fs/set_permissions.rs index b6249d13f00..09be02ea013 100644 --- a/tokio/src/fs/set_permissions.rs +++ b/tokio/src/fs/set_permissions.rs @@ -8,7 +8,7 @@ use std::path::Path; /// /// This is an async version of [`std::fs::set_permissions`][std] /// -/// [std]: https://doc.rust-lang.org/std/fs/fn.set_permissions.html +/// [std]: fn@std::fs::set_permissions pub async fn set_permissions(path: impl AsRef, perm: Permissions) -> io::Result<()> { let path = path.as_ref().to_owned(); asyncify(|| std::fs::set_permissions(path, perm)).await diff --git a/tokio/src/fs/symlink_metadata.rs b/tokio/src/fs/symlink_metadata.rs index 682b43a70eb..1d0df125760 100644 --- a/tokio/src/fs/symlink_metadata.rs +++ b/tokio/src/fs/symlink_metadata.rs @@ -8,7 +8,7 @@ use std::path::Path; /// /// This is an async version of [`std::fs::symlink_metadata`][std] /// -/// [std]: https://doc.rust-lang.org/std/fs/fn.symlink_metadata.html +/// [std]: fn@std::fs::symlink_metadata pub async fn symlink_metadata(path: impl AsRef) -> io::Result { let path = path.as_ref().to_owned(); asyncify(|| std::fs::symlink_metadata(path)).await diff --git a/tokio/src/future/ready.rs b/tokio/src/future/ready.rs index d74f999e5de..de2d60c13a2 100644 --- a/tokio/src/future/ready.rs +++ b/tokio/src/future/ready.rs @@ -2,7 +2,7 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -/// Future for the [`ready`](ready()) function. +/// Future for the [`ok`](ok()) function. /// /// `pub` in order to use the future as an associated type in a sealed trait. #[derive(Debug)] diff --git a/tokio/src/io/async_buf_read.rs b/tokio/src/io/async_buf_read.rs index 1ab73cd9b70..ecaafba4c27 100644 --- a/tokio/src/io/async_buf_read.rs +++ b/tokio/src/io/async_buf_read.rs @@ -7,14 +7,18 @@ use std::task::{Context, Poll}; /// Reads bytes asynchronously. /// -/// This trait inherits from [`std::io::BufRead`] and indicates that an I/O object is -/// **non-blocking**. All non-blocking I/O objects must return an error when -/// bytes are unavailable instead of blocking the current thread. +/// This trait is analogous to [`std::io::BufRead`], but integrates with +/// the asynchronous task system. In particular, the [`poll_fill_buf`] method, +/// unlike [`BufRead::fill_buf`], will automatically queue the current task for wakeup +/// and return if data is not yet available, rather than blocking the calling +/// thread. /// /// Utilities for working with `AsyncBufRead` values are provided by /// [`AsyncBufReadExt`]. /// /// [`std::io::BufRead`]: std::io::BufRead +/// [`poll_fill_buf`]: AsyncBufRead::poll_fill_buf +/// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf /// [`AsyncBufReadExt`]: crate::io::AsyncBufReadExt pub trait AsyncBufRead: AsyncRead { /// Attempts to return the contents of the internal buffer, filling it with more data @@ -60,16 +64,14 @@ pub trait AsyncBufRead: AsyncRead { macro_rules! deref_async_buf_read { () => { - fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) - -> Poll> - { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut **self.get_mut()).poll_fill_buf(cx) } fn consume(mut self: Pin<&mut Self>, amt: usize) { Pin::new(&mut **self).consume(amt) } - } + }; } impl AsyncBufRead for Box { diff --git a/tokio/src/io/async_read.rs b/tokio/src/io/async_read.rs index de08d65810b..1aef4150166 100644 --- a/tokio/src/io/async_read.rs +++ b/tokio/src/io/async_read.rs @@ -73,10 +73,10 @@ pub trait AsyncRead { /// that they did not write to. /// /// [`io::Read`]: std::io::Read - /// [`poll_read_buf`]: #method.poll_read_buf + /// [`poll_read_buf`]: method@Self::poll_read_buf unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { for x in buf { - *x.as_mut_ptr() = 0; + *x = MaybeUninit::new(0); } true @@ -140,12 +140,14 @@ macro_rules! deref_async_read { (**self).prepare_uninitialized_buffer(buf) } - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) - -> Poll> - { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { Pin::new(&mut **self).poll_read(cx, buf) } - } + }; } impl AsyncRead for Box { diff --git a/tokio/src/io/async_seek.rs b/tokio/src/io/async_seek.rs index 0be9c90d562..32ed0a22ab9 100644 --- a/tokio/src/io/async_seek.rs +++ b/tokio/src/io/async_seek.rs @@ -55,13 +55,10 @@ macro_rules! deref_async_seek { Pin::new(&mut **self).start_seek(cx, pos) } - fn poll_complete( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut **self).poll_complete(cx) } - } + }; } impl AsyncSeek for Box { diff --git a/tokio/src/io/async_write.rs b/tokio/src/io/async_write.rs index 0bfed056ef6..ecf7575b128 100644 --- a/tokio/src/io/async_write.rs +++ b/tokio/src/io/async_write.rs @@ -51,7 +51,7 @@ pub trait AsyncWrite { /// If the object is not ready for writing, the method returns /// `Poll::Pending` and arranges for the current task (via /// `cx.waker()`) to receive a notification when the object becomes - /// readable or is closed. + /// writable or is closed. fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -153,9 +153,11 @@ pub trait AsyncWrite { macro_rules! deref_async_write { () => { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) - -> Poll> - { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { Pin::new(&mut **self).poll_write(cx, buf) } @@ -166,7 +168,7 @@ macro_rules! deref_async_write { fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut **self).poll_shutdown(cx) } - } + }; } impl AsyncWrite for Box { diff --git a/tokio/src/io/driver/mod.rs b/tokio/src/io/driver/mod.rs index d8535d9ab27..dbfb6e16e3c 100644 --- a/tokio/src/io/driver/mod.rs +++ b/tokio/src/io/driver/mod.rs @@ -237,10 +237,14 @@ impl fmt::Debug for Handle { // ===== impl Inner ===== impl Inner { - /// Registers an I/O resource with the reactor. + /// Registers an I/O resource with the reactor for a given `mio::Ready` state. /// /// The registration token is returned. - pub(super) fn add_source(&self, source: &dyn Evented) -> io::Result
{ + pub(super) fn add_source( + &self, + source: &dyn Evented, + ready: mio::Ready, + ) -> io::Result
{ let address = self.io_dispatch.alloc().ok_or_else(|| { io::Error::new( io::ErrorKind::Other, @@ -253,7 +257,7 @@ impl Inner { self.io.register( source, mio::Token(address.to_usize()), - mio::Ready::all(), + ready, mio::PollOpt::edge(), )?; @@ -339,12 +343,12 @@ mod tests { let inner = reactor.inner; let inner2 = inner.clone(); - let token_1 = inner.add_source(&NotEvented).unwrap(); + let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); let thread = thread::spawn(move || { inner2.drop_source(token_1); }); - let token_2 = inner.add_source(&NotEvented).unwrap(); + let token_2 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); thread.join().unwrap(); assert!(token_1 != token_2); @@ -360,15 +364,15 @@ mod tests { // add sources to fill up the first page so that the dropped index // may be reused. for _ in 0..31 { - inner.add_source(&NotEvented).unwrap(); + inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); } - let token_1 = inner.add_source(&NotEvented).unwrap(); + let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); let thread = thread::spawn(move || { inner2.drop_source(token_1); }); - let token_2 = inner.add_source(&NotEvented).unwrap(); + let token_2 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); thread.join().unwrap(); assert!(token_1 != token_2); @@ -383,11 +387,11 @@ mod tests { let inner2 = inner.clone(); let thread = thread::spawn(move || { - let token_2 = inner2.add_source(&NotEvented).unwrap(); + let token_2 = inner2.add_source(&NotEvented, mio::Ready::all()).unwrap(); token_2 }); - let token_1 = inner.add_source(&NotEvented).unwrap(); + let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); let token_2 = thread.join().unwrap(); assert!(token_1 != token_2); diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs index 6d6edf558f6..7b005560db9 100644 --- a/tokio/src/io/mod.rs +++ b/tokio/src/io/mod.rs @@ -15,19 +15,19 @@ //! type will _yield_ to the Tokio scheduler when IO is not ready, rather than //! blocking. This allows other tasks to run while waiting on IO. //! -//! Another difference is that [`AsyncRead`] and [`AsyncWrite`] only contain +//! Another difference is that `AsyncRead` and `AsyncWrite` only contain //! core methods needed to provide asynchronous reading and writing //! functionality. Instead, utility methods are defined in the [`AsyncReadExt`] //! and [`AsyncWriteExt`] extension traits. These traits are automatically -//! implemented for all values that implement [`AsyncRead`] and [`AsyncWrite`] +//! implemented for all values that implement `AsyncRead` and `AsyncWrite` //! respectively. //! -//! End users will rarely interact directly with [`AsyncRead`] and -//! [`AsyncWrite`]. Instead, they will use the async functions defined in the -//! extension traits. Library authors are expected to implement [`AsyncRead`] -//! and [`AsyncWrite`] in order to provide types that behave like byte streams. +//! End users will rarely interact directly with `AsyncRead` and +//! `AsyncWrite`. Instead, they will use the async functions defined in the +//! extension traits. Library authors are expected to implement `AsyncRead` +//! and `AsyncWrite` in order to provide types that behave like byte streams. //! -//! Even with these differences, Tokio's [`AsyncRead`] and [`AsyncWrite`] traits +//! Even with these differences, Tokio's `AsyncRead` and `AsyncWrite` traits //! can be used in almost exactly the same manner as the standard library's //! `Read` and `Write`. Most types in the standard library that implement `Read` //! and `Write` have asynchronous equivalents in `tokio` that implement @@ -57,7 +57,7 @@ //! [`File`]: crate::fs::File //! [`TcpStream`]: crate::net::TcpStream //! [`std::fs::File`]: std::fs::File -//! [std_example]: https://doc.rust-lang.org/std/io/index.html#read-and-write +//! [std_example]: std::io#read-and-write //! //! ## Buffered Readers and Writers //! @@ -93,7 +93,8 @@ //! ``` //! //! [`BufWriter`] doesn't add any new ways of writing; it just buffers every call -//! to [`write`](crate::io::AsyncWriteExt::write): +//! to [`write`](crate::io::AsyncWriteExt::write). However, you **must** flush +//! [`BufWriter`] to ensure that any buffered data is written. //! //! ```no_run //! use tokio::io::{self, BufWriter, AsyncWriteExt}; @@ -105,16 +106,19 @@ //! { //! let mut writer = BufWriter::new(f); //! -//! // write a byte to the buffer +//! // Write a byte to the buffer. //! writer.write(&[42u8]).await?; //! -//! } // the buffer is flushed once writer goes out of scope +//! // Flush the buffer before it goes out of scope. +//! writer.flush().await?; +//! +//! } // Unless flushed or shut down, the contents of the buffer is discarded on drop. //! //! Ok(()) //! } //! ``` //! -//! [stdbuf]: https://doc.rust-lang.org/std/io/index.html#bufreader-and-bufwriter +//! [stdbuf]: std::io#bufreader-and-bufwriter //! [`std::io::BufRead`]: std::io::BufRead //! [`AsyncBufRead`]: crate::io::AsyncBufRead //! [`BufReader`]: crate::io::BufReader @@ -122,12 +126,26 @@ //! //! ## Implementing AsyncRead and AsyncWrite //! -//! Because they are traits, we can implement `AsyncRead` and `AsyncWrite` for +//! Because they are traits, we can implement [`AsyncRead`] and [`AsyncWrite`] for //! our own types, as well. Note that these traits must only be implemented for //! non-blocking I/O types that integrate with the futures type system. In //! other words, these types must never block the thread, and instead the //! current task is notified when the I/O resource is ready. //! +//! ## Conversion to and from Sink/Stream +//! +//! It is often convenient to encapsulate the reading and writing of +//! bytes and instead work with a [`Sink`] or [`Stream`] of some data +//! type that is encoded as bytes and/or decoded from bytes. Tokio +//! provides some utility traits in the [tokio-util] crate that +//! abstract the asynchronous buffering that is required and allows +//! you to write [`Encoder`] and [`Decoder`] functions working with a +//! buffer of bytes, and then use that ["codec"] to transform anything +//! that implements [`AsyncRead`] and [`AsyncWrite`] into a `Sink`/`Stream` of +//! your structured data. +//! +//! [tokio-util]: https://docs.rs/tokio-util/0.3/tokio_util/codec/index.html +//! //! # Standard input and output //! //! Tokio provides asynchronous APIs to standard [input], [output], and [error]. @@ -138,21 +156,28 @@ //! context of the Tokio runtime, as they require Tokio-specific features to //! function. Calling these functions outside of a Tokio runtime will panic. //! -//! [input]: fn.stdin.html -//! [output]: fn.stdout.html -//! [error]: fn.stderr.html +//! [input]: fn@stdin +//! [output]: fn@stdout +//! [error]: fn@stderr //! //! # `std` re-exports //! //! Additionally, [`Error`], [`ErrorKind`], and [`Result`] are re-exported //! from `std::io` for ease of use. //! -//! [`AsyncRead`]: trait.AsyncRead.html -//! [`AsyncWrite`]: trait.AsyncWrite.html -//! [`Error`]: struct.Error.html -//! [`ErrorKind`]: enum.ErrorKind.html -//! [`Result`]: type.Result.html +//! [`AsyncRead`]: trait@AsyncRead +//! [`AsyncWrite`]: trait@AsyncWrite +//! [`AsyncReadExt`]: trait@AsyncReadExt +//! [`AsyncWriteExt`]: trait@AsyncWriteExt +//! ["codec"]: https://docs.rs/tokio-util/0.3/tokio_util/codec/index.html +//! [`Encoder`]: https://docs.rs/tokio-util/0.3/tokio_util/codec/trait.Encoder.html +//! [`Decoder`]: https://docs.rs/tokio-util/0.3/tokio_util/codec/trait.Decoder.html +//! [`Error`]: struct@Error +//! [`ErrorKind`]: enum@ErrorKind +//! [`Result`]: type@Result //! [`Read`]: std::io::Read +//! [`Sink`]: https://docs.rs/futures/0.3/futures/sink/trait.Sink.html +//! [`Stream`]: crate::stream::Stream //! [`Write`]: std::io::Write cfg_io_blocking! { pub(crate) mod blocking; @@ -162,6 +187,7 @@ mod async_buf_read; pub use self::async_buf_read::AsyncBufRead; mod async_read; + pub use self::async_read::AsyncRead; mod async_seek; diff --git a/tokio/src/io/poll_evented.rs b/tokio/src/io/poll_evented.rs index 2cf9a043bb6..5295bd71ad8 100644 --- a/tokio/src/io/poll_evented.rs +++ b/tokio/src/io/poll_evented.rs @@ -90,17 +90,17 @@ cfg_io_driver! { /// These events are included as part of the read readiness event stream. The /// write readiness event stream is only for `Ready::writable()` events. /// - /// [`std::io::Read`]: https://doc.rust-lang.org/std/io/trait.Read.html - /// [`std::io::Write`]: https://doc.rust-lang.org/std/io/trait.Write.html - /// [`AsyncRead`]: ../io/trait.AsyncRead.html - /// [`AsyncWrite`]: ../io/trait.AsyncWrite.html - /// [`mio::Evented`]: https://docs.rs/mio/0.6/mio/trait.Evented.html - /// [`Registration`]: struct.Registration.html - /// [`TcpListener`]: ../net/struct.TcpListener.html - /// [`clear_read_ready`]: #method.clear_read_ready - /// [`clear_write_ready`]: #method.clear_write_ready - /// [`poll_read_ready`]: #method.poll_read_ready - /// [`poll_write_ready`]: #method.poll_write_ready + /// [`std::io::Read`]: trait@std::io::Read + /// [`std::io::Write`]: trait@std::io::Write + /// [`AsyncRead`]: trait@AsyncRead + /// [`AsyncWrite`]: trait@AsyncWrite + /// [`mio::Evented`]: trait@mio::Evented + /// [`Registration`]: struct@Registration + /// [`TcpListener`]: struct@crate::net::TcpListener + /// [`clear_read_ready`]: method@Self::clear_read_ready + /// [`clear_write_ready`]: method@Self::clear_write_ready + /// [`poll_read_ready`]: method@Self::poll_read_ready + /// [`poll_write_ready`]: method@Self::poll_write_ready pub struct PollEvented { io: Option, inner: Inner, @@ -175,7 +175,35 @@ where /// from a future driven by a tokio runtime, otherwise runtime can be set /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. pub fn new(io: E) -> io::Result { - let registration = Registration::new(&io)?; + PollEvented::new_with_ready(io, mio::Ready::all()) + } + + /// Creates a new `PollEvented` associated with the default reactor, for specific `mio::Ready` + /// state. `new_with_ready` should be used over `new` when you need control over the readiness + /// state, such as when a file descriptor only allows reads. This does not add `hup` or `error` + /// so if you are interested in those states, you will need to add them to the readiness state + /// passed to this function. + /// + /// An example to listen to read only + /// + /// ```rust + /// ##[cfg(unix)] + /// mio::Ready::from_usize( + /// mio::Ready::readable().as_usize() + /// | mio::unix::UnixReady::error().as_usize() + /// | mio::unix::UnixReady::hup().as_usize() + /// ); + /// ``` + /// + /// # Panics + /// + /// This function panics if thread-local runtime is not set. + /// + /// The runtime is usually set implicitly when this function is called + /// from a future driven by a tokio runtime, otherwise runtime can be set + /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. + pub fn new_with_ready(io: E, ready: mio::Ready) -> io::Result { + let registration = Registration::new_with_ready(&io, ready)?; Ok(Self { io: Some(io), inner: Inner { @@ -225,7 +253,7 @@ where /// The I/O resource will remain in a read-ready state until readiness is /// cleared by calling [`clear_read_ready`]. /// - /// [`clear_read_ready`]: #method.clear_read_ready + /// [`clear_read_ready`]: method@Self::clear_read_ready /// /// # Panics /// @@ -233,6 +261,11 @@ where /// /// * `ready` includes writable. /// * called from outside of a task context. + /// + /// # Warning + /// + /// This method may not be called concurrently. It takes `&self` to allow + /// calling it concurrently with `poll_write_ready`. pub fn poll_read_ready( &self, cx: &mut Context<'_>, @@ -291,7 +324,7 @@ where /// The I/O resource will remain in a write-ready state until readiness is /// cleared by calling [`clear_write_ready`]. /// - /// [`clear_write_ready`]: #method.clear_write_ready + /// [`clear_write_ready`]: method@Self::clear_write_ready /// /// # Panics /// @@ -299,6 +332,11 @@ where /// /// * `ready` contains bits besides `writable` and `hup`. /// * called from outside of a task context. + /// + /// # Warning + /// + /// This method may not be called concurrently. It takes `&self` to allow + /// calling it concurrently with `poll_read_ready`. pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { poll_ready!( self, diff --git a/tokio/src/io/registration.rs b/tokio/src/io/registration.rs index 4df11999f5e..77fe6dbc723 100644 --- a/tokio/src/io/registration.rs +++ b/tokio/src/io/registration.rs @@ -34,9 +34,9 @@ cfg_io_driver! { /// stream. The write readiness event stream is only for `Ready::writable()` /// events. /// - /// [`new`]: #method.new - /// [`poll_read_ready`]: #method.poll_read_ready`] - /// [`poll_write_ready`]: #method.poll_write_ready`] + /// [`new`]: method@Self::new + /// [`poll_read_ready`]: method@Self::poll_read_ready` + /// [`poll_write_ready`]: method@Self::poll_write_ready` #[derive(Debug)] pub struct Registration { handle: Handle, @@ -63,12 +63,49 @@ impl Registration { /// from a future driven by a tokio runtime, otherwise runtime can be set /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. pub fn new(io: &T) -> io::Result + where + T: Evented, + { + Registration::new_with_ready(io, mio::Ready::all()) + } + + /// Registers the I/O resource with the default reactor, for a specific `mio::Ready` state. + /// `new_with_ready` should be used over `new` when you need control over the readiness state, + /// such as when a file descriptor only allows reads. This does not add `hup` or `error` so if + /// you are interested in those states, you will need to add them to the readiness state passed + /// to this function. + /// + /// An example to listen to read only + /// + /// ```rust + /// ##[cfg(unix)] + /// mio::Ready::from_usize( + /// mio::Ready::readable().as_usize() + /// | mio::unix::UnixReady::error().as_usize() + /// | mio::unix::UnixReady::hup().as_usize() + /// ); + /// ``` + /// + /// # Return + /// + /// - `Ok` if the registration happened successfully + /// - `Err` if an error was encountered during registration + /// + /// + /// # Panics + /// + /// This function panics if thread-local runtime is not set. + /// + /// The runtime is usually set implicitly when this function is called + /// from a future driven by a tokio runtime, otherwise runtime can be set + /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. + pub fn new_with_ready(io: &T, ready: mio::Ready) -> io::Result where T: Evented, { let handle = Handle::current(); let address = if let Some(inner) = handle.inner() { - inner.add_source(io)? + inner.add_source(io, ready)? } else { return Err(io::Error::new( io::ErrorKind::Other, @@ -116,8 +153,6 @@ impl Registration { /// the function will always return `Ready(HUP)`. This should be treated as /// the end of the readiness stream. /// - /// Ensure that [`register`] has been called first. - /// /// # Return value /// /// There are several possible return values: @@ -129,22 +164,26 @@ impl Registration { /// since the last call to `poll_read_ready`. /// /// * `Poll::Ready(Err(err))` means that the registration has encountered an - /// error. This error either represents a permanent internal error **or** - /// the fact that [`register`] was not called first. + /// error. This could represent a permanent internal error for example. /// - /// [`register`]: #method.register - /// [edge-triggered]: https://docs.rs/mio/0.6/mio/struct.Poll.html#edge-triggered-and-level-triggered + /// [edge-triggered]: struct@mio::Poll#edge-triggered-and-level-triggered /// /// # Panics /// /// This function will panic if called from outside of a task context. pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); - let v = self.poll_ready(Direction::Read, Some(cx))?; + let v = self.poll_ready(Direction::Read, Some(cx)).map_err(|e| { + coop.made_progress(); + e + })?; match v { - Some(v) => Poll::Ready(Ok(v)), + Some(v) => { + coop.made_progress(); + Poll::Ready(Ok(v)) + } None => Poll::Pending, } } @@ -155,7 +194,7 @@ impl Registration { /// will not notify the current task when a new event is received. As such, /// it is safe to call this function from outside of a task context. /// - /// [`poll_read_ready`]: #method.poll_read_ready + /// [`poll_read_ready`]: method@Self::poll_read_ready pub fn take_read_ready(&self) -> io::Result> { self.poll_ready(Direction::Read, None) } @@ -170,8 +209,6 @@ impl Registration { /// the function will always return `Ready(HUP)`. This should be treated as /// the end of the readiness stream. /// - /// Ensure that [`register`] has been called first. - /// /// # Return value /// /// There are several possible return values: @@ -183,22 +220,26 @@ impl Registration { /// since the last call to `poll_write_ready`. /// /// * `Poll::Ready(Err(err))` means that the registration has encountered an - /// error. This error either represents a permanent internal error **or** - /// the fact that [`register`] was not called first. + /// error. This could represent a permanent internal error for example. /// - /// [`register`]: #method.register - /// [edge-triggered]: https://docs.rs/mio/0.6/mio/struct.Poll.html#edge-triggered-and-level-triggered + /// [edge-triggered]: struct@mio::Poll#edge-triggered-and-level-triggered /// /// # Panics /// /// This function will panic if called from outside of a task context. pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); - let v = self.poll_ready(Direction::Write, Some(cx))?; + let v = self.poll_ready(Direction::Write, Some(cx)).map_err(|e| { + coop.made_progress(); + e + })?; match v { - Some(v) => Poll::Ready(Ok(v)), + Some(v) => { + coop.made_progress(); + Poll::Ready(Ok(v)) + } None => Poll::Pending, } } @@ -209,7 +250,7 @@ impl Registration { /// will not notify the current task when a new event is received. As such, /// it is safe to call this function from outside of a task context. /// - /// [`poll_write_ready`]: #method.poll_write_ready + /// [`poll_write_ready`]: method@Self::poll_write_ready pub fn take_write_ready(&self) -> io::Result> { self.poll_ready(Direction::Write, None) } diff --git a/tokio/src/io/stdin.rs b/tokio/src/io/stdin.rs index 634bd484967..325b8757ec1 100644 --- a/tokio/src/io/stdin.rs +++ b/tokio/src/io/stdin.rs @@ -12,16 +12,19 @@ cfg_io_std! { /// The handle implements the [`AsyncRead`] trait, but beware that concurrent /// reads of `Stdin` must be executed with care. /// - /// As an additional caveat, reading from the handle may block the calling - /// future indefinitely if there is not enough data available. This makes this - /// handle unsuitable for use in any circumstance where immediate reaction to - /// available data is required, e.g. interactive use or when implementing a - /// subprocess driven by requests on the standard input. + /// This handle is best used for non-interactive uses, such as when a file + /// is piped into the application. For technical reasons, `stdin` is + /// implemented by using an ordinary blocking read on a separate thread, and + /// it is impossible to cancel that read. This can make shutdown of the + /// runtime hang until the user presses enter. + /// + /// For interactive uses, it is recommended to spawn a thread dedicated to + /// user input and use blocking IO directly in that thread. /// /// Created by the [`stdin`] function. /// - /// [`stdin`]: fn.stdin.html - /// [`AsyncRead`]: trait.AsyncRead.html + /// [`stdin`]: fn@stdin + /// [`AsyncRead`]: trait@AsyncRead #[derive(Debug)] pub struct Stdin { std: Blocking, @@ -29,14 +32,14 @@ cfg_io_std! { /// Constructs a new handle to the standard input of the current process. /// - /// The returned handle allows reading from standard input from the within the - /// Tokio runtime. + /// This handle is best used for non-interactive uses, such as when a file + /// is piped into the application. For technical reasons, `stdin` is + /// implemented by using an ordinary blocking read on a separate thread, and + /// it is impossible to cancel that read. This can make shutdown of the + /// runtime hang until the user presses enter. /// - /// As an additional caveat, reading from the handle may block the calling - /// future indefinitely if there is not enough data available. This makes this - /// handle unsuitable for use in any circumstance where immediate reaction to - /// available data is required, e.g. interactive use or when implementing a - /// subprocess driven by requests on the standard input. + /// For interactive uses, it is recommended to spawn a thread dedicated to + /// user input and use blocking IO directly in that thread. pub fn stdin() -> Stdin { let std = io::stdin(); Stdin { @@ -60,6 +63,11 @@ impl std::os::windows::io::AsRawHandle for Stdin { } impl AsyncRead for Stdin { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { + // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/io/stdio.rs#L97 + false + } + fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/tokio/src/io/util/async_read_ext.rs b/tokio/src/io/util/async_read_ext.rs index d4402db621b..e848a5d2243 100644 --- a/tokio/src/io/util/async_read_ext.rs +++ b/tokio/src/io/util/async_read_ext.rs @@ -2,8 +2,12 @@ use crate::io::util::chain::{chain, Chain}; use crate::io::util::read::{read, Read}; use crate::io::util::read_buf::{read_buf, ReadBuf}; use crate::io::util::read_exact::{read_exact, ReadExact}; -use crate::io::util::read_int::{ReadI128, ReadI16, ReadI32, ReadI64, ReadI8}; -use crate::io::util::read_int::{ReadU128, ReadU16, ReadU32, ReadU64, ReadU8}; +use crate::io::util::read_int::{ + ReadI128, ReadI128Le, ReadI16, ReadI16Le, ReadI32, ReadI32Le, ReadI64, ReadI64Le, ReadI8, +}; +use crate::io::util::read_int::{ + ReadU128, ReadU128Le, ReadU16, ReadU16Le, ReadU32, ReadU32Le, ReadU64, ReadU64Le, ReadU8, +}; use crate::io::util::read_to_end::{read_to_end, ReadToEnd}; use crate::io::util::read_to_string::{read_to_string, ReadToString}; use crate::io::util::take::{take, Take}; @@ -221,7 +225,7 @@ cfg_io_util! { /// ``` fn read_buf<'a, B>(&'a mut self, buf: &'a mut B) -> ReadBuf<'a, Self, B> where - Self: Sized, + Self: Sized + Unpin, B: BufMut, { read_buf(self, buf) @@ -238,12 +242,6 @@ cfg_io_util! { /// This function reads as many bytes as necessary to completely fill /// the specified buffer `buf`. /// - /// No guarantees are provided about the contents of `buf` when this - /// function is called, implementations cannot rely on any property of - /// the contents of `buf` being `true`. It is recommended that - /// implementations only write data to `buf` instead of reading its - /// contents. - /// /// # Errors /// /// If the operation encounters an "end of file" before completely @@ -669,6 +667,313 @@ cfg_io_util! { /// } /// ``` fn read_i128(&mut self) -> ReadI128; + + /// Reads an unsigned 16-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u16_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 16 bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![2, 5, 3, 0]); + /// + /// assert_eq!(1282, reader.read_u16_le().await?); + /// assert_eq!(3, reader.read_u16_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_u16_le(&mut self) -> ReadU16Le; + + /// Reads a signed 16-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i16_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 16 bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0x00, 0xc1, 0xff, 0x7c]); + /// + /// assert_eq!(-16128, reader.read_i16_le().await?); + /// assert_eq!(31999, reader.read_i16_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_i16_le(&mut self) -> ReadI16Le; + + /// Reads an unsigned 32-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u32_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 32-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0x00, 0x00, 0x01, 0x0b]); + /// + /// assert_eq!(184614912, reader.read_u32_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_u32_le(&mut self) -> ReadU32Le; + + /// Reads a signed 32-bit integer in little-endian order from the + /// underlying reader. + /// + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i32_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 32-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0xff, 0xff, 0x7a, 0x33]); + /// + /// assert_eq!(863698943, reader.read_i32_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_i32_le(&mut self) -> ReadI32Le; + + /// Reads an unsigned 64-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u64_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 64-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0x00, 0x03, 0x43, 0x95, 0x4d, 0x60, 0x86, 0x83 + /// ]); + /// + /// assert_eq!(9477368352180732672, reader.read_u64_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_u64_le(&mut self) -> ReadU64Le; + + /// Reads an signed 64-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i64_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 64-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0x80, 0, 0, 0, 0, 0, 0, 0]); + /// + /// assert_eq!(128, reader.read_i64_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_i64_le(&mut self) -> ReadI64Le; + + /// Reads an unsigned 128-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u128_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 128-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0x00, 0x03, 0x43, 0x95, 0x4d, 0x60, 0x86, 0x83, + /// 0x00, 0x03, 0x43, 0x95, 0x4d, 0x60, 0x86, 0x83 + /// ]); + /// + /// assert_eq!(174826588484952389081207917399662330624, reader.read_u128_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_u128_le(&mut self) -> ReadU128Le; + + /// Reads an signed 128-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i128_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 128-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0x80, 0, 0, 0, 0, 0, 0, 0, + /// 0, 0, 0, 0, 0, 0, 0, 0 + /// ]); + /// + /// assert_eq!(128, reader.read_i128_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_i128_le(&mut self) -> ReadI128Le; } /// Reads all bytes until EOF in this source, placing them into `buf`. diff --git a/tokio/src/io/util/async_seek_ext.rs b/tokio/src/io/util/async_seek_ext.rs index c7243c7f3ef..c7a0f72fb81 100644 --- a/tokio/src/io/util/async_seek_ext.rs +++ b/tokio/src/io/util/async_seek_ext.rs @@ -2,7 +2,9 @@ use crate::io::seek::{seek, Seek}; use crate::io::AsyncSeek; use std::io::SeekFrom; -/// An extension trait which adds utility methods to `AsyncSeek` types. +/// An extension trait which adds utility methods to [`AsyncSeek`] types. +/// +/// As a convenience, this trait may be imported using the [`prelude`]: /// /// # Examples /// @@ -25,6 +27,11 @@ use std::io::SeekFrom; /// Ok(()) /// } /// ``` +/// +/// See [module][crate::io] documentation for more details. +/// +/// [`AsyncSeek`]: AsyncSeek +/// [`prelude`]: crate::prelude pub trait AsyncSeekExt: AsyncSeek { /// Creates a future which will seek an IO object, and then yield the /// new position in the object and the object itself. diff --git a/tokio/src/io/util/async_write_ext.rs b/tokio/src/io/util/async_write_ext.rs index 377f4ecaf80..fa41097472a 100644 --- a/tokio/src/io/util/async_write_ext.rs +++ b/tokio/src/io/util/async_write_ext.rs @@ -3,8 +3,14 @@ use crate::io::util::shutdown::{shutdown, Shutdown}; use crate::io::util::write::{write, Write}; use crate::io::util::write_all::{write_all, WriteAll}; use crate::io::util::write_buf::{write_buf, WriteBuf}; -use crate::io::util::write_int::{WriteI128, WriteI16, WriteI32, WriteI64, WriteI8}; -use crate::io::util::write_int::{WriteU128, WriteU16, WriteU32, WriteU64, WriteU8}; +use crate::io::util::write_int::{ + WriteI128, WriteI128Le, WriteI16, WriteI16Le, WriteI32, WriteI32Le, WriteI64, WriteI64Le, + WriteI8, +}; +use crate::io::util::write_int::{ + WriteU128, WriteU128Le, WriteU16, WriteU16Le, WriteU32, WriteU32Le, WriteU64, WriteU64Le, + WriteU8, +}; use crate::io::AsyncWrite; use bytes::Buf; @@ -180,7 +186,7 @@ cfg_io_util! { /// ``` fn write_buf<'a, B>(&'a mut self, src: &'a mut B) -> WriteBuf<'a, Self, B> where - Self: Sized, + Self: Sized + Unpin, B: Buf, { write_buf(self, src) @@ -608,6 +614,315 @@ cfg_io_util! { /// } /// ``` fn write_i128(&mut self, n: i128) -> WriteI128; + + + /// Writes an unsigned 16-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u16_le(&mut self, n: u16) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 16-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u16_le(517).await?; + /// writer.write_u16_le(768).await?; + /// + /// assert_eq!(writer, b"\x05\x02\x00\x03"); + /// Ok(()) + /// } + /// ``` + fn write_u16_le(&mut self, n: u16) -> WriteU16Le; + + /// Writes a signed 16-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i16_le(&mut self, n: i16) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 16-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i16_le(193).await?; + /// writer.write_i16_le(-132).await?; + /// + /// assert_eq!(writer, b"\xc1\x00\x7c\xff"); + /// Ok(()) + /// } + /// ``` + fn write_i16_le(&mut self, n: i16) -> WriteI16Le; + + /// Writes an unsigned 32-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u32_le(&mut self, n: u32) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 32-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u32_le(267).await?; + /// writer.write_u32_le(1205419366).await?; + /// + /// assert_eq!(writer, b"\x0b\x01\x00\x00\x66\x3d\xd9\x47"); + /// Ok(()) + /// } + /// ``` + fn write_u32_le(&mut self, n: u32) -> WriteU32Le; + + /// Writes a signed 32-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i32_le(&mut self, n: i32) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 32-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i32_le(267).await?; + /// writer.write_i32_le(1205419366).await?; + /// + /// assert_eq!(writer, b"\x0b\x01\x00\x00\x66\x3d\xd9\x47"); + /// Ok(()) + /// } + /// ``` + fn write_i32_le(&mut self, n: i32) -> WriteI32Le; + + /// Writes an unsigned 64-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u64_le(&mut self, n: u64) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 64-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u64_le(918733457491587).await?; + /// writer.write_u64_le(143).await?; + /// + /// assert_eq!(writer, b"\x83\x86\x60\x4d\x95\x43\x03\x00\x8f\x00\x00\x00\x00\x00\x00\x00"); + /// Ok(()) + /// } + /// ``` + fn write_u64_le(&mut self, n: u64) -> WriteU64Le; + + /// Writes an signed 64-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i64_le(&mut self, n: i64) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 64-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i64_le(i64::min_value()).await?; + /// writer.write_i64_le(i64::max_value()).await?; + /// + /// assert_eq!(writer, b"\x00\x00\x00\x00\x00\x00\x00\x80\xff\xff\xff\xff\xff\xff\xff\x7f"); + /// Ok(()) + /// } + /// ``` + fn write_i64_le(&mut self, n: i64) -> WriteI64Le; + + /// Writes an unsigned 128-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u128_le(&mut self, n: u128) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 128-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u128_le(16947640962301618749969007319746179).await?; + /// + /// assert_eq!(writer, vec![ + /// 0x83, 0x86, 0x60, 0x4d, 0x95, 0x43, 0x03, 0x00, + /// 0x83, 0x86, 0x60, 0x4d, 0x95, 0x43, 0x03, 0x00, + /// ]); + /// Ok(()) + /// } + /// ``` + fn write_u128_le(&mut self, n: u128) -> WriteU128Le; + + /// Writes an signed 128-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i128_le(&mut self, n: i128) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 128-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i128_le(i128::min_value()).await?; + /// + /// assert_eq!(writer, vec![ + /// 0, 0, 0, 0, 0, 0, 0, + /// 0, 0, 0, 0, 0, 0, 0, 0, 0x80 + /// ]); + /// Ok(()) + /// } + /// ``` + fn write_i128_le(&mut self, n: i128) -> WriteI128Le; } /// Flushes this output stream, ensuring that all intermediately buffered diff --git a/tokio/src/io/util/buf_reader.rs b/tokio/src/io/util/buf_reader.rs index 0177c0e344a..a1c5990a644 100644 --- a/tokio/src/io/util/buf_reader.rs +++ b/tokio/src/io/util/buf_reader.rs @@ -1,6 +1,7 @@ use crate::io::util::DEFAULT_BUF_SIZE; use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use bytes::Buf; use pin_project_lite::pin_project; use std::io::{self, Read}; use std::mem::MaybeUninit; @@ -82,7 +83,7 @@ impl BufReader { self.project().inner } - /// Consumes this `BufWriter`, returning the underlying reader. + /// Consumes this `BufReader`, returning the underlying reader. /// /// Note that any leftover data in the internal buffer is lost. pub fn into_inner(self) -> R { @@ -162,6 +163,14 @@ impl AsyncWrite for BufReader { self.get_pin_mut().poll_write(cx, buf) } + fn poll_write_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> { + self.get_pin_mut().poll_write_buf(cx, buf) + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.get_pin_mut().poll_flush(cx) } diff --git a/tokio/src/io/util/chain.rs b/tokio/src/io/util/chain.rs index bc76af341da..8ba9194f5de 100644 --- a/tokio/src/io/util/chain.rs +++ b/tokio/src/io/util/chain.rs @@ -84,6 +84,15 @@ where T: AsyncRead, U: AsyncRead, { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit]) -> bool { + if self.first.prepare_uninitialized_buffer(buf) { + return true; + } + if self.second.prepare_uninitialized_buffer(buf) { + return true; + } + false + } fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/tokio/src/io/util/copy.rs b/tokio/src/io/util/copy.rs index 8e0058c1c2d..7bfe296941e 100644 --- a/tokio/src/io/util/copy.rs +++ b/tokio/src/io/util/copy.rs @@ -70,7 +70,7 @@ cfg_io_util! { amt: 0, pos: 0, cap: 0, - buf: Box::new([0; 2048]), + buf: vec![0; 2048].into_boxed_slice(), } } } diff --git a/tokio/src/io/util/empty.rs b/tokio/src/io/util/empty.rs index 7075612f4b3..576058d52d1 100644 --- a/tokio/src/io/util/empty.rs +++ b/tokio/src/io/util/empty.rs @@ -13,7 +13,7 @@ cfg_io_util! { /// /// This is an asynchronous version of [`std::io::empty`][std]. /// - /// [`empty`]: fn.empty.html + /// [`empty`]: fn@empty /// [std]: std::io::empty pub struct Empty { _p: (), @@ -47,6 +47,9 @@ cfg_io_util! { } impl AsyncRead for Empty { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { + false + } #[inline] fn poll_read( self: Pin<&mut Self>, diff --git a/tokio/src/io/util/flush.rs b/tokio/src/io/util/flush.rs index 1465f304486..534a5160c1a 100644 --- a/tokio/src/io/util/flush.rs +++ b/tokio/src/io/util/flush.rs @@ -8,7 +8,8 @@ use std::task::{Context, Poll}; cfg_io_util! { /// A future used to fully flush an I/O object. /// - /// Created by the [`AsyncWriteExt::flush`] function. + /// Created by the [`AsyncWriteExt::flush`][flush] function. + /// [flush]: crate::io::AsyncWriteExt::flush #[derive(Debug)] pub struct Flush<'a, A: ?Sized> { a: &'a mut A, diff --git a/tokio/src/io/util/lines.rs b/tokio/src/io/util/lines.rs index f0e75de4b18..ee27400c9de 100644 --- a/tokio/src/io/util/lines.rs +++ b/tokio/src/io/util/lines.rs @@ -59,6 +59,24 @@ where poll_fn(|cx| Pin::new(&mut *self).poll_next_line(cx)).await } + + /// Obtain a mutable reference to the underlying reader + pub fn get_mut(&mut self) -> &mut R { + &mut self.reader + } + + /// Obtain a reference to the underlying reader + pub fn get_ref(&mut self) -> &R { + &self.reader + } + + /// Unwraps this `Lines`, returning the underlying reader. + /// + /// Note that any leftover data in the internal buffer is lost. + /// Therefore, a following read from the underlying reader may lead to data loss. + pub fn into_inner(self) -> R { + self.reader + } } impl Lines @@ -73,6 +91,7 @@ where let me = self.project(); let n = ready!(read_line_internal(me.reader, cx, me.buf, me.bytes, me.read))?; + debug_assert_eq!(*me.read, 0); if n == 0 && me.buf.is_empty() { return Poll::Ready(Ok(None)); diff --git a/tokio/src/io/util/read_buf.rs b/tokio/src/io/util/read_buf.rs index 550499b9334..6ee3d249f82 100644 --- a/tokio/src/io/util/read_buf.rs +++ b/tokio/src/io/util/read_buf.rs @@ -8,14 +8,14 @@ use std::task::{Context, Poll}; pub(crate) fn read_buf<'a, R, B>(reader: &'a mut R, buf: &'a mut B) -> ReadBuf<'a, R, B> where - R: AsyncRead, + R: AsyncRead + Unpin, B: BufMut, { ReadBuf { reader, buf } } cfg_io_util! { - /// Future returned by [`read_buf`](AsyncReadExt::read_buf). + /// Future returned by [`read_buf`](crate::io::AsyncReadExt::read_buf). #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadBuf<'a, R, B> { @@ -26,16 +26,13 @@ cfg_io_util! { impl Future for ReadBuf<'_, R, B> where - R: AsyncRead, + R: AsyncRead + Unpin, B: BufMut, { type Output = io::Result; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // safety: no data is moved from self - unsafe { - let me = self.get_unchecked_mut(); - Pin::new_unchecked(&mut *me.reader).poll_read_buf(cx, &mut me.buf) - } + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = &mut *self; + Pin::new(&mut *me.reader).poll_read_buf(cx, me.buf) } } diff --git a/tokio/src/io/util/read_exact.rs b/tokio/src/io/util/read_exact.rs index d6983c99530..86b8412954b 100644 --- a/tokio/src/io/util/read_exact.rs +++ b/tokio/src/io/util/read_exact.rs @@ -9,7 +9,8 @@ use std::task::{Context, Poll}; /// A future which can be used to easily read exactly enough bytes to fill /// a buffer. /// -/// Created by the [`AsyncRead::read_exact`]. +/// Created by the [`AsyncReadExt::read_exact`][read_exact]. +/// [read_exact]: [crate::io::AsyncReadExt::read_exact] pub(crate) fn read_exact<'a, A>(reader: &'a mut A, buf: &'a mut [u8]) -> ReadExact<'a, A> where A: AsyncRead + Unpin + ?Sized, diff --git a/tokio/src/io/util/read_int.rs b/tokio/src/io/util/read_int.rs index 9dc4402f88f..9d37dc7a400 100644 --- a/tokio/src/io/util/read_int.rs +++ b/tokio/src/io/util/read_int.rs @@ -121,3 +121,13 @@ reader!(ReadI16, i16, get_i16); reader!(ReadI32, i32, get_i32); reader!(ReadI64, i64, get_i64); reader!(ReadI128, i128, get_i128); + +reader!(ReadU16Le, u16, get_u16_le); +reader!(ReadU32Le, u32, get_u32_le); +reader!(ReadU64Le, u64, get_u64_le); +reader!(ReadU128Le, u128, get_u128_le); + +reader!(ReadI16Le, i16, get_i16_le); +reader!(ReadI32Le, i32, get_i32_le); +reader!(ReadI64Le, i64, get_i64_le); +reader!(ReadI128Le, i128, get_i128_le); diff --git a/tokio/src/io/util/read_line.rs b/tokio/src/io/util/read_line.rs index c5ee597486f..d625a76b80a 100644 --- a/tokio/src/io/util/read_line.rs +++ b/tokio/src/io/util/read_line.rs @@ -5,7 +5,6 @@ use std::future::Future; use std::io; use std::mem; use std::pin::Pin; -use std::str; use std::task::{Context, Poll}; cfg_io_util! { @@ -14,45 +13,72 @@ cfg_io_util! { #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadLine<'a, R: ?Sized> { reader: &'a mut R, - buf: &'a mut String, - bytes: Vec, + /// This is the buffer we were provided. It will be replaced with an empty string + /// while reading to postpone utf-8 handling until after reading. + output: &'a mut String, + /// The actual allocation of the string is moved into a vector instead. + buf: Vec, + /// The number of bytes appended to buf. This can be less than buf.len() if + /// the buffer was not empty when the operation was started. read: usize, } } -pub(crate) fn read_line<'a, R>(reader: &'a mut R, buf: &'a mut String) -> ReadLine<'a, R> +pub(crate) fn read_line<'a, R>(reader: &'a mut R, string: &'a mut String) -> ReadLine<'a, R> where R: AsyncBufRead + ?Sized + Unpin, { ReadLine { reader, - bytes: unsafe { mem::replace(buf.as_mut_vec(), Vec::new()) }, - buf, + buf: mem::replace(string, String::new()).into_bytes(), + output: string, read: 0, } } +fn put_back_original_data(output: &mut String, mut vector: Vec, num_bytes_read: usize) { + let original_len = vector.len() - num_bytes_read; + vector.truncate(original_len); + *output = String::from_utf8(vector).expect("The original data must be valid utf-8."); +} + pub(super) fn read_line_internal( reader: Pin<&mut R>, cx: &mut Context<'_>, - buf: &mut String, - bytes: &mut Vec, + output: &mut String, + buf: &mut Vec, read: &mut usize, ) -> Poll> { - let ret = ready!(read_until_internal(reader, cx, b'\n', bytes, read)); - if str::from_utf8(&bytes).is_err() { - Poll::Ready(ret.and_then(|_| { - Err(io::Error::new( + let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read)); + let utf8_res = String::from_utf8(mem::replace(buf, Vec::new())); + + // At this point both buf and output are empty. The allocation is in utf8_res. + + debug_assert!(buf.is_empty()); + match (io_res, utf8_res) { + (Ok(num_bytes), Ok(string)) => { + debug_assert_eq!(*read, 0); + *output = string; + Poll::Ready(Ok(num_bytes)) + } + (Err(io_err), Ok(string)) => { + *output = string; + Poll::Ready(Err(io_err)) + } + (Ok(num_bytes), Err(utf8_err)) => { + debug_assert_eq!(*read, 0); + put_back_original_data(output, utf8_err.into_bytes(), num_bytes); + + Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidData, "stream did not contain valid UTF-8", - )) - })) - } else { - debug_assert!(buf.is_empty()); - debug_assert_eq!(*read, 0); - // Safety: `bytes` is a valid UTF-8 because `str::from_utf8` returned `Ok`. - mem::swap(unsafe { buf.as_mut_vec() }, bytes); - Poll::Ready(ret) + ))) + } + (Err(io_err), Err(utf8_err)) => { + put_back_original_data(output, utf8_err.into_bytes(), *read); + + Poll::Ready(Err(io_err)) + } } } @@ -62,11 +88,12 @@ impl Future for ReadLine<'_, R> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let Self { reader, + output, buf, - bytes, read, } = &mut *self; - read_line_internal(Pin::new(reader), cx, buf, bytes, read) + + read_line_internal(Pin::new(reader), cx, output, buf, read) } } diff --git a/tokio/src/io/util/read_to_string.rs b/tokio/src/io/util/read_to_string.rs index e77d836dee9..cab0505ab83 100644 --- a/tokio/src/io/util/read_to_string.rs +++ b/tokio/src/io/util/read_to_string.rs @@ -4,7 +4,7 @@ use crate::io::AsyncRead; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use std::{io, mem, str}; +use std::{io, mem}; cfg_io_util! { /// Future for the [`read_to_string`](super::AsyncReadExt::read_to_string) method. @@ -25,7 +25,7 @@ where let start_len = buf.len(); ReadToString { reader, - bytes: unsafe { mem::replace(buf.as_mut_vec(), Vec::new()) }, + bytes: mem::replace(buf, String::new()).into_bytes(), buf, start_len, } @@ -38,19 +38,20 @@ fn read_to_string_internal( bytes: &mut Vec, start_len: usize, ) -> Poll> { - let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len)); - if str::from_utf8(&bytes).is_err() { - Poll::Ready(ret.and_then(|_| { - Err(io::Error::new( + let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len))?; + match String::from_utf8(mem::replace(bytes, Vec::new())) { + Ok(string) => { + debug_assert!(buf.is_empty()); + *buf = string; + Poll::Ready(Ok(ret)) + } + Err(e) => { + *bytes = e.into_bytes(); + Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidData, "stream did not contain valid UTF-8", - )) - })) - } else { - debug_assert!(buf.is_empty()); - // Safety: `bytes` is a valid UTF-8 because `str::from_utf8` returned `Ok`. - mem::swap(unsafe { buf.as_mut_vec() }, bytes); - Poll::Ready(ret) + ))) + } } } @@ -67,7 +68,14 @@ where bytes, start_len, } = &mut *self; - read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len) + let ret = read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len); + if let Poll::Ready(Err(_)) = ret { + // Put back the original string. + bytes.truncate(*start_len); + **buf = String::from_utf8(mem::replace(bytes, Vec::new())) + .expect("original string no longer utf-8"); + } + ret } } diff --git a/tokio/src/io/util/read_until.rs b/tokio/src/io/util/read_until.rs index 1adeda66f05..78dac8c2a14 100644 --- a/tokio/src/io/util/read_until.rs +++ b/tokio/src/io/util/read_until.rs @@ -8,19 +8,22 @@ use std::task::{Context, Poll}; cfg_io_util! { /// Future for the [`read_until`](crate::io::AsyncBufReadExt::read_until) method. + /// The delimeter is included in the resulting vector. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadUntil<'a, R: ?Sized> { reader: &'a mut R, - byte: u8, + delimeter: u8, buf: &'a mut Vec, + /// The number of bytes appended to buf. This can be less than buf.len() if + /// the buffer was not empty when the operation was started. read: usize, } } pub(crate) fn read_until<'a, R>( reader: &'a mut R, - byte: u8, + delimeter: u8, buf: &'a mut Vec, ) -> ReadUntil<'a, R> where @@ -28,7 +31,7 @@ where { ReadUntil { reader, - byte, + delimeter, buf, read: 0, } @@ -37,14 +40,14 @@ where pub(super) fn read_until_internal( mut reader: Pin<&mut R>, cx: &mut Context<'_>, - byte: u8, + delimeter: u8, buf: &mut Vec, read: &mut usize, ) -> Poll> { loop { let (done, used) = { let available = ready!(reader.as_mut().poll_fill_buf(cx))?; - if let Some(i) = memchr::memchr(byte, available) { + if let Some(i) = memchr::memchr(delimeter, available) { buf.extend_from_slice(&available[..=i]); (true, i + 1) } else { @@ -66,11 +69,11 @@ impl Future for ReadUntil<'_, R> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let Self { reader, - byte, + delimeter, buf, read, } = &mut *self; - read_until_internal(Pin::new(reader), cx, *byte, buf, read) + read_until_internal(Pin::new(reader), cx, *delimeter, buf, read) } } diff --git a/tokio/src/io/util/repeat.rs b/tokio/src/io/util/repeat.rs index 064775ee56a..eeef7cc187b 100644 --- a/tokio/src/io/util/repeat.rs +++ b/tokio/src/io/util/repeat.rs @@ -13,7 +13,7 @@ cfg_io_util! { /// /// This is an asynchronous version of [`std::io::Repeat`][std]. /// - /// [repeat]: fn.repeat.html + /// [repeat]: fn@repeat /// [std]: std::io::Repeat #[derive(Debug)] pub struct Repeat { @@ -47,6 +47,9 @@ cfg_io_util! { } impl AsyncRead for Repeat { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { + false + } #[inline] fn poll_read( self: Pin<&mut Self>, diff --git a/tokio/src/io/util/shutdown.rs b/tokio/src/io/util/shutdown.rs index f24e2885414..33ac0ac0db7 100644 --- a/tokio/src/io/util/shutdown.rs +++ b/tokio/src/io/util/shutdown.rs @@ -8,7 +8,8 @@ use std::task::{Context, Poll}; cfg_io_util! { /// A future used to shutdown an I/O object. /// - /// Created by the [`AsyncWriteExt::shutdown`] function. + /// Created by the [`AsyncWriteExt::shutdown`][shutdown] function. + /// [shutdown]: crate::io::AsyncWriteExt::shutdown #[derive(Debug)] pub struct Shutdown<'a, A: ?Sized> { a: &'a mut A, diff --git a/tokio/src/io/util/split.rs b/tokio/src/io/util/split.rs index f1ed2fd89d3..f552ed503d1 100644 --- a/tokio/src/io/util/split.rs +++ b/tokio/src/io/util/split.rs @@ -75,6 +75,8 @@ where let n = ready!(read_until_internal( me.reader, cx, *me.delim, me.buf, me.read, ))?; + // read_until_internal resets me.read to zero once it finds the delimeter + debug_assert_eq!(*me.read, 0); if n == 0 && me.buf.is_empty() { return Poll::Ready(Ok(None)); diff --git a/tokio/src/io/util/write_buf.rs b/tokio/src/io/util/write_buf.rs index e49282fe0c4..cedfde64e6e 100644 --- a/tokio/src/io/util/write_buf.rs +++ b/tokio/src/io/util/write_buf.rs @@ -20,7 +20,7 @@ cfg_io_util! { /// asynchronous manner, returning a future. pub(crate) fn write_buf<'a, W, B>(writer: &'a mut W, buf: &'a mut B) -> WriteBuf<'a, W, B> where - W: AsyncWrite, + W: AsyncWrite + Unpin, B: Buf, { WriteBuf { writer, buf } @@ -28,16 +28,13 @@ where impl Future for WriteBuf<'_, W, B> where - W: AsyncWrite, + W: AsyncWrite + Unpin, B: Buf, { type Output = io::Result; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // safety: no data is moved from self - unsafe { - let me = self.get_unchecked_mut(); - Pin::new_unchecked(&mut *me.writer).poll_write_buf(cx, &mut me.buf) - } + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = &mut *self; + Pin::new(&mut *me.writer).poll_write_buf(cx, me.buf) } } diff --git a/tokio/src/io/util/write_int.rs b/tokio/src/io/util/write_int.rs index 28add549971..ee992de1832 100644 --- a/tokio/src/io/util/write_int.rs +++ b/tokio/src/io/util/write_int.rs @@ -56,6 +56,9 @@ macro_rules! writer { { Poll::Pending => return Poll::Pending, Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), + Poll::Ready(Ok(0)) => { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } Poll::Ready(Ok(n)) => n as u8, }; } @@ -96,7 +99,7 @@ macro_rules! writer8 { match me.dst.poll_write(cx, &buf[..]) { Poll::Pending => Poll::Pending, Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), - Poll::Ready(Ok(0)) => Poll::Pending, + Poll::Ready(Ok(0)) => Poll::Ready(Err(io::ErrorKind::WriteZero.into())), Poll::Ready(Ok(1)) => Poll::Ready(Ok(())), Poll::Ready(Ok(_)) => unreachable!(), } @@ -117,3 +120,13 @@ writer!(WriteI16, i16, put_i16); writer!(WriteI32, i32, put_i32); writer!(WriteI64, i64, put_i64); writer!(WriteI128, i128, put_i128); + +writer!(WriteU16Le, u16, put_u16_le); +writer!(WriteU32Le, u32, put_u32_le); +writer!(WriteU64Le, u64, put_u64_le); +writer!(WriteU128Le, u128, put_u128_le); + +writer!(WriteI16Le, i16, put_i16_le); +writer!(WriteI32Le, i32, put_i32_le); +writer!(WriteI64Le, i64, put_i64_le); +writer!(WriteI128Le, i128, put_i128_le); diff --git a/tokio/src/lib.rs b/tokio/src/lib.rs index 9b1d6cdb010..e473847e2ab 100644 --- a/tokio/src/lib.rs +++ b/tokio/src/lib.rs @@ -1,4 +1,4 @@ -#![doc(html_root_url = "https://docs.rs/tokio/0.2.17")] +#![doc(html_root_url = "https://docs.rs/tokio/0.2.21")] #![allow( clippy::cognitive_complexity, clippy::large_enum_variant, @@ -16,6 +16,7 @@ attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) ))] #![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(docsrs, feature(doc_alias))] //! A runtime for writing reliable, asynchronous, and slim applications. //! @@ -213,6 +214,51 @@ //! [rt-threaded]: runtime/index.html#threaded-scheduler //! [rt-features]: runtime/index.html#runtime-scheduler //! +//! ## CPU-bound tasks and blocking code +//! +//! Tokio is able to concurrently run many tasks on a few threads by repeatedly +//! swapping the currently running task on each thread. However, this kind of +//! swapping can only happen at `.await` points, so code that spends a long time +//! without reaching an `.await` will prevent other tasks from running. To +//! combat this, Tokio provides two kinds of threads: Core threads and blocking +//! threads. The core threads are where all asynchronous code runs, and Tokio +//! will by default spawn one for each CPU core. The blocking threads are +//! spawned on demand, and can be used to run blocking code that would otherwise +//! block other tasks from running. Since it is not possible for Tokio to swap +//! out blocking tasks, like it can do with asynchronous code, the upper limit +//! on the number of blocking threads is very large. These limits can be +//! configured on the [`Builder`]. +//! +//! To spawn a blocking task, you should use the [`spawn_blocking`] function. +//! +//! [`Builder`]: crate::runtime::Builder +//! [`spawn_blocking`]: crate::task::spawn_blocking() +//! +//! ``` +//! #[tokio::main] +//! async fn main() { +//! // This is running on a core thread. +//! +//! let blocking_task = tokio::task::spawn_blocking(|| { +//! // This is running on a blocking thread. +//! // Blocking here is ok. +//! }); +//! +//! // We can wait for the blocking task like this: +//! // If the blocking task panics, the unwrap below will propagate the +//! // panic. +//! blocking_task.await.unwrap(); +//! } +//! ``` +//! +//! If your code is CPU-bound and you wish to limit the number of threads used +//! to run it, you should run it on another thread pool such as [rayon]. You +//! can use an [`oneshot`] channel to send the result back to Tokio when the +//! rayon task finishes. +//! +//! [rayon]: https://docs.rs/rayon +//! [`oneshot`]: crate::sync::oneshot +//! //! ## Asynchronous IO //! //! As well as scheduling and running tasks, Tokio provides everything you need diff --git a/tokio/src/loom/std/atomic_ptr.rs b/tokio/src/loom/std/atomic_ptr.rs index eb8e47557a2..f7fd56cc69b 100644 --- a/tokio/src/loom/std/atomic_ptr.rs +++ b/tokio/src/loom/std/atomic_ptr.rs @@ -11,10 +11,6 @@ impl AtomicPtr { let inner = std::sync::atomic::AtomicPtr::new(ptr); AtomicPtr { inner } } - - pub(crate) fn with_mut(&mut self, f: impl FnOnce(&mut *mut T) -> R) -> R { - f(self.inner.get_mut()) - } } impl Deref for AtomicPtr { diff --git a/tokio/src/loom/std/mod.rs b/tokio/src/loom/std/mod.rs index 595bdf60ed7..60ee56ad202 100644 --- a/tokio/src/loom/std/mod.rs +++ b/tokio/src/loom/std/mod.rs @@ -6,6 +6,8 @@ mod atomic_u32; mod atomic_u64; mod atomic_u8; mod atomic_usize; +#[cfg(feature = "parking_lot")] +mod parking_lot; mod unsafe_cell; pub(crate) mod cell { @@ -41,24 +43,21 @@ pub(crate) mod rand { pub(crate) mod sync { pub(crate) use std::sync::Arc; - #[cfg(feature = "parking_lot")] - mod pl_wrappers; - // Below, make sure all the feature-influenced types are exported for // internal use. Note however that some are not _currently_ named by // consuming code. #[cfg(feature = "parking_lot")] #[allow(unused_imports)] - pub(crate) use pl_wrappers::{Condvar, Mutex}; - - #[cfg(feature = "parking_lot")] - #[allow(unused_imports)] - pub(crate) use parking_lot::{MutexGuard, WaitTimeoutResult}; + pub(crate) use crate::loom::std::parking_lot::{ + Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult, + }; #[cfg(not(feature = "parking_lot"))] #[allow(unused_imports)] - pub(crate) use std::sync::{Condvar, Mutex, MutexGuard, WaitTimeoutResult}; + pub(crate) use std::sync::{ + Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult, + }; pub(crate) mod atomic { pub(crate) use crate::loom::std::atomic_ptr::AtomicPtr; diff --git a/tokio/src/loom/std/sync/pl_wrappers.rs b/tokio/src/loom/std/parking_lot.rs similarity index 59% rename from tokio/src/loom/std/sync/pl_wrappers.rs rename to tokio/src/loom/std/parking_lot.rs index 3be8ba1c108..25d94af44f5 100644 --- a/tokio/src/loom/std/sync/pl_wrappers.rs +++ b/tokio/src/loom/std/parking_lot.rs @@ -6,25 +6,33 @@ use std::sync::{LockResult, TryLockError, TryLockResult}; use std::time::Duration; -use parking_lot as pl; +// Types that do not need wrapping +pub(crate) use parking_lot::{MutexGuard, RwLockReadGuard, RwLockWriteGuard, WaitTimeoutResult}; /// Adapter for `parking_lot::Mutex` to the `std::sync::Mutex` interface. #[derive(Debug)] -pub(crate) struct Mutex(pl::Mutex); +pub(crate) struct Mutex(parking_lot::Mutex); + +#[derive(Debug)] +pub(crate) struct RwLock(parking_lot::RwLock); + +/// Adapter for `parking_lot::Condvar` to the `std::sync::Condvar` interface. +#[derive(Debug)] +pub(crate) struct Condvar(parking_lot::Condvar); impl Mutex { #[inline] pub(crate) fn new(t: T) -> Mutex { - Mutex(pl::Mutex::new(t)) + Mutex(parking_lot::Mutex::new(t)) } #[inline] - pub(crate) fn lock(&self) -> LockResult> { + pub(crate) fn lock(&self) -> LockResult> { Ok(self.0.lock()) } #[inline] - pub(crate) fn try_lock(&self) -> TryLockResult> { + pub(crate) fn try_lock(&self) -> TryLockResult> { match self.0.try_lock() { Some(guard) => Ok(guard), None => Err(TryLockError::WouldBlock), @@ -35,14 +43,24 @@ impl Mutex { // provided here as needed. } -/// Adapter for `parking_lot::Condvar` to the `std::sync::Condvar` interface. -#[derive(Debug)] -pub(crate) struct Condvar(pl::Condvar); +impl RwLock { + pub(crate) fn new(t: T) -> RwLock { + RwLock(parking_lot::RwLock::new(t)) + } + + pub(crate) fn read(&self) -> LockResult> { + Ok(self.0.read()) + } + + pub(crate) fn write(&self) -> LockResult> { + Ok(self.0.write()) + } +} impl Condvar { #[inline] pub(crate) fn new() -> Condvar { - Condvar(pl::Condvar::new()) + Condvar(parking_lot::Condvar::new()) } #[inline] @@ -58,8 +76,8 @@ impl Condvar { #[inline] pub(crate) fn wait<'a, T>( &self, - mut guard: pl::MutexGuard<'a, T>, - ) -> LockResult> { + mut guard: MutexGuard<'a, T>, + ) -> LockResult> { self.0.wait(&mut guard); Ok(guard) } @@ -67,9 +85,9 @@ impl Condvar { #[inline] pub(crate) fn wait_timeout<'a, T>( &self, - mut guard: pl::MutexGuard<'a, T>, + mut guard: MutexGuard<'a, T>, timeout: Duration, - ) -> LockResult<(pl::MutexGuard<'a, T>, pl::WaitTimeoutResult)> { + ) -> LockResult<(MutexGuard<'a, T>, WaitTimeoutResult)> { let wtr = self.0.wait_for(&mut guard, timeout); Ok((guard, wtr)) } diff --git a/tokio/src/macros/cfg.rs b/tokio/src/macros/cfg.rs index 18beb1bdbb8..4b77544eb5c 100644 --- a/tokio/src/macros/cfg.rs +++ b/tokio/src/macros/cfg.rs @@ -35,6 +35,39 @@ macro_rules! cfg_blocking_impl { } } +/// Enables blocking API internals +macro_rules! cfg_blocking_impl_or_task { + ($($item:item)*) => { + $( + #[cfg(any( + feature = "blocking", + feature = "fs", + feature = "dns", + feature = "io-std", + feature = "rt-threaded", + feature = "task", + ))] + $item + )* + } +} + +/// Enables enter::block_on +macro_rules! cfg_block_on { + ($($item:item)*) => { + $( + #[cfg(any( + feature = "blocking", + feature = "fs", + feature = "dns", + feature = "io-std", + feature = "rt-core", + ))] + $item + )* + } +} + /// Enables blocking API internals macro_rules! cfg_not_blocking_impl { ($($item:item)*) => { @@ -320,3 +353,52 @@ macro_rules! cfg_uds { )* } } + +macro_rules! cfg_unstable { + ($($item:item)*) => { + $( + #[cfg(tokio_unstable)] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + $item + )* + } +} + +macro_rules! cfg_trace { + ($($item:item)*) => { + $( + #[cfg(feature = "tracing")] + #[cfg_attr(docsrs, doc(cfg(feature = "tracing")))] + $item + )* + } +} + +macro_rules! cfg_not_trace { + ($($item:item)*) => { + $( + #[cfg(not(feature = "tracing"))] + $item + )* + } +} + +macro_rules! cfg_coop { + ($($item:item)*) => { + $( + #[cfg(any( + feature = "blocking", + feature = "dns", + feature = "fs", + feature = "io-driver", + feature = "io-std", + feature = "process", + feature = "rt-core", + feature = "sync", + feature = "stream", + feature = "time" + ))] + $item + )* + } +} diff --git a/tokio/src/macros/pin.rs b/tokio/src/macros/pin.rs index 33d8499e10d..ed844ef7d11 100644 --- a/tokio/src/macros/pin.rs +++ b/tokio/src/macros/pin.rs @@ -43,7 +43,7 @@ /// Pinning is useful when using `select!` and stream operators that require `T: /// Stream + Unpin`. /// -/// [`Future`]: https://doc.rust-lang.org/std/future/trait.Future.html +/// [`Future`]: trait@std::future::Future /// [`Box::pin`]: # /// /// # Usage diff --git a/tokio/src/macros/scoped_tls.rs b/tokio/src/macros/scoped_tls.rs index 666f382b284..886f9d44b0e 100644 --- a/tokio/src/macros/scoped_tls.rs +++ b/tokio/src/macros/scoped_tls.rs @@ -4,7 +4,6 @@ use std::cell::Cell; use std::marker; /// Set a reference as a thread-local -#[macro_export] macro_rules! scoped_thread_local { ($(#[$attrs:meta])* $vis:vis static $name:ident: $ty:ty) => ( $(#[$attrs])* diff --git a/tokio/src/macros/select.rs b/tokio/src/macros/select.rs index 51b6fcd608c..52c8fdd3404 100644 --- a/tokio/src/macros/select.rs +++ b/tokio/src/macros/select.rs @@ -30,7 +30,7 @@ /// /// The complete lifecycle of a `select!` expression is as follows: /// -/// 1. Evaluate all provded `` expressions. If the precondition +/// 1. Evaluate all provided `` expressions. If the precondition /// returns `false`, disable the branch for the remainder of the current call /// to `select!`. Re-entering `select!` due to a loop clears the "disabled" /// state. @@ -359,8 +359,11 @@ macro_rules! select { let start = $crate::macros::support::thread_rng_n(BRANCHES); for i in 0..BRANCHES { - let branch = (start + i) % BRANCHES; - + let branch; + #[allow(clippy::modulo_one)] + { + branch = (start + i) % BRANCHES; + } match branch { $( $crate::count!( $($skip)* ) => { diff --git a/tokio/src/net/addr.rs b/tokio/src/net/addr.rs index 343d4e21ff2..5ba898a15a7 100644 --- a/tokio/src/net/addr.rs +++ b/tokio/src/net/addr.rs @@ -18,7 +18,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV /// conversion directly, use [`lookup_host()`](super::lookup_host()). /// /// This trait is sealed and is intended to be opaque. The details of the trait -/// will change. Stabilization is pending enhancements to the Rust langague. +/// will change. Stabilization is pending enhancements to the Rust language. pub trait ToSocketAddrs: sealed::ToSocketAddrsPriv {} type ReadyFuture = future::Ready>; @@ -121,6 +121,20 @@ impl sealed::ToSocketAddrsPriv for (Ipv6Addr, u16) { } } +// ===== impl &[SocketAddr] ===== + +impl ToSocketAddrs for &[SocketAddr] {} + +impl sealed::ToSocketAddrsPriv for &[SocketAddr] { + type Iter = std::vec::IntoIter; + type Future = ReadyFuture; + + fn to_socket_addrs(&self) -> Self::Future { + let iter = self.to_vec().into_iter(); + future::ok(iter) + } +} + cfg_dns! { // ===== impl str ===== diff --git a/tokio/src/net/tcp/listener.rs b/tokio/src/net/tcp/listener.rs index cde22cb636f..fd79b259b92 100644 --- a/tokio/src/net/tcp/listener.rs +++ b/tokio/src/net/tcp/listener.rs @@ -72,7 +72,7 @@ cfg_tcp! { } impl TcpListener { - /// Creates a new TcpListener which will be bound to the specified address. + /// Creates a new TcpListener, which will be bound to the specified address. /// /// The returned listener is ready for accepting connections. /// @@ -80,13 +80,19 @@ impl TcpListener { /// to this listener. The port allocated can be queried via the `local_addr` /// method. /// - /// The address type can be any implementor of `ToSocketAddrs` trait. + /// The address type can be any implementor of the [`ToSocketAddrs`] trait. + /// Note that strings only implement this trait when the **`dns`** feature + /// is enabled, as strings may contain domain names that need to be resolved. /// /// If `addr` yields multiple addresses, bind will be attempted with each of /// the addresses until one succeeds and returns the listener. If none of /// the addresses succeed in creating a listener, the error returned from /// the last attempt (the last address) is returned. /// + /// This function sets the `SO_REUSEADDR` option on the socket. + /// + /// [`ToSocketAddrs`]: trait@crate::net::ToSocketAddrs + /// /// # Examples /// /// ```no_run @@ -96,7 +102,26 @@ impl TcpListener { /// /// #[tokio::main] /// async fn main() -> io::Result<()> { - /// let listener = TcpListener::bind("127.0.0.1:0").await?; + /// let listener = TcpListener::bind("127.0.0.1:2345").await?; + /// + /// // use the listener + /// + /// # let _ = listener; + /// Ok(()) + /// } + /// ``` + /// + /// Without the `dns` feature: + /// + /// ```no_run + /// use tokio::net::TcpListener; + /// use std::net::Ipv4Addr; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let listener = TcpListener::bind((Ipv4Addr::new(127, 0, 0, 1), 2345)).await?; /// /// // use the listener /// @@ -135,7 +160,7 @@ impl TcpListener { /// established, the corresponding [`TcpStream`] and the remote peer's /// address will be returned. /// - /// [`TcpStream`]: ../struct.TcpStream.html + /// [`TcpStream`]: struct@crate::net::TcpStream /// /// # Examples /// @@ -320,7 +345,7 @@ impl TcpListener { /// /// For more information about this option, see [`set_ttl`]. /// - /// [`set_ttl`]: #method.set_ttl + /// [`set_ttl`]: method@Self::set_ttl /// /// # Examples /// diff --git a/tokio/src/net/tcp/mod.rs b/tokio/src/net/tcp/mod.rs index d5354b38d25..7ad36eb0b11 100644 --- a/tokio/src/net/tcp/mod.rs +++ b/tokio/src/net/tcp/mod.rs @@ -9,5 +9,8 @@ pub use incoming::Incoming; mod split; pub use split::{ReadHalf, WriteHalf}; +mod split_owned; +pub use split_owned::{OwnedReadHalf, OwnedWriteHalf, ReuniteError}; + pub(crate) mod stream; pub(crate) use stream::TcpStream; diff --git a/tokio/src/net/tcp/split.rs b/tokio/src/net/tcp/split.rs index cce50f6ab36..469056acc5b 100644 --- a/tokio/src/net/tcp/split.rs +++ b/tokio/src/net/tcp/split.rs @@ -19,14 +19,32 @@ use std::net::Shutdown; use std::pin::Pin; use std::task::{Context, Poll}; -/// Read half of a `TcpStream`. +/// Read half of a [`TcpStream`], created by [`split`]. +/// +/// Reading from a `ReadHalf` is usually done using the convenience methods found on the +/// [`AsyncReadExt`] trait. Examples import this trait through [the prelude]. +/// +/// [`TcpStream`]: TcpStream +/// [`split`]: TcpStream::split() +/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt +/// [the prelude]: crate::prelude #[derive(Debug)] pub struct ReadHalf<'a>(&'a TcpStream); -/// Write half of a `TcpStream`. +/// Write half of a [`TcpStream`], created by [`split`]. +/// +/// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will +/// shut down the TCP stream in the write direction. +/// +/// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found +/// on the [`AsyncWriteExt`] trait. Examples import this trait through [the prelude]. /// -/// Note that in the `AsyncWrite` implemenation of `TcpStreamWriteHalf`, -/// `poll_shutdown` actually shuts down the TCP stream in the write direction. +/// [`TcpStream`]: TcpStream +/// [`split`]: TcpStream::split() +/// [`AsyncWrite`]: trait@crate::io::AsyncWrite +/// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown +/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt +/// [the prelude]: crate::prelude #[derive(Debug)] pub struct WriteHalf<'a>(&'a TcpStream); @@ -74,6 +92,8 @@ impl ReadHalf<'_> { /// /// See the [`TcpStream::peek`] level documenation for more details. /// + /// [`TcpStream::peek`]: TcpStream::peek + /// /// # Examples /// /// ```no_run @@ -101,7 +121,10 @@ impl ReadHalf<'_> { /// } /// ``` /// - /// [`TcpStream::peek`]: TcpStream::peek + /// The [`read`] method is defined on the [`AsyncReadExt`] trait. + /// + /// [`read`]: fn@crate::io::AsyncReadExt::read + /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result { poll_fn(|cx| self.poll_peek(cx, buf)).await } diff --git a/tokio/src/net/tcp/split_owned.rs b/tokio/src/net/tcp/split_owned.rs new file mode 100644 index 00000000000..3f6ee33f360 --- /dev/null +++ b/tokio/src/net/tcp/split_owned.rs @@ -0,0 +1,267 @@ +//! `TcpStream` owned split support. +//! +//! A `TcpStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf` +//! with the `TcpStream::into_split` method. `OwnedReadHalf` implements +//! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`. +//! +//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized +//! split has no associated overhead and enforces all invariants at the type +//! level. + +use crate::future::poll_fn; +use crate::io::{AsyncRead, AsyncWrite}; +use crate::net::TcpStream; + +use bytes::Buf; +use std::error::Error; +use std::mem::MaybeUninit; +use std::net::Shutdown; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::{fmt, io}; + +/// Owned read half of a [`TcpStream`], created by [`into_split`]. +/// +/// Reading from an `OwnedReadHalf` is usually done using the convenience methods found +/// on the [`AsyncReadExt`] trait. Examples import this trait through [the prelude]. +/// +/// [`TcpStream`]: TcpStream +/// [`into_split`]: TcpStream::into_split() +/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt +/// [the prelude]: crate::prelude +#[derive(Debug)] +pub struct OwnedReadHalf { + inner: Arc, +} + +/// Owned write half of a [`TcpStream`], created by [`into_split`]. +/// +/// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will +/// shut down the TCP stream in the write direction. +/// +/// Dropping the write half will shutdown the write half of the TCP stream. +/// +/// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found +/// on the [`AsyncWriteExt`] trait. Examples import this trait through [the prelude]. +/// +/// [`TcpStream`]: TcpStream +/// [`into_split`]: TcpStream::into_split() +/// [`AsyncWrite`]: trait@crate::io::AsyncWrite +/// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown +/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt +/// [the prelude]: crate::prelude +#[derive(Debug)] +pub struct OwnedWriteHalf { + inner: Arc, + shutdown_on_drop: bool, +} + +pub(crate) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) { + let arc = Arc::new(stream); + let read = OwnedReadHalf { + inner: Arc::clone(&arc), + }; + let write = OwnedWriteHalf { + inner: arc, + shutdown_on_drop: true, + }; + (read, write) +} + +pub(crate) fn reunite( + read: OwnedReadHalf, + write: OwnedWriteHalf, +) -> Result { + if Arc::ptr_eq(&read.inner, &write.inner) { + write.forget(); + // This unwrap cannot fail as the api does not allow creating more than two Arcs, + // and we just dropped the other half. + Ok(Arc::try_unwrap(read.inner).expect("Too many handles to Arc")) + } else { + Err(ReuniteError(read, write)) + } +} + +/// Error indicating two halves were not from the same socket, and thus could +/// not be reunited. +#[derive(Debug)] +pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf); + +impl fmt::Display for ReuniteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "tried to reunite halves that are not from the same socket" + ) + } +} + +impl Error for ReuniteError {} + +impl OwnedReadHalf { + /// Attempts to put the two halves of a `TcpStream` back together and + /// recover the original socket. Succeeds only if the two halves + /// originated from the same call to [`into_split`]. + /// + /// [`into_split`]: TcpStream::into_split() + pub fn reunite(self, other: OwnedWriteHalf) -> Result { + reunite(self, other) + } + + /// Attempt to receive data on the socket, without removing that data from + /// the queue, registering the current task for wakeup if data is not yet + /// available. + /// + /// See the [`TcpStream::poll_peek`] level documenation for more details. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::io; + /// use tokio::net::TcpStream; + /// + /// use futures::future::poll_fn; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let stream = TcpStream::connect("127.0.0.1:8000").await?; + /// let (mut read_half, _) = stream.into_split(); + /// let mut buf = [0; 10]; + /// + /// poll_fn(|cx| { + /// read_half.poll_peek(cx, &mut buf) + /// }).await?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// [`TcpStream::poll_peek`]: TcpStream::poll_peek + pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + self.inner.poll_peek2(cx, buf) + } + + /// Receives data on the socket from the remote address to which it is + /// connected, without removing that data from the queue. On success, + /// returns the number of bytes peeked. + /// + /// See the [`TcpStream::peek`] level documenation for more details. + /// + /// [`TcpStream::peek`]: TcpStream::peek + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use tokio::prelude::*; + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// // Connect to a peer + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// let (mut read_half, _) = stream.into_split(); + /// + /// let mut b1 = [0; 10]; + /// let mut b2 = [0; 10]; + /// + /// // Peek at the data + /// let n = read_half.peek(&mut b1).await?; + /// + /// // Read the data + /// assert_eq!(n, read_half.read(&mut b2[..n]).await?); + /// assert_eq!(&b1[..n], &b2[..n]); + /// + /// Ok(()) + /// } + /// ``` + /// + /// The [`read`] method is defined on the [`AsyncReadExt`] trait. + /// + /// [`read`]: fn@crate::io::AsyncReadExt::read + /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt + pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result { + poll_fn(|cx| self.poll_peek(cx, buf)).await + } +} + +impl AsyncRead for OwnedReadHalf { + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { + false + } + + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.inner.poll_read_priv(cx, buf) + } +} + +impl OwnedWriteHalf { + /// Attempts to put the two halves of a `TcpStream` back together and + /// recover the original socket. Succeeds only if the two halves + /// originated from the same call to [`into_split`]. + /// + /// [`into_split`]: TcpStream::into_split() + pub fn reunite(self, other: OwnedReadHalf) -> Result { + reunite(other, self) + } + + /// Drop the write half, but don't issue a TCP shutdown. + pub fn forget(mut self) { + self.shutdown_on_drop = false; + drop(self); + } +} + +impl Drop for OwnedWriteHalf { + fn drop(&mut self) { + if self.shutdown_on_drop { + let _ = self.inner.shutdown(Shutdown::Write); + } + } +} + +impl AsyncWrite for OwnedWriteHalf { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.inner.poll_write_priv(cx, buf) + } + + fn poll_write_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> { + self.inner.poll_write_buf_priv(cx, buf) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + // tcp flush is a no-op + Poll::Ready(Ok(())) + } + + // `poll_shutdown` on a write half shutdowns the stream in the "write" direction. + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + self.inner.shutdown(Shutdown::Write).into() + } +} + +impl AsRef for OwnedReadHalf { + fn as_ref(&self) -> &TcpStream { + &*self.inner + } +} + +impl AsRef for OwnedWriteHalf { + fn as_ref(&self) -> &TcpStream { + &*self.inner + } +} diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index f9c2e98fb74..cc81e116069 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -1,6 +1,7 @@ use crate::future::poll_fn; use crate::io::{AsyncRead, AsyncWrite, PollEvented}; use crate::net::tcp::split::{split, ReadHalf, WriteHalf}; +use crate::net::tcp::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf}; use crate::net::ToSocketAddrs; use bytes::Buf; @@ -20,9 +21,16 @@ cfg_tcp! { /// A TCP stream can either be created by connecting to an endpoint, via the /// [`connect`] method, or by [accepting] a connection from a [listener]. /// - /// [`connect`]: struct.TcpStream.html#method.connect - /// [accepting]: struct.TcpListener.html#method.accept - /// [listener]: struct.TcpListener.html + /// Reading and writing to a `TcpStream` is usually done using the + /// convenience methods found on the [`AsyncReadExt`] and [`AsyncWriteExt`] + /// traits. Examples import these traits through [the prelude]. + /// + /// [`connect`]: method@TcpStream::connect + /// [accepting]: method@super::TcpListener::accept + /// [listener]: struct@super::TcpListener + /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt + /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt + /// [the prelude]: crate::prelude /// /// # Examples /// @@ -42,6 +50,11 @@ cfg_tcp! { /// Ok(()) /// } /// ``` + /// + /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait. + /// + /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all + /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt pub struct TcpStream { io: PollEvented, } @@ -50,14 +63,18 @@ cfg_tcp! { impl TcpStream { /// Opens a TCP connection to a remote host. /// - /// `addr` is an address of the remote host. Anything which implements - /// `ToSocketAddrs` trait can be supplied for the address. + /// `addr` is an address of the remote host. Anything which implements the + /// [`ToSocketAddrs`] trait can be supplied as the address. Note that + /// strings only implement this trait when the **`dns`** feature is enabled, + /// as strings may contain domain names that need to be resolved. /// /// If `addr` yields multiple addresses, connect will be attempted with each /// of the addresses until a connection is successful. If none of the /// addresses result in a successful connection, the error returned from the /// last connection attempt (the last address) is returned. /// + /// [`ToSocketAddrs`]: trait@crate::net::ToSocketAddrs + /// /// # Examples /// /// ```no_run @@ -76,6 +93,31 @@ impl TcpStream { /// Ok(()) /// } /// ``` + /// + /// Without the `dns` feature: + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use tokio::prelude::*; + /// use std::error::Error; + /// use std::net::Ipv4Addr; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// // Connect to a peer + /// let mut stream = TcpStream::connect((Ipv4Addr::new(127, 0, 0, 1), 8080)).await?; + /// + /// // Write some data. + /// stream.write_all(b"hello world!").await?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait. + /// + /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all + /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt pub async fn connect(addr: A) -> io::Result { let addrs = addr.to_socket_addrs().await?; @@ -302,6 +344,11 @@ impl TcpStream { /// Ok(()) /// } /// ``` + /// + /// The [`read`] method is defined on the [`AsyncReadExt`] trait. + /// + /// [`read`]: fn@crate::io::AsyncReadExt::read + /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result { poll_fn(|cx| self.poll_peek(cx, buf)).await } @@ -615,12 +662,28 @@ impl TcpStream { /// Splits a `TcpStream` into a read half and a write half, which can be used /// to read and write the stream concurrently. /// - /// See the module level documenation of [`split`](super::split) for more - /// details. + /// This method is more efficient than [`into_split`], but the halves cannot be + /// moved into independently spawned tasks. + /// + /// [`into_split`]: TcpStream::into_split() pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) { split(self) } + /// Splits a `TcpStream` into a read half and a write half, which can be used + /// to read and write the stream concurrently. + /// + /// Unlike [`split`], the owned halves can be moved to separate tasks, however + /// this comes at the cost of a heap allocation. + /// + /// **Note::** Dropping the write half will shutdown the write half of the TCP + /// stream. This is equivalent to calling `shutdown(Write)` on the `TcpStream`. + /// + /// [`split`]: TcpStream::split() + pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { + split_owned(self) + } + // == Poll IO functions that takes `&self` == // // They are not public because (taken from the doc of `PollEvented`): diff --git a/tokio/src/net/udp/socket.rs b/tokio/src/net/udp/socket.rs index 604da98bd5d..97090a206d3 100644 --- a/tokio/src/net/udp/socket.rs +++ b/tokio/src/net/udp/socket.rs @@ -74,9 +74,6 @@ impl UdpSocket { /// Splits the `UdpSocket` into a receive half and a send half. The two parts /// can be used to receive and send datagrams concurrently, even from two /// different tasks. - /// - /// See the module level documenation of [`split`](super::split) for more - /// details. pub fn split(self) -> (RecvHalf, SendHalf) { split(self) } @@ -114,7 +111,7 @@ impl UdpSocket { /// The [`connect`] method will connect this socket to a remote address. The future /// will resolve to an error if the socket is not connected. /// - /// [`connect`]: #method.connect + /// [`connect`]: method@Self::connect pub async fn send(&mut self, buf: &[u8]) -> io::Result { poll_fn(|cx| self.poll_send(cx, buf)).await } @@ -153,7 +150,7 @@ impl UdpSocket { /// The [`connect`] method will connect this socket to a remote address. The future /// will fail if the socket is not connected. /// - /// [`connect`]: #method.connect + /// [`connect`]: method@Self::connect pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result { poll_fn(|cx| self.poll_recv(cx, buf)).await } @@ -238,7 +235,7 @@ impl UdpSocket { /// /// For more information about this option, see [`set_broadcast`]. /// - /// [`set_broadcast`]: #method.set_broadcast + /// [`set_broadcast`]: method@Self::set_broadcast pub fn broadcast(&self) -> io::Result { self.io.get_ref().broadcast() } @@ -255,7 +252,7 @@ impl UdpSocket { /// /// For more information about this option, see [`set_multicast_loop_v4`]. /// - /// [`set_multicast_loop_v4`]: #method.set_multicast_loop_v4 + /// [`set_multicast_loop_v4`]: method@Self::set_multicast_loop_v4 pub fn multicast_loop_v4(&self) -> io::Result { self.io.get_ref().multicast_loop_v4() } @@ -275,7 +272,7 @@ impl UdpSocket { /// /// For more information about this option, see [`set_multicast_ttl_v4`]. /// - /// [`set_multicast_ttl_v4`]: #method.set_multicast_ttl_v4 + /// [`set_multicast_ttl_v4`]: method@Self::set_multicast_ttl_v4 pub fn multicast_ttl_v4(&self) -> io::Result { self.io.get_ref().multicast_ttl_v4() } @@ -297,7 +294,7 @@ impl UdpSocket { /// /// For more information about this option, see [`set_multicast_loop_v6`]. /// - /// [`set_multicast_loop_v6`]: #method.set_multicast_loop_v6 + /// [`set_multicast_loop_v6`]: method@Self::set_multicast_loop_v6 pub fn multicast_loop_v6(&self) -> io::Result { self.io.get_ref().multicast_loop_v6() } @@ -317,7 +314,7 @@ impl UdpSocket { /// /// For more information about this option, see [`set_ttl`]. /// - /// [`set_ttl`]: #method.set_ttl + /// [`set_ttl`]: method@Self::set_ttl pub fn ttl(&self) -> io::Result { self.io.get_ref().ttl() } @@ -354,7 +351,7 @@ impl UdpSocket { /// /// For more information about this option, see [`join_multicast_v4`]. /// - /// [`join_multicast_v4`]: #method.join_multicast_v4 + /// [`join_multicast_v4`]: method@Self::join_multicast_v4 pub fn leave_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> io::Result<()> { self.io.get_ref().leave_multicast_v4(&multiaddr, &interface) } @@ -363,7 +360,7 @@ impl UdpSocket { /// /// For more information about this option, see [`join_multicast_v6`]. /// - /// [`join_multicast_v6`]: #method.join_multicast_v6 + /// [`join_multicast_v6`]: method@Self::join_multicast_v6 pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> { self.io.get_ref().leave_multicast_v6(multiaddr, interface) } diff --git a/tokio/src/net/udp/split.rs b/tokio/src/net/udp/split.rs index 55542cb6316..e8d434aa1ef 100644 --- a/tokio/src/net/udp/split.rs +++ b/tokio/src/net/udp/split.rs @@ -1,6 +1,6 @@ -//! [`UdpSocket`](../struct.UdpSocket.html) split support. +//! [`UdpSocket`](crate::net::UdpSocket) split support. //! -//! The [`split`](../struct.UdpSocket.html#method.split) method splits a +//! The [`split`](method@crate::net::UdpSocket::split) method splits a //! `UdpSocket` into a receive half and a send half, which can be used to //! receive and send datagrams concurrently, even from two different tasks. //! @@ -23,14 +23,14 @@ use std::sync::Arc; /// The send half after [`split`](super::UdpSocket::split). /// -/// Use [`send_to`](#method.send_to) or [`send`](#method.send) to send +/// Use [`send_to`](method@Self::send_to) or [`send`](method@Self::send) to send /// datagrams. #[derive(Debug)] pub struct SendHalf(Arc); /// The recv half after [`split`](super::UdpSocket::split). /// -/// Use [`recv_from`](#method.recv_from) or [`recv`](#method.recv) to receive +/// Use [`recv_from`](method@Self::recv_from) or [`recv`](method@Self::recv) to receive /// datagrams. #[derive(Debug)] pub struct RecvHalf(Arc); diff --git a/tokio/src/net/unix/listener.rs b/tokio/src/net/unix/listener.rs index 5acc1b7e822..9b76cb01fd7 100644 --- a/tokio/src/net/unix/listener.rs +++ b/tokio/src/net/unix/listener.rs @@ -3,7 +3,6 @@ use crate::io::PollEvented; use crate::net::unix::{Incoming, UnixStream}; use mio::Ready; -use mio_uds; use std::convert::TryFrom; use std::fmt; use std::io; diff --git a/tokio/src/net/unix/stream.rs b/tokio/src/net/unix/stream.rs index 1ce9c863f63..beae699962d 100644 --- a/tokio/src/net/unix/stream.rs +++ b/tokio/src/net/unix/stream.rs @@ -111,9 +111,6 @@ impl UnixStream { /// Split a `UnixStream` into a read half and a write half, which can be used /// to read and write the stream concurrently. - /// - /// See the module level documenation of [`split`](super::split) for more - /// details. pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) { split(self) } diff --git a/tokio/src/net/unix/ucred.rs b/tokio/src/net/unix/ucred.rs index cdd77ea4140..466aedc21fe 100644 --- a/tokio/src/net/unix/ucred.rs +++ b/tokio/src/net/unix/ucred.rs @@ -22,7 +22,7 @@ pub(crate) use self::impl_linux::get_peer_cred; ))] pub(crate) use self::impl_macos::get_peer_cred; -#[cfg(any(target_os = "solaris"))] +#[cfg(any(target_os = "solaris", target_os = "illumos"))] pub(crate) use self::impl_solaris::get_peer_cred; #[cfg(any(target_os = "linux", target_os = "android"))] @@ -110,7 +110,7 @@ pub(crate) mod impl_macos { } } -#[cfg(any(target_os = "solaris"))] +#[cfg(any(target_os = "solaris", target_os = "illumos"))] pub(crate) mod impl_solaris { use crate::net::unix::UnixStream; use std::io; diff --git a/tokio/src/park/mod.rs b/tokio/src/park/mod.rs index a3e49bbedec..04d3051d807 100644 --- a/tokio/src/park/mod.rs +++ b/tokio/src/park/mod.rs @@ -42,7 +42,7 @@ cfg_resource_drivers! { mod thread; pub(crate) use self::thread::ParkThread; -cfg_blocking_impl! { +cfg_block_on! { pub(crate) use self::thread::{CachedParkThread, ParkError}; } diff --git a/tokio/src/park/thread.rs b/tokio/src/park/thread.rs index a8cdf1432ba..2e2397c7255 100644 --- a/tokio/src/park/thread.rs +++ b/tokio/src/park/thread.rs @@ -204,7 +204,7 @@ impl Unpark for UnparkThread { } } -cfg_blocking_impl! { +cfg_block_on! { use std::marker::PhantomData; use std::rc::Rc; diff --git a/tokio/src/process/mod.rs b/tokio/src/process/mod.rs index 7231511235e..e04a43510a9 100644 --- a/tokio/src/process/mod.rs +++ b/tokio/src/process/mod.rs @@ -6,6 +6,8 @@ //! variants) return "future aware" types that interoperate with Tokio. The asynchronous process //! support is provided through signal handling on Unix and system APIs on Windows. //! +//! [`std::process::Command`]: std::process::Command +//! //! # Examples //! //! Here's an example program which will spawn `echo hello world` and then wait @@ -140,6 +142,9 @@ use std::task::Poll; /// [output](Command::output). /// /// `Command` uses asynchronous versions of some `std` types (for example [`Child`]). +/// +/// [`std::process::Command`]: std::process::Command +/// [`Child`]: struct@Child #[derive(Debug)] pub struct Command { std: StdCommand, @@ -171,7 +176,7 @@ impl Command { /// The search path to be used may be controlled by setting the /// `PATH` environment variable on the Command, /// but this has some implementation limitations on Windows - /// (see issue rust-lang/rust#37519). + /// (see issue [rust-lang/rust#37519]). /// /// # Examples /// @@ -181,6 +186,8 @@ impl Command { /// use tokio::process::Command; /// let command = Command::new("sh"); /// ``` + /// + /// [rust-lang/rust#37519]: https://github.com/rust-lang/rust/issues/37519 pub fn new>(program: S) -> Command { Self::from(StdCommand::new(program)) } @@ -204,7 +211,7 @@ impl Command { /// /// To pass multiple arguments see [`args`]. /// - /// [`args`]: #method.args + /// [`args`]: method@Self::args /// /// # Examples /// @@ -226,7 +233,7 @@ impl Command { /// /// To pass a single argument see [`arg`]. /// - /// [`arg`]: #method.arg + /// [`arg`]: method@Self::arg /// /// # Examples /// @@ -701,7 +708,7 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); let ret = Pin::new(&mut self.inner).poll(cx); @@ -710,6 +717,10 @@ where self.kill_on_drop = false; } + if ret.is_ready() { + coop.made_progress(); + } + ret } } @@ -872,6 +883,11 @@ impl AsyncWrite for ChildStdin { } impl AsyncRead for ChildStdout { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { + // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/process.rs#L314 + false + } + fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -882,6 +898,11 @@ impl AsyncRead for ChildStdout { } impl AsyncRead for ChildStderr { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { + // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/process.rs#L375 + false + } + fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/tokio/src/runtime/basic_scheduler.rs b/tokio/src/runtime/basic_scheduler.rs index 301554280f9..7e1c257cc86 100644 --- a/tokio/src/runtime/basic_scheduler.rs +++ b/tokio/src/runtime/basic_scheduler.rs @@ -121,7 +121,7 @@ where F: Future, { enter(self, |scheduler, context| { - let _enter = runtime::enter(); + let _enter = runtime::enter(false); let waker = waker_ref(&scheduler.spawner.shared); let mut cx = std::task::Context::from_waker(&waker); diff --git a/tokio/src/runtime/blocking/mod.rs b/tokio/src/runtime/blocking/mod.rs index 5c808335cc0..0b36a75f655 100644 --- a/tokio/src/runtime/blocking/mod.rs +++ b/tokio/src/runtime/blocking/mod.rs @@ -9,7 +9,7 @@ cfg_blocking_impl! { mod schedule; mod shutdown; - mod task; + pub(crate) mod task; use crate::runtime::Builder; diff --git a/tokio/src/runtime/blocking/pool.rs b/tokio/src/runtime/blocking/pool.rs index 60554ff6d84..819eadc36b3 100644 --- a/tokio/src/runtime/blocking/pool.rs +++ b/tokio/src/runtime/blocking/pool.rs @@ -149,7 +149,7 @@ impl fmt::Debug for BlockingPool { // ===== impl Spawner ===== impl Spawner { - fn spawn(&self, task: Task, rt: &Handle) -> Result<(), ()> { + pub(crate) fn spawn(&self, task: Task, rt: &Handle) -> Result<(), ()> { let shutdown_tx = { let mut shared = self.inner.shared.lock().unwrap(); diff --git a/tokio/src/runtime/blocking/schedule.rs b/tokio/src/runtime/blocking/schedule.rs index e10778d5304..4e044ab2987 100644 --- a/tokio/src/runtime/blocking/schedule.rs +++ b/tokio/src/runtime/blocking/schedule.rs @@ -6,7 +6,7 @@ use crate::runtime::task::{self, Task}; /// /// We avoid storing the task by forgetting it in `bind` and re-materializing it /// in `release. -pub(super) struct NoopSchedule; +pub(crate) struct NoopSchedule; impl task::Schedule for NoopSchedule { fn bind(_task: Task) -> NoopSchedule { diff --git a/tokio/src/runtime/blocking/shutdown.rs b/tokio/src/runtime/blocking/shutdown.rs index 5ee8af0fbc0..e76a7013552 100644 --- a/tokio/src/runtime/blocking/shutdown.rs +++ b/tokio/src/runtime/blocking/shutdown.rs @@ -33,15 +33,25 @@ impl Receiver { /// duration. If `timeout` is `None`, then the thread is blocked until the /// shutdown signal is received. pub(crate) fn wait(&mut self, timeout: Option) { - use crate::runtime::enter::{enter, try_enter}; + use crate::runtime::enter::try_enter; - let mut e = if std::thread::panicking() { - match try_enter() { - Some(enter) => enter, - _ => return, + if timeout == Some(Duration::from_nanos(0)) { + return; + } + + let mut e = match try_enter(false) { + Some(enter) => enter, + _ => { + if std::thread::panicking() { + // Don't panic in a panic + return; + } else { + panic!( + "Cannot drop a runtime in a context where blocking is not allowed. \ + This happens when a runtime is dropped from within an asynchronous context." + ); + } } - } else { - enter() }; // The oneshot completes with an Err diff --git a/tokio/src/runtime/blocking/task.rs b/tokio/src/runtime/blocking/task.rs index f98b85494cc..a521af4630c 100644 --- a/tokio/src/runtime/blocking/task.rs +++ b/tokio/src/runtime/blocking/task.rs @@ -3,25 +3,28 @@ use std::pin::Pin; use std::task::{Context, Poll}; /// Converts a function to a future that completes on poll -pub(super) struct BlockingTask { +pub(crate) struct BlockingTask { func: Option, } impl BlockingTask { /// Initializes a new blocking task from the given function - pub(super) fn new(func: T) -> BlockingTask { + pub(crate) fn new(func: T) -> BlockingTask { BlockingTask { func: Some(func) } } } +// The closure `F` is never pinned +impl Unpin for BlockingTask {} + impl Future for BlockingTask where T: FnOnce() -> R, { type Output = R; - fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { - let me = unsafe { self.get_unchecked_mut() }; + fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + let me = &mut *self; let func = me .func .take() diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index f9b38151523..d149875003f 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -15,8 +15,8 @@ use std::sync::Arc; /// See function level documentation for details on the various configuration /// settings. /// -/// [`build`]: #method.build -/// [`Builder::new`]: #method.new +/// [`build`]: method@Self::build +/// [`Builder::new`]: method@Self::new /// /// # Examples /// @@ -429,12 +429,13 @@ cfg_rt_core! { /// Sets runtime to use a simpler scheduler that runs all tasks on the current-thread. /// /// The executor and all necessary drivers will all be run on the current - /// thread during `block_on` calls. + /// thread during [`block_on`] calls. /// /// See also [the module level documentation][1], which has a section on scheduler /// types. /// /// [1]: index.html#runtime-configurations + /// [`block_on`]: Runtime::block_on pub fn basic_scheduler(&mut self) -> &mut Self { self.kind = Kind::Basic; self @@ -490,10 +491,12 @@ cfg_rt_threaded! { } fn build_threaded_runtime(&mut self) -> io::Result { + use crate::loom::sys::num_cpus; use crate::runtime::{Kind, ThreadPool}; use crate::runtime::park::Parker; + use std::cmp; - let core_threads = self.core_threads.unwrap_or_else(crate::loom::sys::num_cpus); + let core_threads = self.core_threads.unwrap_or_else(|| cmp::min(self.max_threads, num_cpus())); assert!(core_threads <= self.max_threads, "Core threads number cannot be above max limit"); let clock = time::create_clock(); diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index cfc51def273..1b267f481e2 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -47,9 +47,9 @@ cfg_rt_core! { } } -/// Set this [`ThreadContext`] as the current active [`ThreadContext`]. +/// Set this [`Handle`] as the current active [`Handle`]. /// -/// [`ThreadContext`]: struct.ThreadContext.html +/// [`Handle`]: Handle pub(crate) fn enter(new: Handle, f: F) -> R where F: FnOnce() -> R, diff --git a/tokio/src/runtime/enter.rs b/tokio/src/runtime/enter.rs index afdb67a3b7c..56a7c57b6c6 100644 --- a/tokio/src/runtime/enter.rs +++ b/tokio/src/runtime/enter.rs @@ -2,7 +2,26 @@ use std::cell::{Cell, RefCell}; use std::fmt; use std::marker::PhantomData; -thread_local!(static ENTERED: Cell = Cell::new(false)); +#[derive(Debug, Clone, Copy)] +pub(crate) enum EnterContext { + Entered { + #[allow(dead_code)] + allow_blocking: bool, + }, + NotEntered, +} + +impl EnterContext { + pub(crate) fn is_entered(self) -> bool { + if let EnterContext::Entered { .. } = self { + true + } else { + false + } + } +} + +thread_local!(static ENTERED: Cell = Cell::new(EnterContext::NotEntered)); /// Represents an executor context. pub(crate) struct Enter { @@ -11,8 +30,8 @@ pub(crate) struct Enter { /// Marks the current thread as being within the dynamic extent of an /// executor. -pub(crate) fn enter() -> Enter { - if let Some(enter) = try_enter() { +pub(crate) fn enter(allow_blocking: bool) -> Enter { + if let Some(enter) = try_enter(allow_blocking) { return enter; } @@ -26,12 +45,12 @@ pub(crate) fn enter() -> Enter { /// Tries to enter a runtime context, returns `None` if already in a runtime /// context. -pub(crate) fn try_enter() -> Option { +pub(crate) fn try_enter(allow_blocking: bool) -> Option { ENTERED.with(|c| { - if c.get() { + if c.get().is_entered() { None } else { - c.set(true); + c.set(EnterContext::Entered { allow_blocking }); Some(Enter { _p: PhantomData }) } }) @@ -47,45 +66,87 @@ pub(crate) fn try_enter() -> Option { #[cfg(all(feature = "rt-threaded", feature = "blocking"))] pub(crate) fn exit R, R>(f: F) -> R { // Reset in case the closure panics - struct Reset; + struct Reset(EnterContext); impl Drop for Reset { fn drop(&mut self) { ENTERED.with(|c| { - c.set(true); + assert!(!c.get().is_entered(), "closure claimed permanent executor"); + c.set(self.0); }); } } - ENTERED.with(|c| { - debug_assert!(c.get()); - c.set(false); + let was = ENTERED.with(|c| { + let e = c.get(); + assert!(e.is_entered(), "asked to exit when not entered"); + c.set(EnterContext::NotEntered); + e }); - let reset = Reset; - let ret = f(); - std::mem::forget(reset); + let _reset = Reset(was); + // dropping _reset after f() will reset ENTERED + f() +} - ENTERED.with(|c| { - assert!(!c.get(), "closure claimed permanent executor"); - c.set(true); - }); +cfg_rt_core! { + cfg_rt_util! { + /// Disallow blocking in the current runtime context until the guard is dropped. + pub(crate) fn disallow_blocking() -> DisallowBlockingGuard { + let reset = ENTERED.with(|c| { + if let EnterContext::Entered { + allow_blocking: true, + } = c.get() + { + c.set(EnterContext::Entered { + allow_blocking: false, + }); + true + } else { + false + } + }); + DisallowBlockingGuard(reset) + } - ret + pub(crate) struct DisallowBlockingGuard(bool); + impl Drop for DisallowBlockingGuard { + fn drop(&mut self) { + if self.0 { + // XXX: Do we want some kind of assertion here, or is "best effort" okay? + ENTERED.with(|c| { + if let EnterContext::Entered { + allow_blocking: false, + } = c.get() + { + c.set(EnterContext::Entered { + allow_blocking: true, + }); + } + }) + } + } + } + } } -cfg_blocking_impl! { - use crate::park::ParkError; - use std::time::Duration; +cfg_rt_threaded! { + cfg_blocking! { + /// Returns true if in a runtime context. + pub(crate) fn context() -> EnterContext { + ENTERED.with(|c| c.get()) + } + } +} +cfg_block_on! { impl Enter { /// Blocks the thread on the specified future, returning the value with /// which that future completes. - pub(crate) fn block_on(&mut self, mut f: F) -> Result + pub(crate) fn block_on(&mut self, f: F) -> Result where F: std::future::Future, { use crate::park::{CachedParkThread, Park}; - use std::pin::Pin; use std::task::Context; use std::task::Poll::Ready; @@ -93,9 +154,7 @@ cfg_blocking_impl! { let waker = park.get_unpark()?.into_waker(); let mut cx = Context::from_waker(&waker); - // `block_on` takes ownership of `f`. Once it is pinned here, the original `f` binding can - // no longer be accessed, making the pinning safe. - let mut f = unsafe { Pin::new_unchecked(&mut f) }; + pin!(f); loop { if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) { @@ -105,17 +164,23 @@ cfg_blocking_impl! { park.park()?; } } + } +} +cfg_blocking_impl! { + use crate::park::ParkError; + use std::time::Duration; + + impl Enter { /// Blocks the thread on the specified future for **at most** `timeout` /// /// If the future completes before `timeout`, the result is returned. If /// `timeout` elapses, then `Err` is returned. - pub(crate) fn block_on_timeout(&mut self, mut f: F, timeout: Duration) -> Result + pub(crate) fn block_on_timeout(&mut self, f: F, timeout: Duration) -> Result where F: std::future::Future, { use crate::park::{CachedParkThread, Park}; - use std::pin::Pin; use std::task::Context; use std::task::Poll::Ready; use std::time::Instant; @@ -124,9 +189,7 @@ cfg_blocking_impl! { let waker = park.get_unpark()?.into_waker(); let mut cx = Context::from_waker(&waker); - // `block_on` takes ownership of `f`. Once it is pinned here, the original `f` binding can - // no longer be accessed, making the pinning safe. - let mut f = unsafe { Pin::new_unchecked(&mut f) }; + pin!(f); let when = Instant::now() + timeout; loop { @@ -155,8 +218,8 @@ impl fmt::Debug for Enter { impl Drop for Enter { fn drop(&mut self) { ENTERED.with(|c| { - assert!(c.get()); - c.set(false); + assert!(c.get().is_entered()); + c.set(EnterContext::NotEntered); }); } } diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index db53543e852..0716a7fadca 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -1,6 +1,11 @@ use crate::runtime::{blocking, context, io, time, Spawner}; use std::{error, fmt}; +cfg_blocking! { + use crate::runtime::task; + use crate::runtime::blocking::task::BlockingTask; +} + cfg_rt_core! { use crate::task::JoinHandle; @@ -31,7 +36,39 @@ pub struct Handle { } impl Handle { - /// Enter the runtime context. + /// Enter the runtime context. This allows you to construct types that must + /// have an executor available on creation such as [`Delay`] or [`TcpStream`]. + /// It will also allow you to call methods such as [`tokio::spawn`]. + /// + /// This function is also available as [`Runtime::enter`]. + /// + /// [`Delay`]: struct@crate::time::Delay + /// [`TcpStream`]: struct@crate::net::TcpStream + /// [`Runtime::enter`]: fn@crate::runtime::Runtime::enter + /// [`tokio::spawn`]: fn@crate::spawn + /// + /// # Example + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// fn function_that_spawns(msg: String) { + /// // Had we not used `handle.enter` below, this would panic. + /// tokio::spawn(async move { + /// println!("{}", msg); + /// }); + /// } + /// + /// fn main() { + /// let rt = Runtime::new().unwrap(); + /// let handle = rt.handle().clone(); + /// + /// let s = "Hello World!".to_string(); + /// + /// // By entering the context, we tie `tokio::spawn` to this executor. + /// handle.enter(|| function_that_spawns(s)); + /// } + /// ``` pub fn enter(&self, f: F) -> R where F: FnOnce() -> R, @@ -39,11 +76,13 @@ impl Handle { context::enter(self.clone(), f) } - /// Returns a Handle view over the currently running Runtime + /// Returns a `Handle` view over the currently running `Runtime` /// /// # Panic /// - /// This will panic if called outside the context of a Tokio runtime. + /// This will panic if called outside the context of a Tokio runtime. That means that you must + /// call this on one of the threads **being run by the runtime**. Calling this from within a + /// thread created by `std::thread::spawn` (for example) will cause a panic. /// /// # Examples /// @@ -51,6 +90,7 @@ impl Handle { /// block or function running on that runtime. /// /// ``` + /// # use std::thread; /// # use tokio::runtime::Runtime; /// # fn dox() { /// # let rt = Runtime::new().unwrap(); @@ -61,7 +101,16 @@ impl Handle { /// let handle = Handle::current(); /// handle.spawn(async { /// println!("now running in the existing Runtime"); - /// }) + /// }); + /// + /// # let handle = + /// thread::spawn(move || { + /// // Notice that the handle is created outside of this thread and then moved in + /// handle.block_on(async { /* ... */ }) + /// // This next line would cause a panic + /// // let handle2 = Handle::current(); + /// }); + /// # handle.join().unwrap(); /// # }); /// # } /// ``` @@ -110,8 +159,14 @@ cfg_rt_core! { /// /// # Panics /// - /// This function panics if the spawn fails. Failure occurs if the executor - /// is currently at capacity and is unable to spawn a new future. + /// This function will not panic unless task execution is disabled on the + /// executor. This can only happen if the runtime was built using + /// [`Builder`] without picking either [`basic_scheduler`] or + /// [`threaded_scheduler`]. + /// + /// [`Builder`]: struct@crate::runtime::Builder + /// [`threaded_scheduler`]: fn@crate::runtime::Builder::threaded_scheduler + /// [`basic_scheduler`]: fn@crate::runtime::Builder::basic_scheduler pub fn spawn(&self, future: F) -> JoinHandle where F: Future + Send + 'static, @@ -119,6 +174,182 @@ cfg_rt_core! { { self.spawner.spawn(future) } + + /// Run a future to completion on the Tokio runtime from a synchronous + /// context. + /// + /// This runs the given future on the runtime, blocking until it is + /// complete, and yielding its resolved result. Any tasks or timers which + /// the future spawns internally will be executed on the runtime. + /// + /// If the provided executor currently has no active core thread, this + /// function might hang until a core thread is added. This is not a + /// concern when using the [threaded scheduler], as it always has active + /// core threads, but if you use the [basic scheduler], some other + /// thread must currently be inside a call to [`Runtime::block_on`]. + /// See also [the module level documentation][1], which has a section on + /// scheduler types. + /// + /// This method may not be called from an asynchronous context. + /// + /// [threaded scheduler]: fn@crate::runtime::Builder::threaded_scheduler + /// [basic scheduler]: fn@crate::runtime::Builder::basic_scheduler + /// [`Runtime::block_on`]: fn@crate::runtime::Runtime::block_on + /// [1]: index.html#runtime-configurations + /// + /// # Panics + /// + /// This function panics if the provided future panics, or if called + /// within an asynchronous execution context. + /// + /// # Examples + /// + /// Using `block_on` with the [threaded scheduler]. + /// + /// ``` + /// use tokio::runtime::Runtime; + /// use std::thread; + /// + /// // Create the runtime. + /// // + /// // If the rt-threaded feature is enabled, this creates a threaded + /// // scheduler by default. + /// let rt = Runtime::new().unwrap(); + /// let handle = rt.handle().clone(); + /// + /// // Use the runtime from another thread. + /// let th = thread::spawn(move || { + /// // Execute the future, blocking the current thread until completion. + /// // + /// // This example uses the threaded scheduler, so no concurrent call to + /// // `rt.block_on` is required. + /// handle.block_on(async { + /// println!("hello"); + /// }); + /// }); + /// + /// th.join().unwrap(); + /// ``` + /// + /// Using the [basic scheduler] requires a concurrent call to + /// [`Runtime::block_on`]: + /// + /// [threaded scheduler]: fn@crate::runtime::Builder::threaded_scheduler + /// [basic scheduler]: fn@crate::runtime::Builder::basic_scheduler + /// [`Runtime::block_on`]: fn@crate::runtime::Runtime::block_on + /// + /// ``` + /// use tokio::runtime::Builder; + /// use tokio::sync::oneshot; + /// use std::thread; + /// + /// // Create the runtime. + /// let mut rt = Builder::new() + /// .enable_all() + /// .basic_scheduler() + /// .build() + /// .unwrap(); + /// + /// let handle = rt.handle().clone(); + /// + /// // Signal main thread when task has finished. + /// let (send, recv) = oneshot::channel(); + /// + /// // Use the runtime from another thread. + /// let th = thread::spawn(move || { + /// // Execute the future, blocking the current thread until completion. + /// handle.block_on(async { + /// send.send("done").unwrap(); + /// }); + /// }); + /// + /// // The basic scheduler is used, so the thread above might hang if we + /// // didn't call block_on on the rt too. + /// rt.block_on(async { + /// assert_eq!(recv.await.unwrap(), "done"); + /// }); + /// # th.join().unwrap(); + /// ``` + /// + pub fn block_on(&self, future: F) -> F::Output { + self.enter(|| { + let mut enter = crate::runtime::enter(true); + enter.block_on(future).expect("failed to park thread") + }) + } + } +} + +cfg_blocking! { + impl Handle { + /// Runs the provided closure on a thread where blocking is acceptable. + /// + /// In general, issuing a blocking call or performing a lot of compute in a + /// future without yielding is not okay, as it may prevent the executor from + /// driving other futures forward. This function runs the provided closure + /// on a thread dedicated to blocking operations. See the [CPU-bound tasks + /// and blocking code][blocking] section for more information. + /// + /// Tokio will spawn more blocking threads when they are requested through + /// this function until the upper limit configured on the [`Builder`] is + /// reached. This limit is very large by default, because `spawn_blocking` is + /// often used for various kinds of IO operations that cannot be performed + /// asynchronously. When you run CPU-bound code using `spawn_blocking`, you + /// should keep this large upper limit in mind; to run your CPU-bound + /// computations on only a few threads, you should use a separate thread + /// pool such as [rayon] rather than configuring the number of blocking + /// threads. + /// + /// This function is intended for non-async operations that eventually + /// finish on their own. If you want to spawn an ordinary thread, you should + /// use [`thread::spawn`] instead. + /// + /// Closures spawned using `spawn_blocking` cannot be cancelled. When you + /// shut down the executor, it will wait indefinitely for all blocking + /// operations to finish. You can use [`shutdown_timeout`] to stop waiting + /// for them after a certain timeout. Be aware that this will still not + /// cancel the tasks — they are simply allowed to keep running after the + /// method returns. + /// + /// Note that if you are using the [basic scheduler], this function will + /// still spawn additional threads for blocking operations. The basic + /// scheduler's single thread is only used for asynchronous code. + /// + /// [`Builder`]: struct@crate::runtime::Builder + /// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code + /// [rayon]: https://docs.rs/rayon + /// [basic scheduler]: fn@crate::runtime::Builder::basic_scheduler + /// [`thread::spawn`]: fn@std::thread::spawn + /// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// # async fn docs() -> Result<(), Box>{ + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// let handle = rt.handle(); + /// + /// let res = handle.spawn_blocking(move || { + /// // do some compute-heavy work or call synchronous code + /// "done computing" + /// }).await?; + /// + /// assert_eq!(res, "done computing"); + /// # Ok(()) + /// # } + /// ``` + pub fn spawn_blocking(&self, f: F) -> JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let (task, handle) = task::joinable(BlockingTask::new(f)); + let _ = self.blocking_spawner.spawn(task, self); + handle + } } } diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index aedc328050e..300a14657bf 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -245,7 +245,7 @@ use std::time::Duration; /// The Tokio runtime. /// -/// The runtime provides an I/O [driver], task scheduler, [timer], and blocking +/// The runtime provides an I/O driver, task scheduler, [timer], and blocking /// pool, necessary for running asynchronous tasks. /// /// Instances of `Runtime` can be created using [`new`] or [`Builder`]. However, @@ -266,12 +266,11 @@ use std::time::Duration; /// that reactor will no longer function. Calling any method on them will /// result in an error. /// -/// [driver]: crate::io::driver /// [timer]: crate::time /// [mod]: index.html -/// [`new`]: #method.new -/// [`Builder`]: struct.Builder.html -/// [`tokio::run`]: fn.run.html +/// [`new`]: method@Self::new +/// [`Builder`]: struct@Builder +/// [`tokio::run`]: fn@run #[derive(Debug)] pub struct Runtime { /// Task executor @@ -335,7 +334,7 @@ impl Runtime { /// ``` /// /// [mod]: index.html - /// [main]: ../../tokio_macros/attr.main.html + /// [main]: ../attr.main.html /// [threaded scheduler]: index.html#threaded-scheduler /// [basic scheduler]: index.html#basic-scheduler /// [runtime builder]: crate::runtime::Builder @@ -380,8 +379,14 @@ impl Runtime { /// /// # Panics /// - /// This function panics if the spawn fails. Failure occurs if the executor - /// is currently at capacity and is unable to spawn a new future. + /// This function will not panic unless task execution is disabled on the + /// executor. This can only happen if the runtime was built using + /// [`Builder`] without picking either [`basic_scheduler`] or + /// [`threaded_scheduler`]. + /// + /// [`Builder`]: struct@Builder + /// [`threaded_scheduler`]: fn@Builder::threaded_scheduler + /// [`basic_scheduler`]: fn@Builder::basic_scheduler #[cfg(feature = "rt-core")] pub fn spawn(&self, future: F) -> JoinHandle where @@ -403,12 +408,33 @@ impl Runtime { /// complete, and yielding its resolved result. Any tasks or timers which /// the future spawns internally will be executed on the runtime. /// - /// This method should not be called from an asynchronous context. + /// `&mut` is required as calling `block_on` **may** result in advancing the + /// state of the runtime. The details depend on how the runtime is + /// configured. [`runtime::Handle::block_on`][handle] provides a version + /// that takes `&self`. + /// + /// This method may not be called from an asynchronous context. /// /// # Panics /// - /// This function panics if the executor is at capacity, if the provided - /// future panics, or if called within an asynchronous execution context. + /// This function panics if the provided future panics, or if called within an + /// asynchronous execution context. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::runtime::Runtime; + /// + /// // Create the runtime + /// let mut rt = Runtime::new().unwrap(); + /// + /// // Execute the future, blocking the current thread until completion + /// rt.block_on(async { + /// println!("hello"); + /// }); + /// ``` + /// + /// [handle]: fn@Handle::block_on pub fn block_on(&mut self, future: F) -> F::Output { let kind = &mut self.kind; @@ -421,7 +447,38 @@ impl Runtime { }) } - /// Enter the runtime context. + /// Enter the runtime context. This allows you to construct types that must + /// have an executor available on creation such as [`Delay`] or [`TcpStream`]. + /// It will also allow you to call methods such as [`tokio::spawn`]. + /// + /// This function is also available as [`Handle::enter`]. + /// + /// [`Delay`]: struct@crate::time::Delay + /// [`TcpStream`]: struct@crate::net::TcpStream + /// [`Handle::enter`]: fn@crate::runtime::Handle::enter + /// [`tokio::spawn`]: fn@crate::spawn + /// + /// # Example + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// fn function_that_spawns(msg: String) { + /// // Had we not used `rt.enter` below, this would panic. + /// tokio::spawn(async move { + /// println!("{}", msg); + /// }); + /// } + /// + /// fn main() { + /// let rt = Runtime::new().unwrap(); + /// + /// let s = "Hello World!".to_string(); + /// + /// // By entering the context, we tie `tokio::spawn` to this executor. + /// rt.enter(|| function_that_spawns(s)); + /// } + /// ``` pub fn enter(&self, f: F) -> R where F: FnOnce() -> R, @@ -456,7 +513,7 @@ impl Runtime { /// Usually, dropping a `Runtime` handle is sufficient as tasks are able to /// shutdown in a timely fashion. However, dropping a `Runtime` will wait /// indefinitely for all tasks to terminate, and there are cases where a long - /// blocking task has been spawned which can block dropping `Runtime`. + /// blocking task has been spawned, which can block dropping `Runtime`. /// /// In this case, calling `shutdown_timeout` with an explicit wait timeout /// can work. The `shutdown_timeout` will signal all tasks to shutdown and @@ -491,4 +548,34 @@ impl Runtime { } = self; blocking_pool.shutdown(Some(duration)); } + + /// Shutdown the runtime, without waiting for any spawned tasks to shutdown. + /// + /// This can be useful if you want to drop a runtime from within another runtime. + /// Normally, dropping a runtime will block indefinitely for spawned blocking tasks + /// to complete, which would normally not be permitted within an asynchronous context. + /// By calling `shutdown_background()`, you can drop the runtime from such a context. + /// + /// Note however, that because we do not wait for any blocking tasks to complete, this + /// may result in a resource leak (in that any blocking tasks are still running until they + /// return. + /// + /// This function is equivalent to calling `shutdown_timeout(Duration::of_nanos(0))`. + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// fn main() { + /// let mut runtime = Runtime::new().unwrap(); + /// + /// runtime.block_on(async move { + /// let inner_runtime = Runtime::new().unwrap(); + /// // ... + /// inner_runtime.shutdown_background(); + /// }); + /// } + /// ``` + pub fn shutdown_background(self) { + self.shutdown_timeout(Duration::from_nanos(0)) + } } diff --git a/tokio/src/runtime/shell.rs b/tokio/src/runtime/shell.rs index 294f2a16d8a..a65869d0de2 100644 --- a/tokio/src/runtime/shell.rs +++ b/tokio/src/runtime/shell.rs @@ -32,7 +32,7 @@ impl Shell { where F: Future, { - let _e = enter(); + let _e = enter(true); pin!(f); diff --git a/tokio/src/runtime/task/core.rs b/tokio/src/runtime/task/core.rs index 573b9f3c9cb..f4756c238ef 100644 --- a/tokio/src/runtime/task/core.rs +++ b/tokio/src/runtime/task/core.rs @@ -1,3 +1,14 @@ +//! Core task module. +//! +//! # Safety +//! +//! The functions in this module are private to the `task` module. All of them +//! should be considered `unsafe` to use, but are not marked as such since it +//! would be too noisy. +//! +//! Make sure to consult the relevant safety section of each function before +//! use. + use crate::loom::cell::UnsafeCell; use crate::runtime::task::raw::{self, Vtable}; use crate::runtime::task::state::State; @@ -95,15 +106,16 @@ impl Cell { } impl Core { - /// If needed, bind a scheduler to the task. + /// Bind a scheduler to the task. + /// + /// This only happens on the first poll and must be preceeded by a call to + /// `is_bound` to determine if binding is appropriate or not. /// - /// This only happens on the first poll. + /// # Safety + /// + /// Binding must not be done concurrently since it will mutate the task + /// core through a shared reference. pub(super) fn bind_scheduler(&self, task: Task) { - use std::mem::ManuallyDrop; - - // TODO: it would be nice to not have to wrap with a ManuallyDrop - let task = ManuallyDrop::new(task); - // This function may be called concurrently, but the __first__ time it // is called, the caller has unique access to this field. All subsequent // concurrent calls will be via the `Waker`, which will "happens after" @@ -111,12 +123,10 @@ impl Core { // // In other words, it is always safe to read the field and it is safe to // write to the field when it is `None`. - if self.is_bound() { - return; - } + debug_assert!(!self.is_bound()); // Bind the task to the scheduler - let scheduler = S::bind(ManuallyDrop::into_inner(task)); + let scheduler = S::bind(task); // Safety: As `scheduler` is not set, this is the first poll self.scheduler.with_mut(|ptr| unsafe { diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 29b231ea885..e86b29e699e 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -51,13 +51,13 @@ where // If this is the first time the task is polled, the task will be bound // to the scheduler, in which case the task ref count must be // incremented. - let ref_inc = !self.core().is_bound(); + let is_not_bound = !self.core().is_bound(); // Transition the task to the running state. // // A failure to transition here indicates the task has been cancelled // while in the run queue pending execution. - let snapshot = match self.header().state.transition_to_running(ref_inc) { + let snapshot = match self.header().state.transition_to_running(is_not_bound) { Ok(snapshot) => snapshot, Err(_) => { // The task was shutdown while in the run queue. At this point, @@ -67,15 +67,20 @@ where } }; - // Ensure the task is bound to a scheduler instance. If this is the - // first time polling the task, a scheduler instance is pulled from the - // local context and assigned to the task. - // - // The scheduler maintains ownership of the task and responds to `wake` - // calls. - // - // The task reference count has been incremented. - self.core().bind_scheduler(self.to_task()); + if is_not_bound { + // Ensure the task is bound to a scheduler instance. Since this is + // the first time polling the task, a scheduler instance is pulled + // from the local context and assigned to the task. + // + // The scheduler maintains ownership of the task and responds to + // `wake` calls. + // + // The task reference count has been incremented. + // + // Safety: Since we have unique access to the task so that we can + // safely call `bind_scheduler`. + self.core().bind_scheduler(self.to_task()); + } // The transition to `Running` done above ensures that a lock on the // future has been obtained. This also ensures the `*mut T` pointer @@ -84,21 +89,15 @@ where let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { struct Guard<'a, T: Future, S: Schedule> { core: &'a Core, - polled: bool, } impl Drop for Guard<'_, T, S> { fn drop(&mut self) { - if !self.polled { - self.core.drop_future_or_output(); - } + self.core.drop_future_or_output(); } } - let mut guard = Guard { - core: self.core(), - polled: false, - }; + let guard = Guard { core: self.core() }; // If the task is cancelled, avoid polling it, instead signalling it // is complete. @@ -108,7 +107,7 @@ where let res = guard.core.poll(self.header()); // prevent the guard from dropping the future - guard.polled = true; + mem::forget(guard); res.map(Ok) } diff --git a/tokio/src/runtime/task/join.rs b/tokio/src/runtime/task/join.rs index fdcc346e5c1..3c4aabb2e84 100644 --- a/tokio/src/runtime/task/join.rs +++ b/tokio/src/runtime/task/join.rs @@ -102,7 +102,7 @@ impl Future for JoinHandle { let mut ret = Poll::Pending; // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); // Raw should always be set. If it is not, this is due to polling after // completion @@ -126,6 +126,10 @@ impl Future for JoinHandle { raw.try_read_output(&mut ret as *mut _ as *mut (), cx.waker()); } + if ret.is_ready() { + coop.made_progress(); + } + ret } } diff --git a/tokio/src/runtime/thread_pool/mod.rs b/tokio/src/runtime/thread_pool/mod.rs index 87a75e3fbde..ced9712d9ef 100644 --- a/tokio/src/runtime/thread_pool/mod.rs +++ b/tokio/src/runtime/thread_pool/mod.rs @@ -36,7 +36,7 @@ pub(crate) struct ThreadPool { /// /// `Spawner` instances are obtained by calling [`ThreadPool::spawner`]. /// -/// [`ThreadPool::spawner`]: struct.ThreadPool.html#method.spawner +/// [`ThreadPool::spawner`]: method@ThreadPool::spawner #[derive(Clone)] pub(crate) struct Spawner { shared: Arc, @@ -78,7 +78,7 @@ impl ThreadPool { where F: Future, { - let mut enter = crate::runtime::enter(); + let mut enter = crate::runtime::enter(true); enter.block_on(future).expect("failed to park thread") } } diff --git a/tokio/src/runtime/thread_pool/worker.rs b/tokio/src/runtime/thread_pool/worker.rs index 400e2a938ca..abe20da59cf 100644 --- a/tokio/src/runtime/thread_pool/worker.rs +++ b/tokio/src/runtime/thread_pool/worker.rs @@ -4,6 +4,7 @@ //! "core" is handed off to a new thread allowing the scheduler to continue to //! make progress while the originating thread blocks. +use crate::coop; use crate::loom::rand::seed; use crate::loom::sync::{Arc, Mutex}; use crate::park::{Park, Unpark}; @@ -172,35 +173,70 @@ pub(super) fn create(size: usize, park: Parker) -> (Arc, Launch) { } cfg_blocking! { + use crate::runtime::enter::EnterContext; + pub(crate) fn block_in_place(f: F) -> R where F: FnOnce() -> R, { // Try to steal the worker core back - struct Reset; + struct Reset(coop::Budget); impl Drop for Reset { fn drop(&mut self) { CURRENT.with(|maybe_cx| { if let Some(cx) = maybe_cx { let core = cx.worker.core.take(); - *cx.core.borrow_mut() = core; + let mut cx_core = cx.core.borrow_mut(); + assert!(cx_core.is_none()); + *cx_core = core; + + // Reset the task budget as we are re-entering the + // runtime. + coop::set(self.0); } }); } } + let mut had_core = false; + let mut had_entered = false; + CURRENT.with(|maybe_cx| { - let cx = maybe_cx.expect("can call blocking only when running in a spawned task"); + match (crate::runtime::enter::context(), maybe_cx.is_some()) { + (EnterContext::Entered { .. }, true) => { + // We are on a thread pool runtime thread, so we just need to set up blocking. + had_entered = true; + } + (EnterContext::Entered { allow_blocking }, false) => { + // We are on an executor, but _not_ on the thread pool. + // That is _only_ okay if we are in a thread pool runtime's block_on method: + if allow_blocking { + had_entered = true; + return; + } else { + // This probably means we are on the basic_scheduler or in a LocalSet, + // where it is _not_ okay to block. + panic!("can call blocking only when running on the multi-threaded runtime"); + } + } + (EnterContext::NotEntered, true) => { + // This is a nested call to block_in_place (we already exited). + // All the necessary setup has already been done. + return; + } + (EnterContext::NotEntered, false) => { + // We are outside of the tokio runtime, so blocking is fine. + // We can also skip all of the thread pool blocking setup steps. + return; + } + } + + let cx = maybe_cx.expect("no .is_some() == false cases above should lead here"); // Get the worker core. If none is set, then blocking is fine! let core = match cx.core.borrow_mut().take() { - Some(core) => { - // We are effectively leaving the executor, so we need to - // forcibly end budgeting. - crate::coop::stop(); - core - }, + Some(core) => core, None => return, }; @@ -212,6 +248,7 @@ cfg_blocking! { // // First, move the core back into the worker's shared core slot. cx.worker.core.set(core); + had_core = true; // Next, clone the worker handle and send it to a new thread for // processing. @@ -222,9 +259,17 @@ cfg_blocking! { runtime::spawn_blocking(move || run(worker)); }); - let _reset = Reset; + if had_core { + // Unset the current task's budget. Blocking sections are not + // constrained by task budgets. + let _reset = Reset(coop::stop()); - f() + crate::runtime::enter::exit(f) + } else if had_entered { + crate::runtime::enter::exit(f) + } else { + f() + } } } @@ -256,7 +301,7 @@ fn run(worker: Arc) { core: RefCell::new(None), }; - let _enter = crate::runtime::enter(); + let _enter = crate::runtime::enter(true); CURRENT.set(&cx, || { // This should always be an error. It only returns a `Result` to support @@ -304,7 +349,7 @@ impl Context { *self.core.borrow_mut() = Some(core); // Run the task - crate::coop::budget(|| { + coop::budget(|| { task.run(); // As long as there is budget remaining and a task exists in the @@ -323,7 +368,7 @@ impl Context { None => return Ok(core), }; - if crate::coop::has_budget_remaining() { + if coop::has_budget_remaining() { // Run the LIFO task, then loop *self.core.borrow_mut() = Some(core); task.run(); @@ -530,7 +575,7 @@ impl Core { } // Drain the queue - while let Some(_) = self.next_local_task() {} + while self.next_local_task().is_some() {} } fn drain_pending_drop(&mut self, worker: &Worker) { @@ -752,7 +797,7 @@ impl Shared { } // Drain the injection queue - while let Some(_) = self.inject.pop() {} + while self.inject.pop().is_some() {} } fn ptr_eq(&self, other: &Shared) -> bool { diff --git a/tokio/src/signal/ctrl_c.rs b/tokio/src/signal/ctrl_c.rs index 2240052ce9f..1eeeb85aa17 100644 --- a/tokio/src/signal/ctrl_c.rs +++ b/tokio/src/signal/ctrl_c.rs @@ -11,9 +11,9 @@ use std::io; /// platforms support receiving a signal on "ctrl-c". This function provides a /// portable API for receiving this notification. /// -/// Once the returned future is polled, a listener a listener is registered. The -/// future will complete on the first received `ctrl-c` **after** the initial -/// call to either `Future::poll` or `.await`. +/// Once the returned future is polled, a listener is registered. The future +/// will complete on the first received `ctrl-c` **after** the initial call to +/// either `Future::poll` or `.await`. /// /// # Caveats /// diff --git a/tokio/src/signal/unix.rs b/tokio/src/signal/unix.rs index 06f5cf4eba7..b46b15c99a6 100644 --- a/tokio/src/signal/unix.rs +++ b/tokio/src/signal/unix.rs @@ -401,7 +401,7 @@ pub struct Signal { /// * If the lower-level C functions fail for some reason. /// * If the previous initialization of this specific signal failed. /// * If the signal is one of -/// [`signal_hook::FORBIDDEN`](https://docs.rs/signal-hook/*/signal_hook/fn.register.html#panics) +/// [`signal_hook::FORBIDDEN`](fn@signal_hook_registry::register#panics) pub fn signal(kind: SignalKind) -> io::Result { let signal = kind.0; diff --git a/tokio/src/stream/chain.rs b/tokio/src/stream/chain.rs index 5f0324a4b56..6124c91e44f 100644 --- a/tokio/src/stream/chain.rs +++ b/tokio/src/stream/chain.rs @@ -44,14 +44,6 @@ where } fn size_hint(&self) -> (usize, Option) { - let (a_lower, a_upper) = self.a.size_hint(); - let (b_lower, b_upper) = self.b.size_hint(); - - let upper = match (a_upper, b_upper) { - (Some(a_upper), Some(b_upper)) => Some(a_upper + b_upper), - _ => None, - }; - - (a_lower + b_lower, upper) + super::merge_size_hints(self.a.size_hint(), self.b.size_hint()) } } diff --git a/tokio/src/stream/collect.rs b/tokio/src/stream/collect.rs index f44c72b7b36..46494287cd8 100644 --- a/tokio/src/stream/collect.rs +++ b/tokio/src/stream/collect.rs @@ -32,7 +32,7 @@ pin_project! { /// /// Currently, this trait may not be implemented by third parties. The trait is /// sealed in order to make changes in the future. Stabilization is pending -/// enhancements to the Rust langague. +/// enhancements to the Rust language. pub trait FromStream: sealed::FromStreamPriv {} impl Collect diff --git a/tokio/src/stream/empty.rs b/tokio/src/stream/empty.rs index 6118673e504..2f56ac6cad3 100644 --- a/tokio/src/stream/empty.rs +++ b/tokio/src/stream/empty.rs @@ -4,7 +4,7 @@ use core::marker::PhantomData; use core::pin::Pin; use core::task::{Context, Poll}; -/// Stream for the [`empty`] function. +/// Stream for the [`empty`](fn@empty) function. #[derive(Debug)] #[must_use = "streams do nothing unless polled"] pub struct Empty(PhantomData); diff --git a/tokio/src/stream/iter.rs b/tokio/src/stream/iter.rs index 36eeb5612f7..bc0388a1442 100644 --- a/tokio/src/stream/iter.rs +++ b/tokio/src/stream/iter.rs @@ -3,7 +3,7 @@ use crate::stream::Stream; use core::pin::Pin; use core::task::{Context, Poll}; -/// Stream for the [`iter`] function. +/// Stream for the [`iter`](fn@iter) function. #[derive(Debug)] #[must_use = "streams do nothing unless polled"] pub struct Iter { @@ -45,7 +45,8 @@ where type Item = I::Item; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); + coop.made_progress(); Poll::Ready(self.iter.next()) } diff --git a/tokio/src/stream/merge.rs b/tokio/src/stream/merge.rs index 4850cd40c72..50ba518ce39 100644 --- a/tokio/src/stream/merge.rs +++ b/tokio/src/stream/merge.rs @@ -52,15 +52,7 @@ where } fn size_hint(&self) -> (usize, Option) { - let (a_lower, a_upper) = self.a.size_hint(); - let (b_lower, b_upper) = self.b.size_hint(); - - let upper = match (a_upper, b_upper) { - (Some(a_upper), Some(b_upper)) => Some(a_upper + b_upper), - _ => None, - }; - - (a_lower + b_lower, upper) + super::merge_size_hints(self.a.size_hint(), self.b.size_hint()) } } diff --git a/tokio/src/stream/mod.rs b/tokio/src/stream/mod.rs index 307ead5fba6..7b061efeb69 100644 --- a/tokio/src/stream/mod.rs +++ b/tokio/src/stream/mod.rs @@ -192,12 +192,17 @@ pub trait StreamExt: Stream { /// Values are produced from the merged stream in the order they arrive from /// the two source streams. If both source streams provide values /// simultaneously, the merge stream alternates between them. This provides - /// some level of fairness. + /// some level of fairness. You should not chain calls to `merge`, as this + /// will break the fairness of the merging. /// /// The merged stream completes once **both** source streams complete. When /// one source stream completes before the other, the merge stream /// exclusively polls the remaining stream. /// + /// For merging multiple streams, consider using [`StreamMap`] instead. + /// + /// [`StreamMap`]: crate::stream::StreamMap + /// /// # Examples /// /// ``` @@ -303,7 +308,7 @@ pub trait StreamExt: Stream { /// As values of this stream are made available, the provided function will /// be run on them. If the predicate `f` resolves to /// [`Some(item)`](Some) then the stream will yield the value `item`, but if - /// it resolves to [`None`] then the next value will be produced. + /// it resolves to [`None`], then the value will be skipped. /// /// Note that this function consumes the stream passed into it and returns a /// wrapped version of it, similar to [`Iterator::filter_map`] method in the @@ -697,7 +702,7 @@ pub trait StreamExt: Stream { /// # Notes /// /// `FromStream` is currently a sealed trait. Stabilization is pending - /// enhancements to the Rust langague. + /// enhancements to the Rust language. /// /// # Examples /// @@ -817,3 +822,16 @@ pub trait StreamExt: Stream { } impl StreamExt for St where St: Stream {} + +/// Merge the size hints from two streams. +fn merge_size_hints( + (left_low, left_high): (usize, Option), + (right_low, right_hign): (usize, Option), +) -> (usize, Option) { + let low = left_low.saturating_add(right_low); + let high = match (left_high, right_hign) { + (Some(h1), Some(h2)) => h1.checked_add(h2), + _ => None, + }; + (low, high) +} diff --git a/tokio/src/stream/once.rs b/tokio/src/stream/once.rs index 04a642f309a..7fe204cc127 100644 --- a/tokio/src/stream/once.rs +++ b/tokio/src/stream/once.rs @@ -4,7 +4,7 @@ use core::option; use core::pin::Pin; use core::task::{Context, Poll}; -/// Stream for the [`once`] function. +/// Stream for the [`once`](fn@once) function. #[derive(Debug)] #[must_use = "streams do nothing unless polled"] pub struct Once { diff --git a/tokio/src/stream/pending.rs b/tokio/src/stream/pending.rs index 2e06a1c2612..21224c38596 100644 --- a/tokio/src/stream/pending.rs +++ b/tokio/src/stream/pending.rs @@ -4,7 +4,7 @@ use core::marker::PhantomData; use core::pin::Pin; use core::task::{Context, Poll}; -/// Stream for the [`pending`] function. +/// Stream for the [`pending`](fn@pending) function. #[derive(Debug)] #[must_use = "streams do nothing unless polled"] pub struct Pending(PhantomData); diff --git a/tokio/src/sync/batch_semaphore.rs b/tokio/src/sync/batch_semaphore.rs index 436737a6709..8cd1cdd90ba 100644 --- a/tokio/src/sync/batch_semaphore.rs +++ b/tokio/src/sync/batch_semaphore.rs @@ -123,7 +123,9 @@ impl Semaphore { self.permits.load(Acquire) >> Self::PERMIT_SHIFT } - /// Adds `n` new permits to the semaphore. + /// Adds `added` new permits to the semaphore. + /// + /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. pub(crate) fn release(&self, added: usize) { if added == 0 { return; @@ -186,7 +188,7 @@ impl Semaphore { /// Release `rem` permits to the semaphore's wait list, starting from the /// end of the queue. - /// + /// /// If `rem` exceeds the number of permits needed by the wait list, the /// remainder are assigned back to the semaphore. fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) { @@ -386,13 +388,18 @@ impl Future for Acquire<'_> { type Output = Result<(), AcquireError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // First, ensure the current task has enough budget to proceed. + let coop = ready!(crate::coop::poll_proceed(cx)); + let (node, semaphore, needed, queued) = self.project(); + match semaphore.poll_acquire(cx, needed, node, *queued) { Pending => { *queued = true; Pending } Ready(r) => { + coop.made_progress(); r?; *queued = false; Ready(Ok(())) @@ -512,8 +519,8 @@ impl TryAcquireError { impl fmt::Display for TryAcquireError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - TryAcquireError::Closed => write!(fmt, "{}", "semaphore closed"), - TryAcquireError::NoPermits => write!(fmt, "{}", "no permits available"), + TryAcquireError::Closed => write!(fmt, "semaphore closed"), + TryAcquireError::NoPermits => write!(fmt, "no permits available"), } } } diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 05a58070ee5..0c8716f7795 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -109,13 +109,15 @@ //! } use crate::loom::cell::UnsafeCell; -use crate::loom::future::AtomicWaker; -use crate::loom::sync::atomic::{spin_loop_hint, AtomicBool, AtomicPtr, AtomicUsize}; -use crate::loom::sync::{Arc, Condvar, Mutex}; +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard}; +use crate::util::linked_list::{self, LinkedList}; use std::fmt; -use std::mem; -use std::ptr; +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::ptr::NonNull; use std::sync::atomic::Ordering::SeqCst; use std::task::{Context, Poll, Waker}; use std::usize; @@ -193,8 +195,8 @@ pub struct Receiver { /// Next position to read from next: u64, - /// Waiter state - wait: Arc, + /// Used to support the deprecated `poll_recv` fn + waiter: Option>>>, } /// Error returned by [`Sender::send`][Sender::send]. @@ -247,20 +249,14 @@ pub enum TryRecvError { /// Data shared between senders and receivers struct Shared { /// slots in the channel - buffer: Box<[Slot]>, + buffer: Box<[RwLock>]>, /// Mask a position -> index mask: usize, - /// Tail of the queue + /// Tail of the queue. Includes the rx wait list. tail: Mutex, - /// Notifies a sender that the slot is unlocked - condvar: Condvar, - - /// Stack of pending waiters - wait_stack: AtomicPtr, - /// Number of outstanding Sender handles num_tx: AtomicUsize, } @@ -272,6 +268,12 @@ struct Tail { /// Number of active receivers rx_cnt: usize, + + /// True if the channel is closed + closed: bool, + + /// Receivers waiting for a value + waiters: LinkedList, } /// Slot in the buffer @@ -279,47 +281,79 @@ struct Slot { /// Remaining number of receivers that are expected to see this value. /// /// When this goes to zero, the value is released. + /// + /// An atomic is used as it is mutated concurrently with the slot read lock + /// acquired. rem: AtomicUsize, - /// Used to lock the `write` field. - lock: AtomicUsize, + /// Uniquely identifies the `send` stored in the slot + pos: u64, + + /// True signals the channel is closed. + closed: bool, - /// The value being broadcast + /// The value being broadcast. /// - /// Synchronized by `state` - write: Write, + /// The value is set by `send` when the write lock is held. When a reader + /// drops, `rem` is decremented. When it hits zero, the value is dropped. + val: UnsafeCell>, } -/// A write in the buffer -struct Write { - /// Uniquely identifies this write - pos: UnsafeCell, +/// An entry in the wait queue +struct Waiter { + /// True if queued + queued: bool, - /// The written value - val: UnsafeCell>, + /// Task waiting on the broadcast channel. + waker: Option, + + /// Intrusive linked-list pointers. + pointers: linked_list::Pointers, + + /// Should not be `Unpin`. + _p: PhantomPinned, } -/// Tracks a waiting receiver -#[derive(Debug)] -struct WaitNode { - /// `true` if queued - queued: AtomicBool, +struct RecvGuard<'a, T> { + slot: RwLockReadGuard<'a, Slot>, +} - /// Task to wake when a permit is made available. - waker: AtomicWaker, +/// Receive a value future +struct Recv +where + R: AsMut>, +{ + /// Receiver being waited on + receiver: R, - /// Next pointer in the stack of waiting senders. - next: UnsafeCell<*const WaitNode>, + /// Entry in the waiter `LinkedList` + waiter: UnsafeCell, + + _p: std::marker::PhantomData, } -struct RecvGuard<'a, T> { - slot: &'a Slot, - tail: &'a Mutex, - condvar: &'a Condvar, +/// `AsMut` is not implemented for `T` (coherence). Explicitly implementing +/// `AsMut` for `Receiver` would be included in the public API of the receiver +/// type. Instead, `Borrow` is used internally to bridge the gap. +struct Borrow(T); + +impl AsMut> for Borrow> { + fn as_mut(&mut self) -> &mut Receiver { + &mut self.0 + } +} + +impl<'a, T> AsMut> for Borrow<&'a mut Receiver> { + fn as_mut(&mut self) -> &mut Receiver { + &mut *self.0 + } } +unsafe impl> + Send, T: Send> Send for Recv {} +unsafe impl> + Sync, T: Send> Sync for Recv {} + /// Max number of receivers. Reserve space to lock. -const MAX_RECEIVERS: usize = usize::MAX >> 1; +const MAX_RECEIVERS: usize = usize::MAX >> 2; /// Create a bounded, multi-producer, multi-consumer channel where each sent /// value is broadcasted to all active receivers. @@ -376,33 +410,30 @@ pub fn channel(mut capacity: usize) -> (Sender, Receiver) { let mut buffer = Vec::with_capacity(capacity); for i in 0..capacity { - buffer.push(Slot { + buffer.push(RwLock::new(Slot { rem: AtomicUsize::new(0), - lock: AtomicUsize::new(0), - write: Write { - pos: UnsafeCell::new((i as u64).wrapping_sub(capacity as u64)), - val: UnsafeCell::new(None), - }, - }); + pos: (i as u64).wrapping_sub(capacity as u64), + closed: false, + val: UnsafeCell::new(None), + })); } let shared = Arc::new(Shared { buffer: buffer.into_boxed_slice(), mask: capacity - 1, - tail: Mutex::new(Tail { pos: 0, rx_cnt: 1 }), - condvar: Condvar::new(), - wait_stack: AtomicPtr::new(ptr::null_mut()), + tail: Mutex::new(Tail { + pos: 0, + rx_cnt: 1, + closed: false, + waiters: LinkedList::new(), + }), num_tx: AtomicUsize::new(1), }); let rx = Receiver { shared: shared.clone(), next: 0, - wait: Arc::new(WaitNode { - queued: AtomicBool::new(false), - waker: AtomicWaker::new(), - next: UnsafeCell::new(ptr::null()), - }), + waiter: None, }; let tx = Sender { shared }; @@ -512,11 +543,7 @@ impl Sender { Receiver { shared, next, - wait: Arc::new(WaitNode { - queued: AtomicBool::new(false), - waker: AtomicWaker::new(), - next: UnsafeCell::new(ptr::null()), - }), + waiter: None, } } @@ -577,66 +604,47 @@ impl Sender { tail.pos = tail.pos.wrapping_add(1); // Get the slot - let slot = &self.shared.buffer[idx]; - - // Acquire the write lock - let mut prev = slot.lock.fetch_or(1, SeqCst); + let mut slot = self.shared.buffer[idx].write().unwrap(); - while prev & !1 != 0 { - // Concurrent readers, we must go to sleep - tail = self.shared.condvar.wait(tail).unwrap(); + // Track the position + slot.pos = pos; - prev = slot.lock.load(SeqCst); - - if prev & 1 == 0 { - // The writer lock bit was cleared while this thread was - // sleeping. This can only happen if a newer write happened on - // this slot by another thread. Bail early as an optimization, - // there is nothing left to do. - return Ok(rem); - } - } + // Set remaining receivers + slot.rem.with_mut(|v| *v = rem); - if tail.pos.wrapping_sub(pos) > self.shared.buffer.len() as u64 { - // There is a newer pending write to the same slot. - return Ok(rem); + // Set the closed bit if the value is `None`; otherwise write the value + if value.is_none() { + tail.closed = true; + slot.closed = true; + } else { + slot.val.with_mut(|ptr| unsafe { *ptr = value }); } - // Slot lock acquired - slot.write.pos.with_mut(|ptr| unsafe { *ptr = pos }); - slot.write.val.with_mut(|ptr| unsafe { *ptr = value }); + // Release the slot lock before notifying the receivers. + drop(slot); - // Set remaining receivers - slot.rem.store(rem, SeqCst); - - // Release the slot lock - slot.lock.store(0, SeqCst); + tail.notify_rx(); // Release the mutex. This must happen after the slot lock is released, // otherwise the writer lock bit could be cleared while another thread // is in the critical section. drop(tail); - // Notify waiting receivers - self.notify_rx(); - Ok(rem) } +} - fn notify_rx(&self) { - let mut curr = self.shared.wait_stack.swap(ptr::null_mut(), SeqCst) as *const WaitNode; - - while !curr.is_null() { - let waiter = unsafe { Arc::from_raw(curr) }; - - // Update `curr` before toggling `queued` and waking - curr = waiter.next.with(|ptr| unsafe { *ptr }); +impl Tail { + fn notify_rx(&mut self) { + while let Some(mut waiter) = self.waiters.pop_back() { + // Safety: `waiters` lock is still held. + let waiter = unsafe { waiter.as_mut() }; - // Unset queued - waiter.queued.store(false, SeqCst); + assert!(waiter.queued); + waiter.queued = false; - // Wake - waiter.waker.wake(); + let waker = waiter.waker.take().unwrap(); + waker.wake(); } } } @@ -660,48 +668,106 @@ impl Drop for Sender { impl Receiver { /// Locks the next value if there is one. - /// - /// The caller is responsible for unlocking - fn recv_ref(&mut self, spin: bool) -> Result, TryRecvError> { + fn recv_ref( + &mut self, + waiter: Option<(&UnsafeCell, &Waker)>, + ) -> Result, TryRecvError> { let idx = (self.next & self.shared.mask as u64) as usize; // The slot holding the next value to read - let slot = &self.shared.buffer[idx]; + let mut slot = self.shared.buffer[idx].read().unwrap(); - // Lock the slot - if !slot.try_rx_lock() { - if spin { - while !slot.try_rx_lock() { - spin_loop_hint(); - } - } else { + if slot.pos != self.next { + let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); + + // The receiver has read all current values in the channel and there + // is no waiter to register + if waiter.is_none() && next_pos == self.next { return Err(TryRecvError::Empty); } - } - let guard = RecvGuard { - slot, - tail: &self.shared.tail, - condvar: &self.shared.condvar, - }; - - if guard.pos() != self.next { - let pos = guard.pos(); + // Release the `slot` lock before attempting to acquire the `tail` + // lock. This is required because `send2` acquires the tail lock + // first followed by the slot lock. Acquiring the locks in reverse + // order here would result in a potential deadlock: `recv_ref` + // acquires the `slot` lock and attempts to acquire the `tail` lock + // while `send2` acquired the `tail` lock and attempts to acquire + // the slot lock. + drop(slot); + + let mut tail = self.shared.tail.lock().unwrap(); + + // Acquire slot lock again + slot = self.shared.buffer[idx].read().unwrap(); + + // Make sure the position did not change. This could happen in the + // unlikely event that the buffer is wrapped between dropping the + // read lock and acquiring the tail lock. + if slot.pos != self.next { + let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); + + if next_pos == self.next { + // Store the waker + if let Some((waiter, waker)) = waiter { + // Safety: called while locked. + unsafe { + // Only queue if not already queued + waiter.with_mut(|ptr| { + // If there is no waker **or** if the currently + // stored waker references a **different** task, + // track the tasks' waker to be notified on + // receipt of a new value. + match (*ptr).waker { + Some(ref w) if w.will_wake(waker) => {} + _ => { + (*ptr).waker = Some(waker.clone()); + } + } + + if !(*ptr).queued { + (*ptr).queued = true; + tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr)); + } + }); + } + } + + return Err(TryRecvError::Empty); + } - guard.drop_no_rem_dec(); + // At this point, the receiver has lagged behind the sender by + // more than the channel capacity. The receiver will attempt to + // catch up by skipping dropped messages and setting the + // internal cursor to the **oldest** message stored by the + // channel. + // + // However, finding the oldest position is a bit more + // complicated than `tail-position - buffer-size`. When + // the channel is closed, the tail position is incremented to + // signal a new `None` message, but `None` is not stored in the + // channel itself (see issue #2425 for why). + // + // To account for this, if the channel is closed, the tail + // position is decremented by `buffer-size + 1`. + let mut adjust = 0; + if tail.closed { + adjust = 1 + } + let next = tail + .pos + .wrapping_sub(self.shared.buffer.len() as u64 + adjust); - if pos.wrapping_add(self.shared.buffer.len() as u64) == self.next { - return Err(TryRecvError::Empty); - } else { - let tail = self.shared.tail.lock().unwrap(); - - // `tail.pos` points to the slot the **next** send writes to. - // Because a receiver is lagging, this slot also holds the - // oldest value. To make the positions match, we subtract the - // capacity. - let next = tail.pos.wrapping_sub(self.shared.buffer.len() as u64); let missed = next.wrapping_sub(self.next); + drop(tail); + + // The receiver is slow but no values have been missed + if missed == 0 { + self.next = self.next.wrapping_add(1); + + return Ok(RecvGuard { slot }); + } + self.next = next; return Err(TryRecvError::Lagged(missed)); @@ -710,7 +776,11 @@ impl Receiver { self.next = self.next.wrapping_add(1); - Ok(guard) + if slot.closed { + return Err(TryRecvError::Closed); + } + + Ok(RecvGuard { slot }) } } @@ -758,22 +828,59 @@ where /// } /// ``` pub fn try_recv(&mut self) -> Result { - let guard = self.recv_ref(false)?; + let guard = self.recv_ref(None)?; guard.clone_value().ok_or(TryRecvError::Closed) } - #[doc(hidden)] // TODO: document + #[doc(hidden)] + #[deprecated(since = "0.2.21", note = "use async fn recv()")] pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Some(value) = ok_empty(self.try_recv())? { - return Poll::Ready(Ok(value)); + use Poll::{Pending, Ready}; + + // The borrow checker prohibits calling `self.poll_ref` while passing in + // a mutable ref to a field (as it should). To work around this, + // `waiter` is first *removed* from `self` then `poll_recv` is called. + // + // However, for safety, we must ensure that `waiter` is **not** dropped. + // It could be contained in the intrusive linked list. The `Receiver` + // drop implementation handles cleanup. + // + // The guard pattern is used to ensure that, on return, even due to + // panic, the waiter node is replaced on `self`. + + struct Guard<'a, T> { + waiter: Option>>>, + receiver: &'a mut Receiver, } - self.register_waker(cx.waker()); + impl<'a, T> Drop for Guard<'a, T> { + fn drop(&mut self) { + self.receiver.waiter = self.waiter.take(); + } + } - if let Some(value) = ok_empty(self.try_recv())? { - Poll::Ready(Ok(value)) - } else { - Poll::Pending + let waiter = self.waiter.take().or_else(|| { + Some(Box::pin(UnsafeCell::new(Waiter { + queued: false, + waker: None, + pointers: linked_list::Pointers::new(), + _p: PhantomPinned, + }))) + }); + + let guard = Guard { + waiter, + receiver: self, + }; + let res = guard + .receiver + .recv_ref(Some((&guard.waiter.as_ref().unwrap(), cx.waker()))); + + match res { + Ok(guard) => Ready(guard.clone_value().ok_or(RecvError::Closed)), + Err(TryRecvError::Closed) => Ready(Err(RecvError::Closed)), + Err(TryRecvError::Lagged(n)) => Ready(Err(RecvError::Lagged(n))), + Err(TryRecvError::Empty) => Pending, } } @@ -842,44 +949,14 @@ where /// assert_eq!(30, rx.recv().await.unwrap()); /// } pub async fn recv(&mut self) -> Result { - use crate::future::poll_fn; - - poll_fn(|cx| self.poll_recv(cx)).await - } - - fn register_waker(&self, cx: &Waker) { - self.wait.waker.register_by_ref(cx); - - if !self.wait.queued.load(SeqCst) { - // Set `queued` before queuing. - self.wait.queued.store(true, SeqCst); - - let mut curr = self.shared.wait_stack.load(SeqCst); - - // The ref count is decremented in `notify_rx` when all nodes are - // removed from the waiter stack. - let node = Arc::into_raw(self.wait.clone()) as *mut _; - - loop { - // Safety: `queued == false` means the caller has exclusive - // access to `self.wait.next`. - self.wait.next.with_mut(|ptr| unsafe { *ptr = curr }); - - let res = self - .shared - .wait_stack - .compare_exchange(curr, node, SeqCst, SeqCst); - - match res { - Ok(_) => return, - Err(actual) => curr = actual, - } - } - } + let fut = Recv::<_, T>::new(Borrow(self)); + fut.await } } #[cfg(feature = "stream")] +#[doc(hidden)] +#[deprecated(since = "0.2.21", note = "use `into_stream()`")] impl crate::stream::Stream for Receiver where T: Clone, @@ -890,6 +967,7 @@ where mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { + #[allow(deprecated)] self.poll_recv(cx).map(|v| match v { Ok(v) => Some(Ok(v)), lag @ Err(RecvError::Lagged(_)) => Some(lag), @@ -902,14 +980,30 @@ impl Drop for Receiver { fn drop(&mut self) { let mut tail = self.shared.tail.lock().unwrap(); + if let Some(waiter) = &self.waiter { + // safety: tail lock is held + let queued = waiter.with(|ptr| unsafe { (*ptr).queued }); + + if queued { + // Remove the node + // + // safety: tail lock is held and the wait node is verified to be in + // the list. + unsafe { + waiter.with_mut(|ptr| { + tail.waiters.remove((&mut *ptr).into()); + }); + } + } + } + tail.rx_cnt -= 1; let until = tail.pos; drop(tail); while self.next != until { - match self.recv_ref(true) { - // Ignore the value + match self.recv_ref(None) { Ok(_) => {} // The channel is closed Err(TryRecvError::Closed) => break, @@ -922,103 +1016,198 @@ impl Drop for Receiver { } } -impl Drop for Shared { - fn drop(&mut self) { - // Clear the wait stack - let mut curr = self.wait_stack.with_mut(|ptr| *ptr as *const WaitNode); +impl Recv +where + R: AsMut>, +{ + fn new(receiver: R) -> Recv { + Recv { + receiver, + waiter: UnsafeCell::new(Waiter { + queued: false, + waker: None, + pointers: linked_list::Pointers::new(), + _p: PhantomPinned, + }), + _p: std::marker::PhantomData, + } + } - while !curr.is_null() { - let waiter = unsafe { Arc::from_raw(curr) }; - curr = waiter.next.with(|ptr| unsafe { *ptr }); + /// A custom `project` implementation is used in place of `pin-project-lite` + /// as a custom drop implementation is needed. + fn project(self: Pin<&mut Self>) -> (&mut Receiver, &UnsafeCell) { + unsafe { + // Safety: Receiver is Unpin + is_unpin::<&mut Receiver>(); + + let me = self.get_unchecked_mut(); + (me.receiver.as_mut(), &me.waiter) } } } -impl fmt::Debug for Sender { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "broadcast::Sender") +impl Future for Recv +where + R: AsMut>, + T: Clone, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let (receiver, waiter) = self.project(); + + let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) { + Ok(value) => value, + Err(TryRecvError::Empty) => return Poll::Pending, + Err(TryRecvError::Lagged(n)) => return Poll::Ready(Err(RecvError::Lagged(n))), + Err(TryRecvError::Closed) => return Poll::Ready(Err(RecvError::Closed)), + }; + + Poll::Ready(guard.clone_value().ok_or(RecvError::Closed)) } } -impl fmt::Debug for Receiver { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "broadcast::Receiver") +cfg_stream! { + use futures_core::Stream; + + impl Receiver { + /// Convert the receiver into a `Stream`. + /// + /// The conversion allows using `Receiver` with APIs that require stream + /// values. + /// + /// # Examples + /// + /// ``` + /// use tokio::stream::StreamExt; + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = broadcast::channel(128); + /// + /// tokio::spawn(async move { + /// for i in 0..10_i32 { + /// tx.send(i).unwrap(); + /// } + /// }); + /// + /// // Streams must be pinned to iterate. + /// tokio::pin! { + /// let stream = rx + /// .into_stream() + /// .filter(Result::is_ok) + /// .map(Result::unwrap) + /// .filter(|v| v % 2 == 0) + /// .map(|v| v + 1); + /// } + /// + /// while let Some(i) = stream.next().await { + /// println!("{}", i); + /// } + /// } + /// ``` + pub fn into_stream(self) -> impl Stream> { + Recv::new(Borrow(self)) + } } -} -impl Slot { - /// Tries to lock the slot for a receiver. If `false`, then a sender holds the - /// lock and the calling task will be notified once the sender has released - /// the lock. - fn try_rx_lock(&self) -> bool { - let mut curr = self.lock.load(SeqCst); - - loop { - if curr & 1 == 1 { - // Locked by sender - return false; - } + impl Stream for Recv + where + R: AsMut>, + T: Clone, + { + type Item = Result; - // Only increment (by 2) if the LSB "lock" bit is not set. - let res = self.lock.compare_exchange(curr, curr + 2, SeqCst, SeqCst); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let (receiver, waiter) = self.project(); - match res { - Ok(_) => return true, - Err(actual) => curr = actual, - } + let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) { + Ok(value) => value, + Err(TryRecvError::Empty) => return Poll::Pending, + Err(TryRecvError::Lagged(n)) => return Poll::Ready(Some(Err(RecvError::Lagged(n)))), + Err(TryRecvError::Closed) => return Poll::Ready(None), + }; + + Poll::Ready(guard.clone_value().map(Ok)) } } +} - fn rx_unlock(&self, tail: &Mutex, condvar: &Condvar, rem_dec: bool) { - if rem_dec { - // Decrement the remaining counter - if 1 == self.rem.fetch_sub(1, SeqCst) { - // Last receiver, drop the value - self.write.val.with_mut(|ptr| unsafe { *ptr = None }); +impl Drop for Recv +where + R: AsMut>, +{ + fn drop(&mut self) { + // Acquire the tail lock. This is required for safety before accessing + // the waiter node. + let mut tail = self.receiver.as_mut().shared.tail.lock().unwrap(); + + // safety: tail lock is held + let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued }); + + if queued { + // Remove the node + // + // safety: tail lock is held and the wait node is verified to be in + // the list. + unsafe { + self.waiter.with_mut(|ptr| { + tail.waiters.remove((&mut *ptr).into()); + }); } } + } +} - if 1 == self.lock.fetch_sub(2, SeqCst) - 2 { - // First acquire the lock to make sure our sender is waiting on the - // condition variable, otherwise the notification could be lost. - mem::drop(tail.lock().unwrap()); - // Wake up senders - condvar.notify_all(); - } +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + type Handle = NonNull; + type Target = Waiter; + + fn as_raw(handle: &NonNull) -> NonNull { + *handle + } + + unsafe fn from_raw(ptr: NonNull) -> NonNull { + ptr + } + + unsafe fn pointers(mut target: NonNull) -> NonNull> { + NonNull::from(&mut target.as_mut().pointers) } } -impl<'a, T> RecvGuard<'a, T> { - fn pos(&self) -> u64 { - self.slot.write.pos.with(|ptr| unsafe { *ptr }) +impl fmt::Debug for Sender { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "broadcast::Sender") } +} +impl fmt::Debug for Receiver { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "broadcast::Receiver") + } +} + +impl<'a, T> RecvGuard<'a, T> { fn clone_value(&self) -> Option where T: Clone, { - self.slot.write.val.with(|ptr| unsafe { (*ptr).clone() }) - } - - fn drop_no_rem_dec(self) { - self.slot.rx_unlock(self.tail, self.condvar, false); - - mem::forget(self); + self.slot.val.with(|ptr| unsafe { (*ptr).clone() }) } } impl<'a, T> Drop for RecvGuard<'a, T> { fn drop(&mut self) { - self.slot.rx_unlock(self.tail, self.condvar, true) - } -} - -fn ok_empty(res: Result) -> Result, RecvError> { - match res { - Ok(value) => Ok(Some(value)), - Err(TryRecvError::Empty) => Ok(None), - Err(TryRecvError::Lagged(n)) => Err(RecvError::Lagged(n)), - Err(TryRecvError::Closed) => Err(RecvError::Closed), + // Decrement the remaining counter + if 1 == self.slot.rem.fetch_sub(1, SeqCst) { + // Safety: Last receiver, drop the value + self.slot.val.with_mut(|ptr| unsafe { *ptr = None }); + } } } @@ -1044,3 +1233,5 @@ impl fmt::Display for TryRecvError { } impl std::error::Error for TryRecvError {} + +fn is_unpin() {} diff --git a/tokio/src/sync/cancellation_token.rs b/tokio/src/sync/cancellation_token.rs new file mode 100644 index 00000000000..d60d8e0202c --- /dev/null +++ b/tokio/src/sync/cancellation_token.rs @@ -0,0 +1,861 @@ +//! An asynchronously awaitable `CancellationToken`. +//! The token allows to signal a cancellation request to one or more tasks. + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Mutex; +use crate::util::intrusive_double_linked_list::{LinkedList, ListNode}; + +use core::future::Future; +use core::pin::Pin; +use core::ptr::NonNull; +use core::sync::atomic::Ordering; +use core::task::{Context, Poll, Waker}; + +/// A token which can be used to signal a cancellation request to one or more +/// tasks. +/// +/// Tasks can call [`CancellationToken::cancelled()`] in order to +/// obtain a Future which will be resolved when cancellation is requested. +/// +/// Cancellation can be requested through the [`CancellationToken::cancel`] method. +/// +/// # Examples +/// +/// ```ignore +/// use tokio::select; +/// use tokio::scope::CancellationToken; +/// +/// #[tokio::main] +/// async fn main() { +/// let token = CancellationToken::new(); +/// let cloned_token = token.clone(); +/// +/// let join_handle = tokio::spawn(async move { +/// // Wait for either cancellation or a very long time +/// select! { +/// _ = cloned_token.cancelled() => { +/// // The token was cancelled +/// 5 +/// } +/// _ = tokio::time::delay_for(std::time::Duration::from_secs(9999)) => { +/// 99 +/// } +/// } +/// }); +/// +/// tokio::spawn(async move { +/// tokio::time::delay_for(std::time::Duration::from_millis(10)).await; +/// token.cancel(); +/// }); +/// +/// assert_eq!(5, join_handle.await.unwrap()); +/// } +/// ``` +pub struct CancellationToken { + inner: NonNull, +} + +// Safety: The CancellationToken is thread-safe and can be moved between threads, +// since all methods are internally synchronized. +unsafe impl Send for CancellationToken {} +unsafe impl Sync for CancellationToken {} + +/// A Future that is resolved once the corresponding [`CancellationToken`] +/// was cancelled +#[must_use = "futures do nothing unless polled"] +pub struct WaitForCancellationFuture<'a> { + /// The CancellationToken that is associated with this WaitForCancellationFuture + cancellation_token: Option<&'a CancellationToken>, + /// Node for waiting at the cancellation_token + wait_node: ListNode, + /// Whether this future was registered at the token yet as a waiter + is_registered: bool, +} + +// Safety: Futures can be sent between threads as long as the underlying +// cancellation_token is thread-safe (Sync), +// which allows to poll/register/unregister from a different thread. +unsafe impl<'a> Send for WaitForCancellationFuture<'a> {} + +// ===== impl CancellationToken ===== + +impl core::fmt::Debug for CancellationToken { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("CancellationToken") + .field("is_cancelled", &self.is_cancelled()) + .finish() + } +} + +impl Clone for CancellationToken { + fn clone(&self) -> Self { + // Safety: The state inside a `CancellationToken` is always valid, since + // is reference counted + let inner = self.state(); + + // Tokens are cloned by increasing their refcount + let current_state = inner.snapshot(); + inner.increment_refcount(current_state); + + CancellationToken { inner: self.inner } + } +} + +impl Drop for CancellationToken { + fn drop(&mut self) { + let token_state_pointer = self.inner; + + // Safety: The state inside a `CancellationToken` is always valid, since + // is reference counted + let inner = unsafe { &mut *self.inner.as_ptr() }; + + let mut current_state = inner.snapshot(); + + // We need to safe the parent, since the state might be released by the + // next call + let parent = inner.parent; + + // Drop our own refcount + current_state = inner.decrement_refcount(current_state); + + // If this was the last reference, unregister from the parent + if current_state.refcount == 0 { + if let Some(mut parent) = parent { + // Safety: Since we still retain a reference on the parent, it must be valid. + let parent = unsafe { parent.as_mut() }; + parent.unregister_child(token_state_pointer, current_state); + } + } + } +} + +impl CancellationToken { + /// Creates a new CancellationToken in the non-cancelled state. + pub fn new() -> CancellationToken { + let state = Box::new(CancellationTokenState::new( + None, + StateSnapshot { + cancel_state: CancellationState::NotCancelled, + has_parent_ref: false, + refcount: 1, + }, + )); + + // Safety: We just created the Box. The pointer is guaranteed to be + // not null + CancellationToken { + inner: unsafe { NonNull::new_unchecked(Box::into_raw(state)) }, + } + } + + /// Returns a reference to the utilized `CancellationTokenState`. + fn state(&self) -> &CancellationTokenState { + // Safety: The state inside a `CancellationToken` is always valid, since + // is reference counted + unsafe { &*self.inner.as_ptr() } + } + + /// Creates a `CancellationToken` which will get cancelled whenever the + /// current token gets cancelled. + /// + /// If the current token is already cancelled, the child token will get + /// returned in cancelled state. + /// + /// # Examples + /// + /// ```ignore + /// use tokio::select; + /// use tokio::scope::CancellationToken; + /// + /// #[tokio::main] + /// async fn main() { + /// let token = CancellationToken::new(); + /// let child_token = token.child_token(); + /// + /// let join_handle = tokio::spawn(async move { + /// // Wait for either cancellation or a very long time + /// select! { + /// _ = child_token.cancelled() => { + /// // The token was cancelled + /// 5 + /// } + /// _ = tokio::time::delay_for(std::time::Duration::from_secs(9999)) => { + /// 99 + /// } + /// } + /// }); + /// + /// tokio::spawn(async move { + /// tokio::time::delay_for(std::time::Duration::from_millis(10)).await; + /// token.cancel(); + /// }); + /// + /// assert_eq!(5, join_handle.await.unwrap()); + /// } + /// ``` + pub fn child_token(&self) -> CancellationToken { + let inner = self.state(); + + // Increment the refcount of this token. It will be referenced by the + // child, independent of whether the child is immediately cancelled or + // not. + let _current_state = inner.increment_refcount(inner.snapshot()); + + let mut unpacked_child_state = StateSnapshot { + has_parent_ref: true, + refcount: 1, + cancel_state: CancellationState::NotCancelled, + }; + let mut child_token_state = Box::new(CancellationTokenState::new( + Some(self.inner), + unpacked_child_state, + )); + + { + let mut guard = inner.synchronized.lock().unwrap(); + if guard.is_cancelled { + // This task was already cancelled. In this case we should not + // insert the child into the list, since it would never get removed + // from the list. + (*child_token_state.synchronized.lock().unwrap()).is_cancelled = true; + unpacked_child_state.cancel_state = CancellationState::Cancelled; + // Since it's not in the list, the parent doesn't need to retain + // a reference to it. + unpacked_child_state.has_parent_ref = false; + child_token_state + .state + .store(unpacked_child_state.pack(), Ordering::SeqCst); + } else { + if let Some(mut first_child) = guard.first_child { + child_token_state.from_parent.next_peer = Some(first_child); + // Safety: We manipulate other child task inside the Mutex + // and retain a parent reference on it. The child token can't + // get invalidated while the Mutex is held. + unsafe { + first_child.as_mut().from_parent.prev_peer = + Some((&mut *child_token_state).into()) + }; + } + guard.first_child = Some((&mut *child_token_state).into()); + } + }; + + let child_token_ptr = Box::into_raw(child_token_state); + // Safety: We just created the pointer from a `Box` + CancellationToken { + inner: unsafe { NonNull::new_unchecked(child_token_ptr) }, + } + } + + /// Cancel the [`CancellationToken`] and all child tokens which had been + /// derived from it. + /// + /// This will wake up all tasks which are waiting for cancellation. + pub fn cancel(&self) { + self.state().cancel(); + } + + /// Returns `true` if the `CancellationToken` had been cancelled + pub fn is_cancelled(&self) -> bool { + self.state().is_cancelled() + } + + /// Returns a `Future` that gets fulfilled when cancellation is requested. + pub fn cancelled(&self) -> WaitForCancellationFuture<'_> { + WaitForCancellationFuture { + cancellation_token: Some(self), + wait_node: ListNode::new(WaitQueueEntry::new()), + is_registered: false, + } + } + + unsafe fn register( + &self, + wait_node: &mut ListNode, + cx: &mut Context<'_>, + ) -> Poll<()> { + self.state().register(wait_node, cx) + } + + fn check_for_cancellation( + &self, + wait_node: &mut ListNode, + cx: &mut Context<'_>, + ) -> Poll<()> { + self.state().check_for_cancellation(wait_node, cx) + } + + fn unregister(&self, wait_node: &mut ListNode) { + self.state().unregister(wait_node) + } +} + +// ===== impl WaitForCancellationFuture ===== + +impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitForCancellationFuture").finish() + } +} + +impl<'a> Future for WaitForCancellationFuture<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + // Safety: We do not move anything out of `WaitForCancellationFuture` + let mut_self: &mut WaitForCancellationFuture<'_> = unsafe { Pin::get_unchecked_mut(self) }; + + let cancellation_token = mut_self + .cancellation_token + .expect("polled WaitForCancellationFuture after completion"); + + let poll_res = if !mut_self.is_registered { + // Safety: The `ListNode` is pinned through the Future, + // and we will unregister it in `WaitForCancellationFuture::drop` + // before the Future is dropped and the memory reference is invalidated. + unsafe { cancellation_token.register(&mut mut_self.wait_node, cx) } + } else { + cancellation_token.check_for_cancellation(&mut mut_self.wait_node, cx) + }; + + if let Poll::Ready(()) = poll_res { + // The cancellation_token was signalled + mut_self.cancellation_token = None; + // A signalled Token means the Waker won't be enqueued anymore + mut_self.is_registered = false; + mut_self.wait_node.task = None; + } else { + // This `Future` and its stored `Waker` stay registered at the + // `CancellationToken` + mut_self.is_registered = true; + } + + poll_res + } +} + +impl<'a> Drop for WaitForCancellationFuture<'a> { + fn drop(&mut self) { + // If this WaitForCancellationFuture has been polled and it was added to the + // wait queue at the cancellation_token, it must be removed before dropping. + // Otherwise the cancellation_token would access invalid memory. + if let Some(token) = self.cancellation_token { + if self.is_registered { + token.unregister(&mut self.wait_node); + } + } + } +} + +/// Tracks how the future had interacted with the [`CancellationToken`] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum PollState { + /// The task has never interacted with the [`CancellationToken`]. + New, + /// The task was added to the wait queue at the [`CancellationToken`]. + Waiting, + /// The task has been polled to completion. + Done, +} + +/// Tracks the WaitForCancellationFuture waiting state. +/// Access to this struct is synchronized through the mutex in the CancellationToken. +struct WaitQueueEntry { + /// The task handle of the waiting task + task: Option, + // Current polling state. This state is only updated inside the Mutex of + // the CancellationToken. + state: PollState, +} + +impl WaitQueueEntry { + /// Creates a new WaitQueueEntry + fn new() -> WaitQueueEntry { + WaitQueueEntry { + task: None, + state: PollState::New, + } + } +} + +struct SynchronizedState { + waiters: LinkedList, + first_child: Option>, + is_cancelled: bool, +} + +impl SynchronizedState { + fn new() -> Self { + Self { + waiters: LinkedList::new(), + first_child: None, + is_cancelled: false, + } + } +} + +/// Information embedded in child tokens which is synchronized through the Mutex +/// in their parent. +struct SynchronizedThroughParent { + next_peer: Option>, + prev_peer: Option>, +} + +/// Possible states of a `CancellationToken` +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum CancellationState { + NotCancelled = 0, + Cancelling = 1, + Cancelled = 2, +} + +impl CancellationState { + fn pack(self) -> usize { + self as usize + } + + fn unpack(value: usize) -> Self { + match value { + 0 => CancellationState::NotCancelled, + 1 => CancellationState::Cancelling, + 2 => CancellationState::Cancelled, + _ => unreachable!("Invalid value"), + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +struct StateSnapshot { + /// The amount of references to this particular CancellationToken. + /// `CancellationToken` structs hold these references to a `CancellationTokenState`. + /// Also the state is referenced by the state of each child. + refcount: usize, + /// Whether the state is still referenced by it's parent and can therefore + /// not be freed. + has_parent_ref: bool, + /// Whether the token is cancelled + cancel_state: CancellationState, +} + +impl StateSnapshot { + /// Packs the snapshot into a `usize` + fn pack(self) -> usize { + self.refcount << 3 | if self.has_parent_ref { 4 } else { 0 } | self.cancel_state.pack() + } + + /// Unpacks the snapshot from a `usize` + fn unpack(value: usize) -> Self { + let refcount = value >> 3; + let has_parent_ref = value & 4 != 0; + let cancel_state = CancellationState::unpack(value & 0x03); + + StateSnapshot { + refcount, + has_parent_ref, + cancel_state, + } + } + + /// Whether this `CancellationTokenState` is still referenced by any + /// `CancellationToken`. + fn has_refs(&self) -> bool { + self.refcount != 0 || self.has_parent_ref + } +} + +/// The maximum permitted amount of references to a CancellationToken. This +/// is derived from the intent to never use more than 32bit in the `Snapshot`. +const MAX_REFS: u32 = (std::u32::MAX - 7) >> 3; + +/// Internal state of the `CancellationToken` pair above +struct CancellationTokenState { + state: AtomicUsize, + parent: Option>, + from_parent: SynchronizedThroughParent, + synchronized: Mutex, +} + +impl CancellationTokenState { + fn new( + parent: Option>, + state: StateSnapshot, + ) -> CancellationTokenState { + CancellationTokenState { + parent, + from_parent: SynchronizedThroughParent { + prev_peer: None, + next_peer: None, + }, + state: AtomicUsize::new(state.pack()), + synchronized: Mutex::new(SynchronizedState::new()), + } + } + + /// Returns a snapshot of the current atomic state of the token + fn snapshot(&self) -> StateSnapshot { + StateSnapshot::unpack(self.state.load(Ordering::SeqCst)) + } + + fn atomic_update_state(&self, mut current_state: StateSnapshot, func: F) -> StateSnapshot + where + F: Fn(StateSnapshot) -> StateSnapshot, + { + let mut current_packed_state = current_state.pack(); + loop { + let next_state = func(current_state); + match self.state.compare_exchange( + current_packed_state, + next_state.pack(), + Ordering::SeqCst, + Ordering::SeqCst, + ) { + Ok(_) => { + return next_state; + } + Err(actual) => { + current_packed_state = actual; + current_state = StateSnapshot::unpack(actual); + } + } + } + } + + fn increment_refcount(&self, current_state: StateSnapshot) -> StateSnapshot { + self.atomic_update_state(current_state, |mut state: StateSnapshot| { + if state.refcount >= MAX_REFS as usize { + eprintln!("[ERROR] Maximum reference count for CancellationToken was exceeded"); + std::process::abort(); + } + state.refcount += 1; + state + }) + } + + fn decrement_refcount(&self, current_state: StateSnapshot) -> StateSnapshot { + let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| { + state.refcount -= 1; + state + }); + + // Drop the State if it is not referenced anymore + if !current_state.has_refs() { + // Safety: `CancellationTokenState` is always stored in refcounted + // Boxes + let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) }; + } + + current_state + } + + fn remove_parent_ref(&self, current_state: StateSnapshot) -> StateSnapshot { + let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| { + state.has_parent_ref = false; + state + }); + + // Drop the State if it is not referenced anymore + if !current_state.has_refs() { + // Safety: `CancellationTokenState` is always stored in refcounted + // Boxes + let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) }; + } + + current_state + } + + /// Unregisters a child from the parent token. + /// The child tokens state is not exactly known at this point in time. + /// If the parent token is cancelled, the child token gets removed from the + /// parents list, and might therefore already have been freed. If the parent + /// token is not cancelled, the child token is still valid. + fn unregister_child( + &mut self, + mut child_state: NonNull, + current_child_state: StateSnapshot, + ) { + let removed_child = { + // Remove the child toke from the parents linked list + let mut guard = self.synchronized.lock().unwrap(); + if !guard.is_cancelled { + // Safety: Since the token was not cancelled, the child must + // still be in the list and valid. + let mut child_state = unsafe { child_state.as_mut() }; + debug_assert!(child_state.snapshot().has_parent_ref); + + if guard.first_child == Some(child_state.into()) { + guard.first_child = child_state.from_parent.next_peer; + } + // Safety: If peers wouldn't be valid anymore, they would try + // to remove themselves from the list. This would require locking + // the Mutex that we currently own. + unsafe { + if let Some(mut prev_peer) = child_state.from_parent.prev_peer { + prev_peer.as_mut().from_parent.next_peer = + child_state.from_parent.next_peer; + } + if let Some(mut next_peer) = child_state.from_parent.next_peer { + next_peer.as_mut().from_parent.prev_peer = + child_state.from_parent.prev_peer; + } + } + child_state.from_parent.prev_peer = None; + child_state.from_parent.next_peer = None; + + // The child is no longer referenced by the parent, since we were able + // to remove its reference from the parents list. + true + } else { + // Do not touch the linked list anymore. If the parent is cancelled + // it will move all childs outside of the Mutex and manipulate + // the pointers there. Manipulating the pointers here too could + // lead to races. Therefore leave them just as as and let the + // parent deal with it. The parent will make sure to retain a + // reference to this state as long as it manipulates the list + // pointers. Therefore the pointers are not dangling. + false + } + }; + + if removed_child { + // If the token removed itself from the parents list, it can reset + // the the parent ref status. If it is isn't able to do so, because the + // parent removed it from the list, there is no need to do this. + // The parent ref acts as as another reference count. Therefore + // removing this reference can free the object. + // Safety: The token was in the list. This means the parent wasn't + // cancelled before, and the token must still be alive. + unsafe { child_state.as_mut().remove_parent_ref(current_child_state) }; + } + + // Decrement the refcount on the parent and free it if necessary + self.decrement_refcount(self.snapshot()); + } + + fn cancel(&self) { + // Move the state of the CancellationToken from `NotCancelled` to `Cancelling` + let mut current_state = self.snapshot(); + + let state_after_cancellation = loop { + if current_state.cancel_state != CancellationState::NotCancelled { + // Another task already initiated the cancellation + return; + } + + let mut next_state = current_state; + next_state.cancel_state = CancellationState::Cancelling; + match self.state.compare_exchange( + current_state.pack(), + next_state.pack(), + Ordering::SeqCst, + Ordering::SeqCst, + ) { + Ok(_) => break next_state, + Err(actual) => current_state = StateSnapshot::unpack(actual), + } + }; + + // This task cancelled the token + + // Take the task list out of the Token + // We do not want to cancel child token inside this lock. If one of the + // child tasks would have additional child tokens, we would recursively + // take locks. + + // Doing this action has an impact if the child token is dropped concurrently: + // It will try to deregister itself from the parent task, but can not find + // itself in the task list anymore. Therefore it needs to assume the parent + // has extracted the list and will process it. It may not modify the list. + // This is OK from a memory safety perspective, since the parent still + // retains a reference to the child task until it finished iterating over + // it. + + let mut first_child = { + let mut guard = self.synchronized.lock().unwrap(); + // Save the cancellation also inside the Mutex + // This allows child tokens which want to detach themselves to detect + // that this is no longer required since the parent cleared the list. + guard.is_cancelled = true; + + // Wakeup all waiters + // This happens inside the lock to make cancellation reliable + // If we would access waiters outside of the lock, the pointers + // may no longer be valid. + // Typically this shouldn't be an issue, since waking a task should + // only move it from the blocked into the ready state and not have + // further side effects. + + // Use a reverse iterator, so that the oldest waiter gets + // scheduled first + guard.waiters.reverse_drain(|waiter| { + // We are not allowed to move the `Waker` out of the list node. + // The `Future` relies on the fact that the old `Waker` stays there + // as long as the `Future` has not completed in order to perform + // the `will_wake()` check. + // Therefore `wake_by_ref` is used instead of `wake()` + if let Some(handle) = &mut waiter.task { + handle.wake_by_ref(); + } + // Mark the waiter to have been removed from the list. + waiter.state = PollState::Done; + }); + + guard.first_child.take() + }; + + while let Some(mut child) = first_child { + // Safety: We know this is a valid pointer since it is in our child pointer + // list. It can't have been freed in between, since we retain a a reference + // to each child. + let mut_child = unsafe { child.as_mut() }; + + // Get the next child and clean up list pointers + first_child = mut_child.from_parent.next_peer; + mut_child.from_parent.prev_peer = None; + mut_child.from_parent.next_peer = None; + + // Cancel the child task + mut_child.cancel(); + + // Drop the parent reference. This `CancellationToken` is not interested + // in interacting with the child anymore. + // This is ONLY allowed once we promised not to touch the state anymore + // after this interaction. + mut_child.remove_parent_ref(mut_child.snapshot()); + } + + // The cancellation has completed + // At this point in time tasks which registered a wait node can be sure + // that this wait node already had been dequeued from the list without + // needing to inspect the list. + self.atomic_update_state(state_after_cancellation, |mut state| { + state.cancel_state = CancellationState::Cancelled; + state + }); + } + + /// Returns `true` if the `CancellationToken` had been cancelled + fn is_cancelled(&self) -> bool { + let current_state = self.snapshot(); + current_state.cancel_state != CancellationState::NotCancelled + } + + /// Registers a waiting task at the `CancellationToken`. + /// Safety: This method is only safe as long as the waiting waiting task + /// will properly unregister the wait node before it gets moved. + unsafe fn register( + &self, + wait_node: &mut ListNode, + cx: &mut Context<'_>, + ) -> Poll<()> { + debug_assert_eq!(PollState::New, wait_node.state); + let current_state = self.snapshot(); + + // Perform an optimistic cancellation check before. This is not strictly + // necessary since we also check for cancellation in the Mutex, but + // reduces the necessary work to be performed for tasks which already + // had been cancelled. + if current_state.cancel_state != CancellationState::NotCancelled { + return Poll::Ready(()); + } + + // So far the token is not cancelled. However it could be cancelld before + // we get the chance to store the `Waker`. Therfore we need to check + // for cancellation again inside the mutex. + let mut guard = self.synchronized.lock().unwrap(); + if guard.is_cancelled { + // Cancellation was signalled + wait_node.state = PollState::Done; + Poll::Ready(()) + } else { + // Added the task to the wait queue + wait_node.task = Some(cx.waker().clone()); + wait_node.state = PollState::Waiting; + guard.waiters.add_front(wait_node); + Poll::Pending + } + } + + fn check_for_cancellation( + &self, + wait_node: &mut ListNode, + cx: &mut Context<'_>, + ) -> Poll<()> { + debug_assert!( + wait_node.task.is_some(), + "Method can only be called after task had been registered" + ); + + let current_state = self.snapshot(); + + if current_state.cancel_state != CancellationState::NotCancelled { + // If the cancellation had been fully completed we know that our `Waker` + // is no longer registered at the `CancellationToken`. + // Otherwise the cancel call may or may not yet have iterated + // through the waiters list and removed the wait nodes. + // If it hasn't yet, we need to remove it. Otherwise an attempt to + // reuse the `wait_node´ might get freed due to the `WaitForCancellationFuture` + // getting dropped before the cancellation had interacted with it. + if current_state.cancel_state != CancellationState::Cancelled { + self.unregister(wait_node); + } + Poll::Ready(()) + } else { + // Check if we need to swap the `Waker`. This will make the check more + // expensive, since the `Waker` is synchronized through the Mutex. + // If we don't need to perform a `Waker` update, an atomic check for + // cancellation is sufficient. + let need_waker_update = wait_node + .task + .as_ref() + .map(|waker| waker.will_wake(cx.waker())) + .unwrap_or(true); + + if need_waker_update { + let guard = self.synchronized.lock().unwrap(); + if guard.is_cancelled { + // Cancellation was signalled. Since this cancellation signal + // is set inside the Mutex, the old waiter must already have + // been removed from the waiting list + debug_assert_eq!(PollState::Done, wait_node.state); + wait_node.task = None; + Poll::Ready(()) + } else { + // The WaitForCancellationFuture is already in the queue. + // The CancellationToken can't have been cancelled, + // since this would change the is_cancelled flag inside the mutex. + // Therefore we just have to update the Waker. A follow-up + // cancellation will always use the new waker. + wait_node.task = Some(cx.waker().clone()); + Poll::Pending + } + } else { + // Do nothing. If the token gets cancelled, this task will get + // woken again and can fetch the cancellation. + Poll::Pending + } + } + } + + fn unregister(&self, wait_node: &mut ListNode) { + debug_assert!( + wait_node.task.is_some(), + "waiter can not be active without task" + ); + + let mut guard = self.synchronized.lock().unwrap(); + // WaitForCancellationFuture only needs to get removed if it has been added to + // the wait queue of the CancellationToken. + // This has happened in the PollState::Waiting case. + if let PollState::Waiting = wait_node.state { + // Safety: Due to the state, we know that the node must be part + // of the waiter list + if !unsafe { guard.waiters.remove(wait_node) } { + // Panic if the address isn't found. This can only happen if the contract was + // violated, e.g. the WaitQueueEntry got moved after the initial poll. + panic!("Future could not be removed from wait queue"); + } + wait_node.state = PollState::Done; + } + wait_node.task = None; + } +} diff --git a/tokio/src/sync/mod.rs b/tokio/src/sync/mod.rs index 0607f78ad42..3d96106d2df 100644 --- a/tokio/src/sync/mod.rs +++ b/tokio/src/sync/mod.rs @@ -32,7 +32,7 @@ //! single producer to a single consumer. This channel is usually used to send //! the result of a computation to a waiter. //! -//! **Example:** using a `oneshot` channel to receive the result of a +//! **Example:** using a [`oneshot` channel][oneshot] to receive the result of a //! computation. //! //! ``` @@ -58,11 +58,12 @@ //! } //! ``` //! -//! Note, if the task produces the the computation result as its final action -//! before terminating, the [`JoinHandle`] can be used to receive the -//! computation result instead of allocating resources for the `oneshot` -//! channel. Awaiting on [`JoinHandle`] returns `Result`. If the task panics, -//! the `Joinhandle` yields `Err` with the panic cause. +//! Note, if the task produces a computation result as its final +//! action before terminating, the [`JoinHandle`] can be used to +//! receive that value instead of allocating resources for the +//! `oneshot` channel. Awaiting on [`JoinHandle`] returns `Result`. If +//! the task panics, the `Joinhandle` yields `Err` with the panic +//! cause. //! //! **Example:** //! @@ -84,6 +85,7 @@ //! } //! ``` //! +//! [oneshot]: oneshot //! [`JoinHandle`]: crate::task::JoinHandle //! //! ## `mpsc` channel @@ -126,7 +128,7 @@ //! pressure. //! //! A common concurrency pattern for resource management is to spawn a task -//! dedicated to managing that resource and using message passing betwen other +//! dedicated to managing that resource and using message passing between other //! tasks to interact with the resource. The resource may be anything that may //! not be concurrently used. Some examples include a socket and program state. //! For example, if multiple tasks need to send data over a single socket, spawn @@ -230,9 +232,11 @@ //! } //! ``` //! +//! [mpsc]: mpsc +//! //! ## `broadcast` channel //! -//! The [`broadcast` channel][broadcast] supports sending **many** values from +//! The [`broadcast` channel] supports sending **many** values from //! **many** producers to **many** consumers. Each consumer will receive //! **each** value. This channel can be used to implement "fan out" style //! patterns common with pub / sub or "chat" systems. @@ -265,12 +269,14 @@ //! } //! ``` //! +//! [`broadcast` channel]: crate::sync::broadcast +//! //! ## `watch` channel //! -//! The [`watch` channel][watch] supports sending **many** values from a -//! **single** producer to **many** consumers. However, only the **most recent** -//! value is stored in the channel. Consumers are notified when a new value is -//! sent, but there is no guarantee that consumers will see **all** values. +//! The [`watch` channel] supports sending **many** values from a **single** +//! producer to **many** consumers. However, only the **most recent** value is +//! stored in the channel. Consumers are notified when a new value is sent, but +//! there is no guarantee that consumers will see **all** values. //! //! The [`watch` channel] is similar to a [`broadcast` channel] with capacity 1. //! @@ -278,9 +284,9 @@ //! changes or signalling program state changes, such as transitioning to //! shutdown. //! -//! **Example:** use a `watch` channel to notify tasks of configuration changes. -//! In this example, a configuration file is checked periodically. When the file -//! changes, the configuration changes are signalled to consumers. +//! **Example:** use a [`watch` channel] to notify tasks of configuration +//! changes. In this example, a configuration file is checked periodically. When +//! the file changes, the configuration changes are signalled to consumers. //! //! ``` //! use tokio::sync::watch; @@ -393,6 +399,9 @@ //! } //! ``` //! +//! [`watch` channel]: crate::sync::watch +//! [`broadcast` channel]: crate::sync::broadcast +//! //! # State synchronization //! //! The remaining synchronization primitives focus on synchronizing state. @@ -400,23 +409,23 @@ //! operate in a similar way as their `std` counterparts parts but will wait //! asynchronously instead of blocking the thread. //! -//! * [`Barrier`][Barrier] Ensures multiple tasks will wait for each other to +//! * [`Barrier`](Barrier) Ensures multiple tasks will wait for each other to //! reach a point in the program, before continuing execution all together. //! -//! * [`Mutex`][Mutex] Mutual Exclusion mechanism, which ensures that at most +//! * [`Mutex`](Mutex) Mutual Exclusion mechanism, which ensures that at most //! one thread at a time is able to access some data. //! -//! * [`Notify`][Notify] Basic task notification. `Notify` supports notifying a +//! * [`Notify`](Notify) Basic task notification. `Notify` supports notifying a //! receiving task without sending data. In this case, the task wakes up and //! resumes processing. //! -//! * [`RwLock`][RwLock] Provides a mutual exclusion mechanism which allows +//! * [`RwLock`](RwLock) Provides a mutual exclusion mechanism which allows //! multiple readers at the same time, while allowing only one writer at a //! time. In some cases, this can be more efficient than a mutex. //! -//! * [`Semaphore`][Semaphore] Limits the amount of concurrency. A semaphore +//! * [`Semaphore`](Semaphore) Limits the amount of concurrency. A semaphore //! holds a number of permits, which tasks may request in order to enter a -//! critical section. Semaphores are useful for implementing limiting of +//! critical section. Semaphores are useful for implementing limiting or //! bounding of any kind. cfg_sync! { @@ -425,10 +434,15 @@ cfg_sync! { pub mod broadcast; + cfg_unstable! { + mod cancellation_token; + pub use cancellation_token::{CancellationToken, WaitForCancellationFuture}; + } + pub mod mpsc; mod mutex; - pub use mutex::{Mutex, MutexGuard}; + pub use mutex::{Mutex, MutexGuard, TryLockError, OwnedMutexGuard}; mod notify; pub use notify::Notify; @@ -438,7 +452,7 @@ cfg_sync! { pub(crate) mod batch_semaphore; pub(crate) mod semaphore_ll; mod semaphore; - pub use semaphore::{Semaphore, SemaphorePermit}; + pub use semaphore::{Semaphore, SemaphorePermit, OwnedSemaphorePermit}; mod rwlock; pub use rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard}; diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index 34663957883..148ee3ad766 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -277,7 +277,7 @@ where use super::block::Read::*; // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); self.inner.rx_fields.with_mut(|rx_fields_ptr| { let rx_fields = unsafe { &mut *rx_fields_ptr }; @@ -287,6 +287,7 @@ where match rx_fields.list.pop(&self.inner.tx) { Some(Value(value)) => { self.inner.semaphore.add_permit(); + coop.made_progress(); return Ready(Some(value)); } Some(Closed) => { @@ -297,6 +298,7 @@ where // which ensures that if dropping the tx handle is // visible, then all messages sent are also visible. assert!(self.inner.semaphore.is_idle()); + coop.made_progress(); return Ready(None); } None => {} // fall through @@ -314,6 +316,7 @@ where try_recv!(); if rx_fields.rx_closed && self.inner.semaphore.is_idle() { + coop.made_progress(); Ready(None) } else { Pending @@ -439,11 +442,15 @@ impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) { permit: &mut Permit, ) -> Poll> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); permit .poll_acquire(cx, 1, &self.0) .map_err(|_| ClosedError::new()) + .map(move |r| { + coop.made_progress(); + r + }) } fn try_acquire(&self, permit: &mut Permit) -> Result<(), TrySendError> { diff --git a/tokio/src/sync/mpsc/error.rs b/tokio/src/sync/mpsc/error.rs index 9919356314f..72c42aa53e7 100644 --- a/tokio/src/sync/mpsc/error.rs +++ b/tokio/src/sync/mpsc/error.rs @@ -96,7 +96,7 @@ impl Error for TryRecvError {} // ===== ClosedError ===== -/// Error returned by [`Sender::poll_ready`](super::Sender::poll_ready)]. +/// Error returned by [`Sender::poll_ready`](super::Sender::poll_ready). #[derive(Debug)] pub struct ClosedError(()); diff --git a/tokio/src/sync/mpsc/mod.rs b/tokio/src/sync/mpsc/mod.rs index 4cfd6150f30..c489c9f99ff 100644 --- a/tokio/src/sync/mpsc/mod.rs +++ b/tokio/src/sync/mpsc/mod.rs @@ -6,13 +6,17 @@ //! Similar to `std`, channel creation provides [`Receiver`] and [`Sender`] //! handles. [`Receiver`] implements `Stream` and allows a task to read values //! out of the channel. If there is no message to read, the current task will be -//! notified when a new value is sent. [`Sender`] implements the `Sink` trait -//! and allows sending messages into the channel. If the channel is at capacity, -//! the send is rejected and the task will be notified when additional capacity -//! is available. In other words, the channel provides backpressure. +//! notified when a new value is sent. If the channel is at capacity, the send +//! is rejected and the task will be notified when additional capacity is +//! available. In other words, the channel provides backpressure. //! -//! Unbounded channels are also available using the `unbounded_channel` -//! constructor. +//! This module provides two variants of the channel: bounded and unbounded. The +//! bounded variant has a limit on the number of messages that the channel can +//! store, and if this limit is reached, trying to send another message will +//! wait until a message is received from the channel. An unbounded channel has +//! an infinite capacity, so the `send` method never does any kind of sleeping. +//! This makes the [`UnboundedSender`] usable from both synchronous and +//! asynchronous code. //! //! # Disconnection //! @@ -33,8 +37,32 @@ //! consumes the channel to completion, at which point the receiver can be //! dropped. //! +//! # Communicating between sync and async code +//! +//! When you want to communicate between synchronous and asynchronous code, there +//! are two situations to consider: +//! +//! **Bounded channel**: If you need a bounded channel, you should use a bounded +//! Tokio `mpsc` channel for both directions of communication. To call the async +//! [`send`][bounded-send] or [`recv`][bounded-recv] methods in sync code, you +//! will need to use [`Handle::block_on`], which allow you to execute an async +//! method in synchronous code. This is necessary because a bounded channel may +//! need to wait for additional capacity to become available. +//! +//! **Unbounded channel**: You should use the kind of channel that matches where +//! the receiver is. So for sending a message _from async to sync_, you should +//! use [the standard library unbounded channel][std-unbounded] or +//! [crossbeam][crossbeam-unbounded]. Similarly, for sending a message _from sync +//! to async_, you should use an unbounded Tokio `mpsc` channel. +//! //! [`Sender`]: crate::sync::mpsc::Sender //! [`Receiver`]: crate::sync::mpsc::Receiver +//! [bounded-send]: crate::sync::mpsc::Sender::send() +//! [bounded-recv]: crate::sync::mpsc::Receiver::recv() +//! [`UnboundedSender`]: crate::sync::mpsc::UnboundedSender +//! [`Handle::block_on`]: crate::runtime::Handle::block_on() +//! [std-unbounded]: std::sync::mpsc::channel +//! [crossbeam-unbounded]: https://docs.rs/crossbeam/*/crossbeam/channel/fn.unbounded.html pub(super) mod block; diff --git a/tokio/src/sync/mpsc/unbounded.rs b/tokio/src/sync/mpsc/unbounded.rs index ba543fe4c87..1b2288ab08c 100644 --- a/tokio/src/sync/mpsc/unbounded.rs +++ b/tokio/src/sync/mpsc/unbounded.rs @@ -163,9 +163,13 @@ impl UnboundedSender { /// Attempts to send a message on this `UnboundedSender` without blocking. /// + /// This method is not marked async because sending a message to an unbounded channel + /// never requires any form of waiting. Because of this, the `send` method can be + /// used in both synchronous and asynchronous code without problems. + /// /// If the receive half of the channel is closed, either due to [`close`] - /// being called or the [`UnboundedReceiver`] having been dropped, - /// the function returns an error. The error includes the value passed to `send`. + /// being called or the [`UnboundedReceiver`] having been dropped, this + /// function returns an error. The error includes the value passed to `send`. /// /// [`close`]: UnboundedReceiver::close /// [`UnboundedReceiver`]: UnboundedReceiver diff --git a/tokio/src/sync/mutex.rs b/tokio/src/sync/mutex.rs index 6b51b405681..642058be626 100644 --- a/tokio/src/sync/mutex.rs +++ b/tokio/src/sync/mutex.rs @@ -1,120 +1,164 @@ -//! An asynchronous `Mutex`-like type. -//! -//! This module provides [`Mutex`], a type that acts similarly to an asynchronous `Mutex`, with one -//! major difference: the [`MutexGuard`] returned by `lock` is not tied to the lifetime of the -//! `Mutex`. This enables you to acquire a lock, and then pass that guard into a future, and then -//! release it at some later point in time. -//! -//! This allows you to do something along the lines of: -//! -//! ```rust,no_run -//! use tokio::sync::Mutex; -//! use std::sync::Arc; -//! -//! #[tokio::main] -//! async fn main() { -//! let data1 = Arc::new(Mutex::new(0)); -//! let data2 = Arc::clone(&data1); -//! -//! tokio::spawn(async move { -//! let mut lock = data2.lock().await; -//! *lock += 1; -//! }); -//! -//! let mut lock = data1.lock().await; -//! *lock += 1; -//! } -//! ``` -//! -//! Another example -//! ```rust,no_run -//! #![warn(rust_2018_idioms)] -//! -//! use tokio::sync::Mutex; -//! use std::sync::Arc; -//! -//! -//! #[tokio::main] -//! async fn main() { -//! let count = Arc::new(Mutex::new(0)); -//! -//! for _ in 0..5 { -//! let my_count = Arc::clone(&count); -//! tokio::spawn(async move { -//! for _ in 0..10 { -//! let mut lock = my_count.lock().await; -//! *lock += 1; -//! println!("{}", lock); -//! } -//! }); -//! } -//! -//! loop { -//! if *count.lock().await >= 50 { -//! break; -//! } -//! } -//! println!("Count hit 50."); -//! } -//! ``` -//! There are a few things of note here to pay attention to in this example. -//! 1. The mutex is wrapped in an [`std::sync::Arc`] to allow it to be shared across threads. -//! 2. Each spawned task obtains a lock and releases it on every iteration. -//! 3. Mutation of the data the Mutex is protecting is done by de-referencing the the obtained lock -//! as seen on lines 23 and 30. -//! -//! Tokio's Mutex works in a simple FIFO (first in, first out) style where as requests for a lock are -//! made Tokio will queue them up and provide a lock when it is that requester's turn. In that way -//! the Mutex is "fair" and predictable in how it distributes the locks to inner data. This is why -//! the output of this program is an in-order count to 50. Locks are released and reacquired -//! after every iteration, so basically, each thread goes to the back of the line after it increments -//! the value once. Also, since there is only a single valid lock at any given time there is no -//! possibility of a race condition when mutating the inner value. -//! -//! Note that in contrast to `std::sync::Mutex`, this implementation does not -//! poison the mutex when a thread holding the `MutexGuard` panics. In such a -//! case, the mutex will be unlocked. If the panic is caught, this might leave -//! the data protected by the mutex in an inconsistent state. -//! -//! [`Mutex`]: struct.Mutex.html -//! [`MutexGuard`]: struct.MutexGuard.html -use crate::coop::CoopFutureExt; use crate::sync::batch_semaphore as semaphore; use std::cell::UnsafeCell; use std::error::Error; use std::fmt; use std::ops::{Deref, DerefMut}; +use std::sync::Arc; -/// An asynchronous mutual exclusion primitive useful for protecting shared data +/// An asynchronous `Mutex`-like type. /// -/// Each mutex has a type parameter (`T`) which represents the data that it is protecting. The data -/// can only be accessed through the RAII guards returned from `lock`, which -/// guarantees that the data is only ever accessed when the mutex is locked. +/// This type acts similarly to an asynchronous [`std::sync::Mutex`], with one +/// major difference: [`lock`] does not block and the lock guard can be held +/// across await points. +/// +/// # Which kind of mutex should you use? +/// +/// Contrary to popular belief, it is ok and often preferred to use the ordinary +/// [`Mutex`][std] from the standard library in asynchronous code. This section +/// will help you decide on which kind of mutex you should use. +/// +/// The primary use case of the async mutex is to provide shared mutable access +/// to IO resources such as a database connection. If the data stored behind the +/// mutex is just data, it is often better to use a blocking mutex such as the +/// one in the standard library or [`parking_lot`]. This is because the feature +/// that the async mutex offers over the blocking mutex is that it is possible +/// to keep the mutex locked across an `.await` point, which is rarely necessary +/// for data. +/// +/// A common pattern is to wrap the `Arc>` in a struct that provides +/// non-async methods for performing operations on the data within, and only +/// lock the mutex inside these methods. The [mini-redis] example provides an +/// illustration of this pattern. +/// +/// Additionally, when you _do_ want shared access to an IO resource, it is +/// often better to spawn a task to manage the IO resource, and to use message +/// passing to communicate with that task. +/// +/// [std]: std::sync::Mutex +/// [`parking_lot`]: https://docs.rs/parking_lot +/// [mini-redis]: https://github.com/tokio-rs/mini-redis/blob/master/src/db.rs +/// +/// # Examples: +/// +/// ```rust,no_run +/// use tokio::sync::Mutex; +/// use std::sync::Arc; +/// +/// #[tokio::main] +/// async fn main() { +/// let data1 = Arc::new(Mutex::new(0)); +/// let data2 = Arc::clone(&data1); +/// +/// tokio::spawn(async move { +/// let mut lock = data2.lock().await; +/// *lock += 1; +/// }); +/// +/// let mut lock = data1.lock().await; +/// *lock += 1; +/// } +/// ``` +/// +/// +/// ```rust,no_run +/// use tokio::sync::Mutex; +/// use std::sync::Arc; +/// +/// #[tokio::main] +/// async fn main() { +/// let count = Arc::new(Mutex::new(0)); +/// +/// for _ in 0..5 { +/// let my_count = Arc::clone(&count); +/// tokio::spawn(async move { +/// for _ in 0..10 { +/// let mut lock = my_count.lock().await; +/// *lock += 1; +/// println!("{}", lock); +/// } +/// }); +/// } +/// +/// loop { +/// if *count.lock().await >= 50 { +/// break; +/// } +/// } +/// println!("Count hit 50."); +/// } +/// ``` +/// There are a few things of note here to pay attention to in this example. +/// 1. The mutex is wrapped in an [`Arc`] to allow it to be shared across +/// threads. +/// 2. Each spawned task obtains a lock and releases it on every iteration. +/// 3. Mutation of the data protected by the Mutex is done by de-referencing +/// the obtained lock as seen on lines 12 and 19. +/// +/// Tokio's Mutex works in a simple FIFO (first in, first out) style where all +/// calls to [`lock`] complete in the order they were performed. In that way the +/// Mutex is "fair" and predictable in how it distributes the locks to inner +/// data. This is why the output of the program above is an in-order count to +/// 50. Locks are released and reacquired after every iteration, so basically, +/// each thread goes to the back of the line after it increments the value once. +/// Finally, since there is only a single valid lock at any given time, there is +/// no possibility of a race condition when mutating the inner value. +/// +/// Note that in contrast to [`std::sync::Mutex`], this implementation does not +/// poison the mutex when a thread holding the [`MutexGuard`] panics. In such a +/// case, the mutex will be unlocked. If the panic is caught, this might leave +/// the data protected by the mutex in an inconsistent state. +/// +/// [`Mutex`]: struct@Mutex +/// [`MutexGuard`]: struct@MutexGuard +/// [`Arc`]: struct@std::sync::Arc +/// [`std::sync::Mutex`]: struct@std::sync::Mutex +/// [`Send`]: trait@std::marker::Send +/// [`lock`]: method@Mutex::lock #[derive(Debug)] -pub struct Mutex { - c: UnsafeCell, +pub struct Mutex { s: semaphore::Semaphore, + c: UnsafeCell, } /// A handle to a held `Mutex`. /// -/// As long as you have this guard, you have exclusive access to the underlying `T`. The guard -/// internally keeps a reference-couned pointer to the original `Mutex`, so even if the lock goes -/// away, the guard remains valid. +/// As long as you have this guard, you have exclusive access to the underlying +/// `T`. The guard internally borrows the `Mutex`, so the mutex will not be +/// dropped while a guard exists. /// -/// The lock is automatically released whenever the guard is dropped, at which point `lock` -/// will succeed yet again. -pub struct MutexGuard<'a, T> { +/// The lock is automatically released whenever the guard is dropped, at which +/// point `lock` will succeed yet again. +pub struct MutexGuard<'a, T: ?Sized> { lock: &'a Mutex, } +/// An owned handle to a held `Mutex`. +/// +/// This guard is only available from a `Mutex` that is wrapped in an [`Arc`]. It +/// is identical to `MutexGuard`, except that rather than borrowing the `Mutex`, +/// it clones the `Arc`, incrementing the reference count. This means that +/// unlike `MutexGuard`, it will have the `'static` lifetime. +/// +/// As long as you have this guard, you have exclusive access to the underlying +/// `T`. The guard internally keeps a reference-couned pointer to the original +/// `Mutex`, so even if the lock goes away, the guard remains valid. +/// +/// The lock is automatically released whenever the guard is dropped, at which +/// point `lock` will succeed yet again. +/// +/// [`Arc`]: std::sync::Arc +pub struct OwnedMutexGuard { + lock: Arc>, +} + // As long as T: Send, it's fine to send and share Mutex between threads. -// If T was not Send, sending and sharing a Mutex would be bad, since you can access T through -// Mutex. -unsafe impl Send for Mutex where T: Send {} -unsafe impl Sync for Mutex where T: Send {} -unsafe impl<'a, T> Sync for MutexGuard<'a, T> where T: Send + Sync {} +// If T was not Send, sending and sharing a Mutex would be bad, since you can +// access T through Mutex. +unsafe impl Send for Mutex where T: ?Sized + Send {} +unsafe impl Sync for Mutex where T: ?Sized + Send {} +unsafe impl Sync for MutexGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl Sync for OwnedMutexGuard where T: ?Sized + Send + Sync {} /// Error returned from the [`Mutex::try_lock`] function. /// @@ -126,7 +170,7 @@ pub struct TryLockError(()); impl fmt::Display for TryLockError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "{}", "operation would block") + write!(fmt, "operation would block") } } @@ -140,34 +184,120 @@ fn bounds() { // This has to take a value, since the async fn's return type is unnameable. fn check_send_sync_val(_t: T) {} fn check_send_sync() {} + fn check_static() {} + fn check_static_val(_t: T) {} + check_send::>(); + check_send::>(); check_unpin::>(); check_send_sync::>(); + check_static::>(); let mutex = Mutex::new(1); check_send_sync_val(mutex.lock()); + let arc_mutex = Arc::new(Mutex::new(1)); + check_send_sync_val(arc_mutex.clone().lock_owned()); + check_static_val(arc_mutex.lock_owned()); } -impl Mutex { +impl Mutex { /// Creates a new lock in an unlocked state ready for use. - pub fn new(t: T) -> Self { + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// let lock = Mutex::new(5); + /// ``` + pub fn new(t: T) -> Self + where + T: Sized, + { Self { c: UnsafeCell::new(t), s: semaphore::Semaphore::new(1), } } - /// A future that resolves on acquiring the lock and returns the `MutexGuard`. + /// Locks this mutex, causing the current task + /// to yield until the lock has been acquired. + /// When the lock has been acquired, function returns a [`MutexGuard`]. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Mutex::new(1); + /// + /// let mut n = mutex.lock().await; + /// *n = 2; + /// } + /// ``` pub async fn lock(&self) -> MutexGuard<'_, T> { - self.s.acquire(1).cooperate().await.unwrap_or_else(|_| { - // The semaphore was closed. but, we never explicitly close it, and we have a - // handle to it through the Arc, which means that this can never happen. + self.acquire().await; + MutexGuard { lock: self } + } + + /// Locks this mutex, causing the current task to yield until the lock has + /// been acquired. When the lock has been acquired, this returns an + /// [`OwnedMutexGuard`]. + /// + /// This method is identical to [`Mutex::lock`], except that the returned + /// guard references the `Mutex` with an [`Arc`] rather than by borrowing + /// it. Therefore, the `Mutex` must be wrapped in an `Arc` to call this + /// method, and the guard will live for the `'static` lifetime, as it keeps + /// the `Mutex` alive by holding an `Arc`. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// use std::sync::Arc; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Arc::new(Mutex::new(1)); + /// + /// let mut n = mutex.clone().lock_owned().await; + /// *n = 2; + /// } + /// ``` + /// + /// [`Arc`]: std::sync::Arc + pub async fn lock_owned(self: Arc) -> OwnedMutexGuard { + self.acquire().await; + OwnedMutexGuard { lock: self } + } + + async fn acquire(&self) { + self.s.acquire(1).await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and + // we own it exclusively, which means that this can never happen. unreachable!() }); - MutexGuard { lock: self } } - /// Tries to acquire the lock + /// Attempts to acquire the lock, and returns [`TryLockError`] if the + /// lock is currently held somewhere else. + /// + /// [`TryLockError`]: TryLockError + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// # async fn dox() -> Result<(), tokio::sync::TryLockError> { + /// + /// let mutex = Mutex::new(1); + /// + /// let n = mutex.try_lock()?; + /// assert_eq!(*n, 1); + /// # Ok(()) + /// # } + /// ``` pub fn try_lock(&self) -> Result, TryLockError> { match self.s.try_acquire(1) { Ok(_) => Ok(MutexGuard { lock: self }), @@ -175,15 +305,56 @@ impl Mutex { } } - /// Consumes the mutex, returning the underlying data. - pub fn into_inner(self) -> T { - self.c.into_inner() + /// Attempts to acquire the lock, and returns [`TryLockError`] if the lock + /// is currently held somewhere else. + /// + /// This method is identical to [`Mutex::try_lock`], except that the + /// returned guard references the `Mutex` with an [`Arc`] rather than by + /// borrowing it. Therefore, the `Mutex` must be wrapped in an `Arc` to call + /// this method, and the guard will live for the `'static` lifetime, as it + /// keeps the `Mutex` alive by holding an `Arc`. + /// + /// [`TryLockError`]: TryLockError + /// [`Arc`]: std::sync::Arc + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// use std::sync::Arc; + /// # async fn dox() -> Result<(), tokio::sync::TryLockError> { + /// + /// let mutex = Arc::new(Mutex::new(1)); + /// + /// let n = mutex.clone().try_lock_owned()?; + /// assert_eq!(*n, 1); + /// # Ok(()) + /// # } + pub fn try_lock_owned(self: Arc) -> Result, TryLockError> { + match self.s.try_acquire(1) { + Ok(_) => Ok(OwnedMutexGuard { lock: self }), + Err(_) => Err(TryLockError(())), + } } -} -impl<'a, T> Drop for MutexGuard<'a, T> { - fn drop(&mut self) { - self.lock.s.release(1) + /// Consumes the mutex, returning the underlying data. + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Mutex::new(1); + /// + /// let n = mutex.into_inner(); + /// assert_eq!(n, 1); + /// } + /// ``` + pub fn into_inner(self) -> T + where + T: Sized, + { + self.c.into_inner() } } @@ -202,26 +373,67 @@ where } } -impl<'a, T> Deref for MutexGuard<'a, T> { +// === impl MutexGuard === + +impl Drop for MutexGuard<'_, T> { + fn drop(&mut self) { + self.lock.s.release(1) + } +} + +impl Deref for MutexGuard<'_, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + unsafe { &*self.lock.c.get() } + } +} + +impl DerefMut for MutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.lock.c.get() } + } +} + +impl fmt::Debug for MutexGuard<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl fmt::Display for MutexGuard<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +// === impl OwnedMutexGuard === + +impl Drop for OwnedMutexGuard { + fn drop(&mut self) { + self.lock.s.release(1) + } +} + +impl Deref for OwnedMutexGuard { type Target = T; fn deref(&self) -> &Self::Target { unsafe { &*self.lock.c.get() } } } -impl<'a, T> DerefMut for MutexGuard<'a, T> { +impl DerefMut for OwnedMutexGuard { fn deref_mut(&mut self) -> &mut Self::Target { unsafe { &mut *self.lock.c.get() } } } -impl<'a, T: fmt::Debug> fmt::Debug for MutexGuard<'a, T> { +impl fmt::Debug for OwnedMutexGuard { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(&**self, f) } } -impl<'a, T: fmt::Display> fmt::Display for MutexGuard<'a, T> { +impl fmt::Display for OwnedMutexGuard { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&**self, f) } diff --git a/tokio/src/sync/oneshot.rs b/tokio/src/sync/oneshot.rs index 644832b96ce..17767e7f7f8 100644 --- a/tokio/src/sync/oneshot.rs +++ b/tokio/src/sync/oneshot.rs @@ -16,7 +16,7 @@ use std::task::{Context, Poll, Waker}; /// Sends a value to the associated `Receiver`. /// -/// Instances are created by the [`channel`](fn.channel.html) function. +/// Instances are created by the [`channel`](fn@channel) function. #[derive(Debug)] pub struct Sender { inner: Option>>, @@ -24,7 +24,7 @@ pub struct Sender { /// Receive a value from the associated `Sender`. /// -/// Instances are created by the [`channel`](fn.channel.html) function. +/// Instances are created by the [`channel`](fn@channel) function. #[derive(Debug)] pub struct Receiver { inner: Option>>, @@ -144,8 +144,11 @@ impl Sender { /// Attempts to send a value on this channel, returning it back if it could /// not be sent. /// - /// The function consumes `self` as only one value may ever be sent on a - /// one-shot channel. + /// This method consumes `self` as only one value may ever be sent on a oneshot + /// channel. It is not marked async because sending a message to an oneshot + /// channel never requires any form of waiting. Because of this, the `send` + /// method can be used in both synchronous and asynchronous code without + /// problems. /// /// A successful send occurs when it is determined that the other end of the /// channel has not hung up already. An unsuccessful send would be one where @@ -197,13 +200,14 @@ impl Sender { #[doc(hidden)] // TODO: remove pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); let inner = self.inner.as_ref().unwrap(); let mut state = State::load(&inner.state, Acquire); if state.is_closed() { + coop.made_progress(); return Poll::Ready(()); } @@ -216,6 +220,7 @@ impl Sender { if state.is_closed() { // Set the flag again so that the waker is released in drop State::set_tx_task(&inner.state); + coop.made_progress(); return Ready(()); } else { unsafe { inner.drop_tx_task() }; @@ -233,6 +238,7 @@ impl Sender { state = State::set_tx_task(&inner.state); if state.is_closed() { + coop.made_progress(); return Ready(()); } } @@ -360,7 +366,7 @@ impl Receiver { /// Prevents the associated [`Sender`] handle from sending a value. /// /// Any `send` operation which happens after calling `close` is guaranteed - /// to fail. After calling `close`, `Receiver::poll`] should be called to + /// to fail. After calling `close`, [`try_recv`] should be called to /// receive a value if one was sent **before** the call to `close` /// completed. /// @@ -368,6 +374,7 @@ impl Receiver { /// value will not be sent into the channel and never received. /// /// [`Sender`]: Sender + /// [`try_recv`]: Receiver::try_recv /// /// # Examples /// @@ -548,17 +555,19 @@ impl Inner { fn poll_recv(&self, cx: &mut Context<'_>) -> Poll> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); // Load the state let mut state = State::load(&self.state, Acquire); if state.is_complete() { + coop.made_progress(); match unsafe { self.consume_value() } { Some(value) => Ready(Ok(value)), None => Ready(Err(RecvError(()))), } } else if state.is_closed() { + coop.made_progress(); Ready(Err(RecvError(()))) } else { if state.is_rx_task_set() { @@ -572,6 +581,7 @@ impl Inner { // Set the flag again so that the waker is released in drop State::set_rx_task(&self.state); + coop.made_progress(); return match unsafe { self.consume_value() } { Some(value) => Ready(Ok(value)), None => Ready(Err(RecvError(()))), @@ -592,6 +602,7 @@ impl Inner { state = State::set_rx_task(&self.state); if state.is_complete() { + coop.made_progress(); match unsafe { self.consume_value() } { Some(value) => Ready(Ok(value)), None => Ready(Err(RecvError(()))), diff --git a/tokio/src/sync/rwlock.rs b/tokio/src/sync/rwlock.rs index 0f7991a5bf8..f6cbd2a0fba 100644 --- a/tokio/src/sync/rwlock.rs +++ b/tokio/src/sync/rwlock.rs @@ -1,4 +1,3 @@ -use crate::coop::CoopFutureExt; use crate::sync::batch_semaphore::{AcquireError, Semaphore}; use std::cell::UnsafeCell; use std::ops; @@ -32,8 +31,8 @@ const MAX_READS: usize = 10; /// /// The type parameter `T` represents the data that this lock protects. It is /// required that `T` satisfies [`Send`] to be shared across threads. The RAII guards -/// returned from the locking methods implement [`Deref`](https://doc.rust-lang.org/std/ops/trait.Deref.html) -/// (and [`DerefMut`](https://doc.rust-lang.org/std/ops/trait.DerefMut.html) +/// returned from the locking methods implement [`Deref`](trait@std::ops::Deref) +/// (and [`DerefMut`](trait@std::ops::DerefMut) /// for the `write` methods) to allow access to the content of the lock. /// /// # Examples @@ -62,14 +61,14 @@ const MAX_READS: usize = 10; /// } /// ``` /// -/// [`Mutex`]: struct.Mutex.html -/// [`RwLock`]: struct.RwLock.html -/// [`RwLockReadGuard`]: struct.RwLockReadGuard.html -/// [`RwLockWriteGuard`]: struct.RwLockWriteGuard.html -/// [`Send`]: https://doc.rust-lang.org/std/marker/trait.Send.html +/// [`Mutex`]: struct@super::Mutex +/// [`RwLock`]: struct@RwLock +/// [`RwLockReadGuard`]: struct@RwLockReadGuard +/// [`RwLockWriteGuard`]: struct@RwLockWriteGuard +/// [`Send`]: trait@std::marker::Send /// [_write-preferring_]: https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock#Priority_policies #[derive(Debug)] -pub struct RwLock { +pub struct RwLock { //semaphore to coordinate read and write access to T s: Semaphore, @@ -83,9 +82,9 @@ pub struct RwLock { /// This structure is created by the [`read`] method on /// [`RwLock`]. /// -/// [`read`]: struct.RwLock.html#method.read +/// [`read`]: method@RwLock::read #[derive(Debug)] -pub struct RwLockReadGuard<'a, T> { +pub struct RwLockReadGuard<'a, T: ?Sized> { permit: ReleasingPermit<'a, T>, lock: &'a RwLock, } @@ -96,32 +95,32 @@ pub struct RwLockReadGuard<'a, T> { /// This structure is created by the [`write`] and method /// on [`RwLock`]. /// -/// [`write`]: struct.RwLock.html#method.write -/// [`RwLock`]: struct.RwLock.html +/// [`write`]: method@RwLock::write +/// [`RwLock`]: struct@RwLock #[derive(Debug)] -pub struct RwLockWriteGuard<'a, T> { +pub struct RwLockWriteGuard<'a, T: ?Sized> { permit: ReleasingPermit<'a, T>, lock: &'a RwLock, } // Wrapper arround Permit that releases on Drop #[derive(Debug)] -struct ReleasingPermit<'a, T> { +struct ReleasingPermit<'a, T: ?Sized> { num_permits: u16, lock: &'a RwLock, } -impl<'a, T> ReleasingPermit<'a, T> { +impl<'a, T: ?Sized> ReleasingPermit<'a, T> { async fn acquire( lock: &'a RwLock, num_permits: u16, ) -> Result, AcquireError> { - lock.s.acquire(num_permits).cooperate().await?; + lock.s.acquire(num_permits).await?; Ok(Self { num_permits, lock }) } } -impl<'a, T> Drop for ReleasingPermit<'a, T> { +impl Drop for ReleasingPermit<'_, T> { fn drop(&mut self) { self.lock.s.release(self.num_permits as usize); } @@ -154,12 +153,12 @@ fn bounds() { // As long as T: Send + Sync, it's fine to send and share RwLock between threads. // If T were not Send, sending and sharing a RwLock would be bad, since you can access T through // RwLock. -unsafe impl Send for RwLock where T: Send {} -unsafe impl Sync for RwLock where T: Send + Sync {} -unsafe impl<'a, T> Sync for RwLockReadGuard<'a, T> where T: Send + Sync {} -unsafe impl<'a, T> Sync for RwLockWriteGuard<'a, T> where T: Send + Sync {} +unsafe impl Send for RwLock where T: ?Sized + Send {} +unsafe impl Sync for RwLock where T: ?Sized + Send + Sync {} +unsafe impl Sync for RwLockReadGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl Sync for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {} -impl RwLock { +impl RwLock { /// Creates a new instance of an `RwLock` which is unlocked. /// /// # Examples @@ -169,7 +168,10 @@ impl RwLock { /// /// let lock = RwLock::new(5); /// ``` - pub fn new(value: T) -> RwLock { + pub fn new(value: T) -> RwLock + where + T: Sized, + { RwLock { c: UnsafeCell::new(value), s: Semaphore::new(MAX_READS), @@ -251,12 +253,15 @@ impl RwLock { } /// Consumes the lock, returning the underlying data. - pub fn into_inner(self) -> T { + pub fn into_inner(self) -> T + where + T: Sized, + { self.c.into_inner() } } -impl ops::Deref for RwLockReadGuard<'_, T> { +impl ops::Deref for RwLockReadGuard<'_, T> { type Target = T; fn deref(&self) -> &T { @@ -264,7 +269,7 @@ impl ops::Deref for RwLockReadGuard<'_, T> { } } -impl ops::Deref for RwLockWriteGuard<'_, T> { +impl ops::Deref for RwLockWriteGuard<'_, T> { type Target = T; fn deref(&self) -> &T { @@ -272,7 +277,7 @@ impl ops::Deref for RwLockWriteGuard<'_, T> { } } -impl ops::DerefMut for RwLockWriteGuard<'_, T> { +impl ops::DerefMut for RwLockWriteGuard<'_, T> { fn deref_mut(&mut self) -> &mut T { unsafe { &mut *self.lock.c.get() } } diff --git a/tokio/src/sync/semaphore.rs b/tokio/src/sync/semaphore.rs index 4cce7e8f5bc..2489d34aaaf 100644 --- a/tokio/src/sync/semaphore.rs +++ b/tokio/src/sync/semaphore.rs @@ -1,5 +1,5 @@ use super::batch_semaphore as ll; // low level implementation -use crate::coop::CoopFutureExt; +use std::sync::Arc; /// Counting semaphore performing asynchronous permit aquisition. /// @@ -18,7 +18,11 @@ pub struct Semaphore { ll_sem: ll::Semaphore, } -/// A permit from the semaphore +/// A permit from the semaphore. +/// +/// This type is created by the [`acquire`] method. +/// +/// [`acquire`]: crate::sync::Semaphore::acquire() #[must_use] #[derive(Debug)] pub struct SemaphorePermit<'a> { @@ -26,6 +30,18 @@ pub struct SemaphorePermit<'a> { permits: u16, } +/// An owned permit from the semaphore. +/// +/// This type is created by the [`acquire_owned`] method. +/// +/// [`acquire_owned`]: crate::sync::Semaphore::acquire_owned() +#[must_use] +#[derive(Debug)] +pub struct OwnedSemaphorePermit { + sem: Arc, + permits: u16, +} + /// Error returned from the [`Semaphore::try_acquire`] function. /// /// A `try_acquire` operation can only fail if the semaphore has no available @@ -51,33 +67,35 @@ fn bounds() { } impl Semaphore { - /// Creates a new semaphore with the initial number of permits + /// Creates a new semaphore with the initial number of permits. pub fn new(permits: usize) -> Self { Self { ll_sem: ll::Semaphore::new(permits), } } - /// Returns the current number of available permits + /// Returns the current number of available permits. pub fn available_permits(&self) -> usize { self.ll_sem.available_permits() } /// Adds `n` new permits to the semaphore. + /// + /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. pub fn add_permits(&self, n: usize) { self.ll_sem.release(n); } - /// Acquires permit from the semaphore + /// Acquires permit from the semaphore. pub async fn acquire(&self) -> SemaphorePermit<'_> { - self.ll_sem.acquire(1).cooperate().await.unwrap(); + self.ll_sem.acquire(1).await.unwrap(); SemaphorePermit { sem: &self, permits: 1, } } - /// Tries to acquire a permit form the semaphore + /// Tries to acquire a permit from the semaphore. pub fn try_acquire(&self) -> Result, TryAcquireError> { match self.ll_sem.try_acquire(1) { Ok(_) => Ok(SemaphorePermit { @@ -87,6 +105,34 @@ impl Semaphore { Err(_) => Err(TryAcquireError(())), } } + + /// Acquires permit from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. + /// + /// [`Arc`]: std::sync::Arc + pub async fn acquire_owned(self: Arc) -> OwnedSemaphorePermit { + self.ll_sem.acquire(1).await.unwrap(); + OwnedSemaphorePermit { + sem: self.clone(), + permits: 1, + } + } + + /// Tries to acquire a permit from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. + /// + /// [`Arc`]: std::sync::Arc + pub fn try_acquire_owned(self: Arc) -> Result { + match self.ll_sem.try_acquire(1) { + Ok(_) => Ok(OwnedSemaphorePermit { + sem: self.clone(), + permits: 1, + }), + Err(_) => Err(TryAcquireError(())), + } + } } impl<'a> SemaphorePermit<'a> { @@ -98,8 +144,23 @@ impl<'a> SemaphorePermit<'a> { } } +impl OwnedSemaphorePermit { + /// Forgets the permit **without** releasing it back to the semaphore. + /// This can be used to reduce the amount of permits available from a + /// semaphore. + pub fn forget(mut self) { + self.permits = 0; + } +} + impl<'a> Drop for SemaphorePermit<'_> { fn drop(&mut self) { self.sem.add_permits(self.permits as usize); } } + +impl Drop for OwnedSemaphorePermit { + fn drop(&mut self) { + self.sem.add_permits(self.permits as usize); + } +} diff --git a/tokio/src/sync/semaphore_ll.rs b/tokio/src/sync/semaphore_ll.rs index 0bdc4e27617..25d25ac88ab 100644 --- a/tokio/src/sync/semaphore_ll.rs +++ b/tokio/src/sync/semaphore_ll.rs @@ -333,8 +333,9 @@ impl Semaphore { self.add_permits_locked(0, true); } - /// Adds `n` new permits to the semaphore. + /// + /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. pub(crate) fn add_permits(&self, n: usize) { if n == 0 { return; @@ -749,7 +750,7 @@ impl Permit { /// Forgets the permit **without** releasing it back to the semaphore. /// /// After calling `forget`, `poll_acquire` is able to acquire new permit - /// from the sempahore. + /// from the semaphore. /// /// Repeatedly calling `forget` without associated calls to `add_permit` /// will result in the semaphore losing all permits. @@ -853,8 +854,8 @@ impl TryAcquireError { impl fmt::Display for TryAcquireError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - TryAcquireError::Closed => write!(fmt, "{}", "semaphore closed"), - TryAcquireError::NoPermits => write!(fmt, "{}", "no permits available"), + TryAcquireError::Closed => write!(fmt, "semaphore closed"), + TryAcquireError::NoPermits => write!(fmt, "no permits available"), } } } diff --git a/tokio/src/sync/tests/loom_cancellation_token.rs b/tokio/src/sync/tests/loom_cancellation_token.rs new file mode 100644 index 00000000000..e9c9f3dd980 --- /dev/null +++ b/tokio/src/sync/tests/loom_cancellation_token.rs @@ -0,0 +1,155 @@ +use crate::sync::CancellationToken; + +use loom::{future::block_on, thread}; +use tokio_test::assert_ok; + +#[test] +fn cancel_token() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + + let th1 = thread::spawn(move || { + block_on(async { + token1.cancelled().await; + }); + }); + + let th2 = thread::spawn(move || { + token.cancel(); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +#[test] +fn cancel_with_child() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + let token2 = token.clone(); + let child_token = token.child_token(); + + let th1 = thread::spawn(move || { + block_on(async { + token1.cancelled().await; + }); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + block_on(async { + child_token.cancelled().await; + }); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_token_no_child() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + let token2 = token.clone(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + drop(token2); + }); + + let th3 = thread::spawn(move || { + drop(token); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_token_with_childs() { + loom::model(|| { + let token1 = CancellationToken::new(); + let child_token1 = token1.child_token(); + let child_token2 = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + drop(child_token1); + }); + + let th3 = thread::spawn(move || { + drop(child_token2); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_and_cancel_token() { + loom::model(|| { + let token1 = CancellationToken::new(); + let token2 = token1.clone(); + let child_token = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + drop(child_token); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn cancel_parent_and_child() { + loom::model(|| { + let token1 = CancellationToken::new(); + let token2 = token1.clone(); + let child_token = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + child_token.cancel(); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} diff --git a/tokio/src/sync/tests/mod.rs b/tokio/src/sync/tests/mod.rs index d571754c011..6ba8c1f9b6a 100644 --- a/tokio/src/sync/tests/mod.rs +++ b/tokio/src/sync/tests/mod.rs @@ -7,6 +7,8 @@ cfg_not_loom! { cfg_loom! { mod loom_atomic_waker; mod loom_broadcast; + #[cfg(tokio_unstable)] + mod loom_cancellation_token; mod loom_list; mod loom_mpsc; mod loom_notify; diff --git a/tokio/src/sync/tests/semaphore_batch.rs b/tokio/src/sync/tests/semaphore_batch.rs index 60f3f231e76..9342cd1cb3c 100644 --- a/tokio/src/sync/tests/semaphore_batch.rs +++ b/tokio/src/sync/tests/semaphore_batch.rs @@ -236,7 +236,7 @@ fn close_semaphore_notifies_permit2() { #[test] fn cancel_acquire_releases_permits() { let s = Semaphore::new(10); - let _permit1 = s.try_acquire(4).expect("uncontended try_acquire succeeds"); + s.try_acquire(4).expect("uncontended try_acquire succeeds"); assert_eq!(6, s.available_permits()); let mut acquire = task::spawn(s.acquire(8)); diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 402670f4511..13033d9e726 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -62,9 +62,9 @@ use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard, Weak}; use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; -/// Receives values from the associated [`Sender`](struct.Sender.html). +/// Receives values from the associated [`Sender`](struct@Sender). /// -/// Instances are created by the [`channel`](fn.channel.html) function. +/// Instances are created by the [`channel`](fn@channel) function. #[derive(Debug)] pub struct Receiver { /// Pointer to the shared state @@ -74,9 +74,9 @@ pub struct Receiver { inner: Watcher, } -/// Sends values to the associated [`Receiver`](struct.Receiver.html). +/// Sends values to the associated [`Receiver`](struct@Receiver). /// -/// Instances are created by the [`channel`](fn.channel.html) function. +/// Instances are created by the [`channel`](fn@channel) function. #[derive(Debug)] pub struct Sender { shared: Weak>, @@ -172,8 +172,8 @@ const CLOSED: usize = 1; /// # } /// ``` /// -/// [`Sender`]: struct.Sender.html -/// [`Receiver`]: struct.Receiver.html +/// [`Sender`]: struct@Sender +/// [`Receiver`]: struct@Receiver pub fn channel(init: T) -> (Sender, Receiver) { const VERSION_0: usize = 0; const VERSION_1: usize = 2; @@ -338,7 +338,6 @@ impl Sender { // Notify all watchers notify_all(&*shared); - // Return the old value Ok(()) } diff --git a/tokio/src/task/blocking.rs b/tokio/src/task/blocking.rs index 0069b10adaa..ed60f4c4734 100644 --- a/tokio/src/task/blocking.rs +++ b/tokio/src/task/blocking.rs @@ -1,20 +1,37 @@ use crate::task::JoinHandle; cfg_rt_threaded! { - /// Runs the provided blocking function without blocking the executor. + /// Runs the provided blocking function on the current thread without + /// blocking the executor. /// /// In general, issuing a blocking call or performing a lot of compute in a /// future without yielding is not okay, as it may prevent the executor from - /// driving other futures forward. If you run a closure through this method, - /// the current executor thread will relegate all its executor duties to another - /// (possibly new) thread, and only then poll the task. Note that this requires - /// additional synchronization. + /// driving other futures forward. This function runs the closure on the + /// current thread by having the thread temporarily cease from being a core + /// thread, and turns it into a blocking thread. See the [CPU-bound tasks + /// and blocking code][blocking] section for more information. /// - /// # Note + /// Although this function avoids starving other independently spawned + /// tasks, any other code running concurrently in the same task will be + /// suspended during the call to `block_in_place`. This can happen e.g. when + /// using the [`join!`] macro. To avoid this issue, use [`spawn_blocking`] + /// instead. /// - /// This function can only be called from a spawned task when working with - /// the [threaded scheduler](https://docs.rs/tokio/0.2.10/tokio/runtime/index.html#threaded-scheduler). - /// Consider using [tokio::task::spawn_blocking](https://docs.rs/tokio/0.2.10/tokio/task/fn.spawn_blocking.html). + /// Note that this function can only be used on the [threaded scheduler]. + /// + /// Code running behind `block_in_place` cannot be cancelled. When you shut + /// down the executor, it will wait indefinitely for all blocking operations + /// to finish. You can use [`shutdown_timeout`] to stop waiting for them + /// after a certain timeout. Be aware that this will still not cancel the + /// tasks — they are simply allowed to keep running after the method + /// returns. + /// + /// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code + /// [threaded scheduler]: fn@crate::runtime::Builder::threaded_scheduler + /// [`spawn_blocking`]: fn@crate::task::spawn_blocking + /// [`join!`]: macro@join + /// [`thread::spawn`]: fn@std::thread::spawn + /// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout /// /// # Examples /// @@ -32,19 +49,50 @@ cfg_rt_threaded! { where F: FnOnce() -> R, { - use crate::runtime::{enter, thread_pool}; - - enter::exit(|| thread_pool::block_in_place(f)) + crate::runtime::thread_pool::block_in_place(f) } } cfg_blocking! { /// Runs the provided closure on a thread where blocking is acceptable. /// - /// In general, issuing a blocking call or performing a lot of compute in a future without - /// yielding is not okay, as it may prevent the executor from driving other futures forward. - /// A closure that is run through this method will instead be run on a dedicated thread pool for - /// such blocking tasks without holding up the main futures executor. + /// In general, issuing a blocking call or performing a lot of compute in a + /// future without yielding is not okay, as it may prevent the executor from + /// driving other futures forward. This function runs the provided closure + /// on a thread dedicated to blocking operations. See the [CPU-bound tasks + /// and blocking code][blocking] section for more information. + /// + /// Tokio will spawn more blocking threads when they are requested through + /// this function until the upper limit configured on the [`Builder`] is + /// reached. This limit is very large by default, because `spawn_blocking` is + /// often used for various kinds of IO operations that cannot be performed + /// asynchronously. When you run CPU-bound code using `spawn_blocking`, you + /// should keep this large upper limit in mind; to run your CPU-bound + /// computations on only a few threads, you should use a separate thread + /// pool such as [rayon] rather than configuring the number of blocking + /// threads. + /// + /// This function is intended for non-async operations that eventually + /// finish on their own. If you want to spawn an ordinary thread, you should + /// use [`thread::spawn`] instead. + /// + /// Closures spawned using `spawn_blocking` cannot be cancelled. When you + /// shut down the executor, it will wait indefinitely for all blocking + /// operations to finish. You can use [`shutdown_timeout`] to stop waiting + /// for them after a certain timeout. Be aware that this will still not + /// cancel the tasks — they are simply allowed to keep running after the + /// method returns. + /// + /// Note that if you are using the [basic scheduler], this function will + /// still spawn additional threads for blocking operations. The basic + /// scheduler's single thread is only used for asynchronous code. + /// + /// [`Builder`]: struct@crate::runtime::Builder + /// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code + /// [rayon]: https://docs.rs/rayon + /// [basic scheduler]: fn@crate::runtime::Builder::basic_scheduler + /// [`thread::spawn`]: fn@std::thread::spawn + /// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout /// /// # Examples /// @@ -66,6 +114,19 @@ cfg_blocking! { F: FnOnce() -> R + Send + 'static, R: Send + 'static, { + #[cfg(feature = "tracing")] + let f = { + let span = tracing::trace_span!( + target: "tokio::task", + "task", + kind = %"blocking", + function = %std::any::type_name::(), + ); + move || { + let _g = span.enter(); + f() + } + }; crate::runtime::spawn_blocking(f) } } diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index fcb8c789237..3c409edfb90 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -7,6 +7,7 @@ use std::cell::{Cell, RefCell}; use std::collections::VecDeque; use std::fmt; use std::future::Future; +use std::marker::PhantomData; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::Poll; @@ -104,16 +105,19 @@ cfg_rt_util! { /// } /// ``` /// - /// [`Send`]: https://doc.rust-lang.org/std/marker/trait.Send.html - /// [local task set]: struct.LocalSet.html - /// [`Runtime::block_on`]: ../struct.Runtime.html#method.block_on - /// [`task::spawn_local`]: fn.spawn.html + /// [`Send`]: trait@std::marker::Send + /// [local task set]: struct@LocalSet + /// [`Runtime::block_on`]: method@crate::runtime::Runtime::block_on + /// [`task::spawn_local`]: fn@spawn_local pub struct LocalSet { /// Current scheduler tick tick: Cell, /// State available from thread-local context: Context, + + /// This type should not be Send. + _not_send: PhantomData<*const ()>, } } @@ -191,6 +195,7 @@ cfg_rt_util! { F: Future + 'static, F::Output: 'static, { + let future = crate::util::trace::task(future, "local"); CURRENT.with(|maybe_cx| { let cx = maybe_cx .expect("`spawn_local` called from outside of a `task::LocalSet`"); @@ -228,6 +233,7 @@ impl LocalSet { waker: AtomicWaker::new(), }), }, + _not_send: PhantomData, } } @@ -266,12 +272,13 @@ impl LocalSet { /// }).await; /// } /// ``` - /// [`spawn_local`]: fn.spawn_local.html + /// [`spawn_local`]: fn@spawn_local pub fn spawn_local(&self, future: F) -> JoinHandle where F: Future + 'static, F::Output: 'static, { + let future = crate::util::trace::task(future, "local"); let (task, handle) = unsafe { task::joinable_local(future) }; self.context.tasks.borrow_mut().queue.push_back(task); handle @@ -335,10 +342,10 @@ impl LocalSet { /// }) /// ``` /// - /// [`spawn_local`]: fn.spawn_local.html - /// [`Runtime::block_on`]: ../struct.Runtime.html#method.block_on - /// [in-place blocking]: ../blocking/fn.in_place.html - /// [`spawn_blocking`]: ../blocking/fn.spawn_blocking.html + /// [`spawn_local`]: fn@spawn_local + /// [`Runtime::block_on`]: method@crate::runtime::Runtime::block_on + /// [in-place blocking]: fn@crate::task::block_in_place + /// [`spawn_blocking`]: fn@crate::task::spawn_blocking pub fn block_on(&self, rt: &mut crate::runtime::Runtime, future: F) -> F::Output where F: Future, @@ -372,7 +379,7 @@ impl LocalSet { /// } /// ``` /// - /// [`spawn_local`]: fn.spawn_local.html + /// [`spawn_local`]: fn@spawn_local /// [awaiting the local set]: #awaiting-a-localset pub async fn run_until(&self, future: F) -> F::Output where @@ -515,7 +522,10 @@ impl Future for RunUntil<'_, T> { .waker .register_by_ref(cx.waker()); - if let Poll::Ready(output) = me.future.poll(cx) { + let _no_blocking = crate::runtime::enter::disallow_blocking(); + let f = me.future; + + if let Poll::Ready(output) = crate::coop::budget(|| f.poll(cx)) { return Poll::Ready(output); } diff --git a/tokio/src/task/spawn.rs b/tokio/src/task/spawn.rs index fa5ff13b01e..d6e771184f2 100644 --- a/tokio/src/task/spawn.rs +++ b/tokio/src/task/spawn.rs @@ -129,6 +129,7 @@ doc_rt_core! { { let spawn_handle = runtime::context::spawn_handle() .expect("must be called from the context of Tokio runtime configured with either `basic_scheduler` or `threaded_scheduler`"); + let task = crate::util::trace::task(task, "task"); spawn_handle.spawn(task) } } diff --git a/tokio/src/task/task_local.rs b/tokio/src/task/task_local.rs index a3e7f03fe6e..1679ee3ba12 100644 --- a/tokio/src/task/task_local.rs +++ b/tokio/src/task/task_local.rs @@ -29,7 +29,7 @@ use std::{fmt, thread}; /// See [LocalKey documentation][`tokio::task::LocalKey`] for more /// information. /// -/// [`tokio::task::LocalKey`]: ../tokio/task/struct.LocalKey.html +/// [`tokio::task::LocalKey`]: struct@crate::task::LocalKey #[macro_export] macro_rules! task_local { // empty (base case for the recursion) @@ -49,7 +49,7 @@ macro_rules! task_local { #[macro_export] macro_rules! __task_local_inner { ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => { - static $name: $crate::task::LocalKey<$t> = { + $vis static $name: $crate::task::LocalKey<$t> = { std::thread_local! { static __KEY: std::cell::RefCell> = std::cell::RefCell::new(None); } @@ -89,7 +89,7 @@ macro_rules! __task_local_inner { /// }).await; /// # } /// ``` -/// [`std::thread::LocalKey`]: https://doc.rust-lang.org/std/thread/struct.LocalKey.html +/// [`std::thread::LocalKey`]: struct@std::thread::LocalKey pub struct LocalKey { #[doc(hidden)] pub inner: thread::LocalKey>>, @@ -219,7 +219,7 @@ impl Future for TaskLocalFuture { trait StaticLifetime: 'static {} impl StaticLifetime for T {} -/// An error returned by [`LocalKey::try_with`](struct.LocalKey.html#method.try_with). +/// An error returned by [`LocalKey::try_with`](method@LocalKey::try_with). #[derive(Clone, Copy, Eq, PartialEq)] pub struct AccessError { _private: (), diff --git a/tokio/src/time/clock.rs b/tokio/src/time/clock.rs index 4ac24af3d05..bd67d7a31d7 100644 --- a/tokio/src/time/clock.rs +++ b/tokio/src/time/clock.rs @@ -56,8 +56,9 @@ cfg_test_util! { /// Pause time /// /// The current value of `Instant::now()` is saved and all subsequent calls - /// to `Instant::now()` will return the saved value. This is useful for - /// running tests that are dependent on time. + /// to `Instant::now()` until the timer wheel is checked again will return the saved value. + /// Once the timer wheel is checked, time will immediately advance to the next registered + /// `Delay`. This is useful for running tests that depend on time. /// /// # Panics /// diff --git a/tokio/src/time/delay.rs b/tokio/src/time/delay.rs index 8088c9955cd..744c7e16aea 100644 --- a/tokio/src/time/delay.rs +++ b/tokio/src/time/delay.rs @@ -29,10 +29,29 @@ pub fn delay_until(deadline: Instant) -> Delay { /// operates at millisecond granularity and should not be used for tasks that /// require high-resolution timers. /// +/// To run something regularly on a schedule, see [`interval`]. +/// /// # Cancellation /// /// Canceling a delay is done by dropping the returned future. No additional /// cleanup work is required. +/// +/// # Examples +/// +/// Wait 100ms and print "100 ms have elapsed". +/// +/// ``` +/// use tokio::time::{delay_for, Duration}; +/// +/// #[tokio::main] +/// async fn main() { +/// delay_for(Duration::from_millis(100)).await; +/// println!("100 ms have elapsed"); +/// } +/// ``` +/// +/// [`interval`]: crate::time::interval() +#[cfg_attr(docsrs, doc(alias = "sleep"))] pub fn delay_for(duration: Duration) -> Delay { delay_until(Instant::now() + duration) } diff --git a/tokio/src/time/delay_queue.rs b/tokio/src/time/delay_queue.rs index d0b1f192660..55ec7cd68d1 100644 --- a/tokio/src/time/delay_queue.rs +++ b/tokio/src/time/delay_queue.rs @@ -2,7 +2,7 @@ //! //! See [`DelayQueue`] for more details. //! -//! [`DelayQueue`]: struct.DelayQueue.html +//! [`DelayQueue`]: struct@DelayQueue use crate::time::wheel::{self, Wheel}; use crate::time::{delay_until, Delay, Duration, Error, Instant}; @@ -50,13 +50,12 @@ use std::task::{self, Poll}; /// /// # Implementation /// -/// The `DelayQueue` is backed by the same hashed timing wheel implementation as -/// [`Timer`] as such, it offers the same performance benefits. See [`Timer`] -/// for further implementation notes. +/// The [`DelayQueue`] is backed by a separate instance of the same timer wheel used internally by +/// Tokio's standalone timer utilities such as [`delay_for`]. Because of this, it offers the same +/// performance and scalability benefits. /// -/// State associated with each entry is stored in a [`slab`]. This allows -/// amortizing the cost of allocation. Space created for expired entries is -/// reused when inserting new entries. +/// State associated with each entry is stored in a [`slab`]. This amortizes the cost of allocation, +/// and allows reuse of the memory allocated for expired entires. /// /// Capacity can be checked using [`capacity`] and allocated preemptively by using /// the [`reserve`] method. @@ -112,16 +111,17 @@ use std::task::{self, Poll}; /// } /// ``` /// -/// [`insert`]: #method.insert -/// [`insert_at`]: #method.insert_at -/// [`Key`]: struct.Key.html +/// [`insert`]: method@Self::insert +/// [`insert_at`]: method@Self::insert_at +/// [`Key`]: struct@Key /// [`Stream`]: https://docs.rs/futures/0.1/futures/stream/trait.Stream.html -/// [`poll`]: #method.poll -/// [`Stream::poll`]: #method.poll -/// [`Timer`]: ../struct.Timer.html -/// [`slab`]: https://docs.rs/slab -/// [`capacity`]: #method.capacity -/// [`reserve`]: #method.reserve +/// [`poll`]: method@Self::poll +/// [`Stream::poll`]: method@Self::poll +/// [`DelayQueue`]: struct@DelayQueue +/// [`delay_for`]: fn@super::delay_for +/// [`slab`]: slab +/// [`capacity`]: method@Self::capacity +/// [`reserve`]: method@Self::reserve #[derive(Debug)] pub struct DelayQueue { /// Stores data associated with entries @@ -148,7 +148,7 @@ pub struct DelayQueue { /// /// Values are returned by [`DelayQueue::poll`]. /// -/// [`DelayQueue::poll`]: struct.DelayQueue.html#method.poll +/// [`DelayQueue::poll`]: method@DelayQueue::poll #[derive(Debug)] pub struct Expired { /// The data stored in the queue @@ -166,8 +166,8 @@ pub struct Expired { /// Instances of `Key` are returned by [`DelayQueue::insert`]. See [`DelayQueue`] /// documentation for more details. /// -/// [`DelayQueue`]: struct.DelayQueue.html -/// [`DelayQueue::insert`]: struct.DelayQueue.html#method.insert +/// [`DelayQueue`]: struct@DelayQueue +/// [`DelayQueue::insert`]: method@DelayQueue::insert #[derive(Debug, Clone)] pub struct Key { index: usize, @@ -295,10 +295,10 @@ impl DelayQueue { /// # } /// ``` /// - /// [`poll`]: #method.poll - /// [`remove`]: #method.remove - /// [`reset`]: #method.reset - /// [`Key`]: struct.Key.html + /// [`poll`]: method@Self::poll + /// [`remove`]: method@Self::remove + /// [`reset`]: method@Self::reset + /// [`Key`]: struct@Key /// [type]: # pub fn insert_at(&mut self, value: T, when: Instant) -> Key { assert!(self.slab.len() < MAX_ENTRIES, "max entries exceeded"); @@ -403,10 +403,10 @@ impl DelayQueue { /// # } /// ``` /// - /// [`poll`]: #method.poll - /// [`remove`]: #method.remove - /// [`reset`]: #method.reset - /// [`Key`]: struct.Key.html + /// [`poll`]: method@Self::poll + /// [`remove`]: method@Self::remove + /// [`reset`]: method@Self::reset + /// [`Key`]: struct@Key /// [type]: # pub fn insert(&mut self, value: T, timeout: Duration) -> Key { self.insert_at(value, Instant::now() + timeout) @@ -574,7 +574,7 @@ impl DelayQueue { /// /// Note that this method has no effect on the allocated capacity. /// - /// [`poll`]: #method.poll + /// [`poll`]: method@Self::poll /// /// # Examples /// diff --git a/tokio/src/time/driver/entry.rs b/tokio/src/time/driver/entry.rs index 20cc824019a..b375ee9d417 100644 --- a/tokio/src/time/driver/entry.rs +++ b/tokio/src/time/driver/entry.rs @@ -30,7 +30,8 @@ pub(crate) struct Entry { /// Timer internals. Using a weak pointer allows the timer to shutdown /// without all `Delay` instances having completed. /// - /// When `None`, the entry has not yet been linked with a timer instance. + /// When empty, it means that the entry has not yet been linked with a + /// timer instance. inner: Weak, /// Tracks the entry state. This value contains the following information: @@ -266,8 +267,9 @@ impl Entry { let when = inner.normalize_deadline(deadline); let elapsed = inner.elapsed(); + let next = if when <= elapsed { ELAPSED } else { when }; + let mut curr = entry.state.load(SeqCst); - let mut notify; loop { // In these two cases, there is no work to do when resetting the @@ -278,16 +280,6 @@ impl Entry { return; } - let next; - - if when <= elapsed { - next = ELAPSED; - notify = !is_elapsed(curr); - } else { - next = when; - notify = true; - } - let actual = entry.state.compare_and_swap(curr, next, SeqCst); if curr == actual { @@ -297,7 +289,16 @@ impl Entry { curr = actual; } - if notify { + // If the state has transitioned to 'elapsed' then wake the task as + // this entry is ready to be polled. + if !is_elapsed(curr) && is_elapsed(next) { + entry.waker.wake(); + } + + // The driver tracks all non-elapsed entries; notify the driver that it + // should update its state for this entry unless the entry had already + // elapsed and remains elapsed. + if !is_elapsed(curr) || !is_elapsed(next) { let _ = inner.queue(entry); } } diff --git a/tokio/src/time/driver/mod.rs b/tokio/src/time/driver/mod.rs index 4616816f3f4..554042fccdc 100644 --- a/tokio/src/time/driver/mod.rs +++ b/tokio/src/time/driver/mod.rs @@ -26,29 +26,29 @@ use std::sync::Arc; use std::usize; use std::{cmp, fmt}; -/// Time implementation that drives [`Delay`], [`Interval`], and [`Timeout`]. +/// Time implementation that drives [`Delay`][delay], [`Interval`][interval], and [`Timeout`][timeout]. /// /// A `Driver` instance tracks the state necessary for managing time and -/// notifying the [`Delay`] instances once their deadlines are reached. +/// notifying the [`Delay`][delay] instances once their deadlines are reached. /// -/// It is expected that a single instance manages many individual [`Delay`] +/// It is expected that a single instance manages many individual [`Delay`][delay] /// instances. The `Driver` implementation is thread-safe and, as such, is able /// to handle callers from across threads. /// -/// After creating the `Driver` instance, the caller must repeatedly call -/// [`turn`]. The time driver will perform no work unless [`turn`] is called -/// repeatedly. +/// After creating the `Driver` instance, the caller must repeatedly call `park` +/// or `park_timeout`. The time driver will perform no work unless `park` or +/// `park_timeout` is called repeatedly. /// /// The driver has a resolution of one millisecond. Any unit of time that falls /// between milliseconds are rounded up to the next millisecond. /// -/// When an instance is dropped, any outstanding [`Delay`] instance that has not +/// When an instance is dropped, any outstanding [`Delay`][delay] instance that has not /// elapsed will be notified with an error. At this point, calling `poll` on the -/// [`Delay`] instance will result in `Err` being returned. +/// [`Delay`][delay] instance will result in panic. /// /// # Implementation /// -/// THe time driver is based on the [paper by Varghese and Lauck][paper]. +/// The time driver is based on the [paper by Varghese and Lauck][paper]. /// /// A hashed timing wheel is a vector of slots, where each slot handles a time /// slice. As time progresses, the timer walks over the slot for the current @@ -73,9 +73,14 @@ use std::{cmp, fmt}; /// When the timer processes entries at level zero, it will notify all the /// `Delay` instances as their deadlines have been reached. For all higher /// levels, all entries will be redistributed across the wheel at the next level -/// down. Eventually, as time progresses, entries will [`Delay`] instances will +/// down. Eventually, as time progresses, entries will [`Delay`][delay] instances will /// either be canceled (dropped) or their associated entries will reach level /// zero and be notified. +/// +/// [paper]: http://www.cs.columbia.edu/~nahum/w6998/papers/ton97-timing-wheels.pdf +/// [delay]: crate::time::Delay +/// [timeout]: crate::time::Timeout +/// [interval]: crate::time::Interval #[derive(Debug)] pub(crate) struct Driver { /// Shared state @@ -119,7 +124,7 @@ where T: Park, { /// Creates a new `Driver` instance that uses `park` to block the current - /// thread and `now` to get the current `Instant`. + /// thread and `clock` to get the current `Instant`. /// /// Specifying the source of time is useful when testing. pub(crate) fn new(park: T, clock: Clock) -> Driver { diff --git a/tokio/src/time/driver/registration.rs b/tokio/src/time/driver/registration.rs index b77357e7353..3a0b34501b0 100644 --- a/tokio/src/time/driver/registration.rs +++ b/tokio/src/time/driver/registration.rs @@ -40,9 +40,12 @@ impl Registration { pub(crate) fn poll_elapsed(&self, cx: &mut task::Context<'_>) -> Poll> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); - self.entry.poll_elapsed(cx) + self.entry.poll_elapsed(cx).map(move |r| { + coop.made_progress(); + r + }) } } diff --git a/tokio/src/time/interval.rs b/tokio/src/time/interval.rs index 090e2d1f05a..1fa21e66418 100644 --- a/tokio/src/time/interval.rs +++ b/tokio/src/time/interval.rs @@ -33,6 +33,37 @@ use std::task::{Context, Poll}; /// // approximately 20ms have elapsed. /// } /// ``` +/// +/// A simple example using `interval` to execute a task every two seconds. +/// +/// The difference between `interval` and [`delay_for`] is that an `interval` +/// measures the time since the last tick, which means that `.tick().await` +/// may wait for a shorter time than the duration specified for the interval +/// if some time has passed between calls to `.tick().await`. +/// +/// If the tick in the example below was replaced with [`delay_for`], the task +/// would only be executed once every three seconds, and not every two +/// seconds. +/// +/// ``` +/// use tokio::time; +/// +/// async fn task_that_takes_a_second() { +/// println!("hello"); +/// time::delay_for(time::Duration::from_secs(1)).await +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let mut interval = time::interval(time::Duration::from_secs(2)); +/// for _i in 0..5 { +/// interval.tick().await; +/// task_that_takes_a_second().await; +/// } +/// } +/// ``` +/// +/// [`delay_for`]: crate::time::delay_for() pub fn interval(period: Duration) -> Interval { assert!(period > Duration::new(0, 0), "`period` must be non-zero."); diff --git a/tokio/src/time/mod.rs b/tokio/src/time/mod.rs index 7070d6b2573..c532b2c175f 100644 --- a/tokio/src/time/mod.rs +++ b/tokio/src/time/mod.rs @@ -24,7 +24,7 @@ //! //! # Examples //! -//! Wait 100ms and print "Hello World!" +//! Wait 100ms and print "100 ms have elapsed" //! //! ``` //! use tokio::time::delay_for; @@ -58,6 +58,38 @@ //! } //! # } //! ``` +//! +//! A simple example using [`interval`] to execute a task every two seconds. +//! +//! The difference between [`interval`] and [`delay_for`] is that an +//! [`interval`] measures the time since the last tick, which means that +//! `.tick().await` may wait for a shorter time than the duration specified +//! for the interval if some time has passed between calls to `.tick().await`. +//! +//! If the tick in the example below was replaced with [`delay_for`], the task +//! would only be executed once every three seconds, and not every two +//! seconds. +//! +//! ``` +//! use tokio::time; +//! +//! async fn task_that_takes_a_second() { +//! println!("hello"); +//! time::delay_for(time::Duration::from_secs(1)).await +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let mut interval = time::interval(time::Duration::from_secs(2)); +//! for _i in 0..5 { +//! interval.tick().await; +//! task_that_takes_a_second().await; +//! } +//! } +//! ``` +//! +//! [`delay_for`]: crate::time::delay_for() +//! [`interval`]: crate::time::interval() mod clock; pub(crate) use self::clock::Clock; diff --git a/tokio/src/time/throttle.rs b/tokio/src/time/throttle.rs index 435bef63815..d53a6f76211 100644 --- a/tokio/src/time/throttle.rs +++ b/tokio/src/time/throttle.rs @@ -16,7 +16,7 @@ use pin_project_lite::pin_project; /// # Example /// /// Create a throttled stream. -/// ```rust,norun +/// ```rust,no_run /// use std::time::Duration; /// use tokio::stream::StreamExt; /// use tokio::time::throttle; diff --git a/tokio/src/time/timeout.rs b/tokio/src/time/timeout.rs index 7a37254404d..efc3dc5c069 100644 --- a/tokio/src/time/timeout.rs +++ b/tokio/src/time/timeout.rs @@ -2,10 +2,11 @@ //! //! See [`Timeout`] documentation for more details. //! -//! [`Timeout`]: struct.Timeout.html +//! [`Timeout`]: struct@Timeout use crate::time::{delay_until, Delay, Duration, Instant}; +use pin_project_lite::pin_project; use std::fmt; use std::future::Future; use std::pin::Pin; @@ -99,12 +100,16 @@ where } } -/// Future returned by [`timeout`](timeout) and [`timeout_at`](timeout_at). -#[must_use = "futures do nothing unless you `.await` or poll them"] -#[derive(Debug)] -pub struct Timeout { - value: T, - delay: Delay, +pin_project! { + /// Future returned by [`timeout`](timeout) and [`timeout_at`](timeout_at). + #[must_use = "futures do nothing unless you `.await` or poll them"] + #[derive(Debug)] + pub struct Timeout { + #[pin] + value: T, + #[pin] + delay: Delay, + } } /// Error returned by `Timeout`. @@ -146,24 +151,18 @@ where { type Output = Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - // First, try polling the future + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + let me = self.project(); - // Safety: we never move `self.value` - unsafe { - let p = self.as_mut().map_unchecked_mut(|me| &mut me.value); - if let Poll::Ready(v) = p.poll(cx) { - return Poll::Ready(Ok(v)); - } + // First, try polling the future + if let Poll::Ready(v) = me.value.poll(cx) { + return Poll::Ready(Ok(v)); } // Now check the timer - // Safety: X_X! - unsafe { - match self.map_unchecked_mut(|me| &mut me.delay).poll(cx) { - Poll::Ready(()) => Poll::Ready(Err(Elapsed(()))), - Poll::Pending => Poll::Pending, - } + match me.delay.poll(cx) { + Poll::Ready(()) => Poll::Ready(Err(Elapsed(()))), + Poll::Pending => Poll::Pending, } } } diff --git a/tokio/src/util/intrusive_double_linked_list.rs b/tokio/src/util/intrusive_double_linked_list.rs new file mode 100644 index 00000000000..083fa31d3ec --- /dev/null +++ b/tokio/src/util/intrusive_double_linked_list.rs @@ -0,0 +1,788 @@ +//! An intrusive double linked list of data + +#![allow(dead_code, unreachable_pub)] + +use core::{ + marker::PhantomPinned, + ops::{Deref, DerefMut}, + ptr::NonNull, +}; + +/// A node which carries data of type `T` and is stored in an intrusive list +#[derive(Debug)] +pub struct ListNode { + /// The previous node in the list. `None` if there is no previous node. + prev: Option>>, + /// The next node in the list. `None` if there is no previous node. + next: Option>>, + /// The data which is associated to this list item + data: T, + /// Prevents `ListNode`s from being `Unpin`. They may never be moved, since + /// the list semantics require addresses to be stable. + _pin: PhantomPinned, +} + +impl ListNode { + /// Creates a new node with the associated data + pub fn new(data: T) -> ListNode { + Self { + prev: None, + next: None, + data, + _pin: PhantomPinned, + } + } +} + +impl Deref for ListNode { + type Target = T; + + fn deref(&self) -> &T { + &self.data + } +} + +impl DerefMut for ListNode { + fn deref_mut(&mut self) -> &mut T { + &mut self.data + } +} + +/// An intrusive linked list of nodes, where each node carries associated data +/// of type `T`. +#[derive(Debug)] +pub struct LinkedList { + head: Option>>, + tail: Option>>, +} + +impl LinkedList { + /// Creates an empty linked list + pub fn new() -> Self { + LinkedList:: { + head: None, + tail: None, + } + } + + /// Adds a node at the front of the linked list. + /// Safety: This function is only safe as long as `node` is guaranteed to + /// get removed from the list before it gets moved or dropped. + /// In addition to this `node` may not be added to another other list before + /// it is removed from the current one. + pub unsafe fn add_front(&mut self, node: &mut ListNode) { + node.next = self.head; + node.prev = None; + if let Some(mut head) = self.head { + head.as_mut().prev = Some(node.into()) + }; + self.head = Some(node.into()); + if self.tail.is_none() { + self.tail = Some(node.into()); + } + } + + /// Inserts a node into the list in a way that the list keeps being sorted. + /// Safety: This function is only safe as long as `node` is guaranteed to + /// get removed from the list before it gets moved or dropped. + /// In addition to this `node` may not be added to another other list before + /// it is removed from the current one. + pub unsafe fn add_sorted(&mut self, node: &mut ListNode) + where + T: PartialOrd, + { + if self.head.is_none() { + // First node in the list + self.head = Some(node.into()); + self.tail = Some(node.into()); + return; + } + + let mut prev: Option>> = None; + let mut current = self.head; + + while let Some(mut current_node) = current { + if node.data < current_node.as_ref().data { + // Need to insert before the current node + current_node.as_mut().prev = Some(node.into()); + match prev { + Some(mut prev) => { + prev.as_mut().next = Some(node.into()); + } + None => { + // We are inserting at the beginning of the list + self.head = Some(node.into()); + } + } + node.next = current; + node.prev = prev; + return; + } + prev = current; + current = current_node.as_ref().next; + } + + // We looped through the whole list and the nodes data is bigger or equal + // than everything we found up to now. + // Insert at the end. Since we checked before that the list isn't empty, + // tail always has a value. + node.prev = self.tail; + node.next = None; + self.tail.as_mut().unwrap().as_mut().next = Some(node.into()); + self.tail = Some(node.into()); + } + + /// Returns the first node in the linked list without removing it from the list + /// The function is only safe as long as valid pointers are stored inside + /// the linked list. + /// The returned pointer is only guaranteed to be valid as long as the list + /// is not mutated + pub fn peek_first(&self) -> Option<&mut ListNode> { + // Safety: When the node was inserted it was promised that it is alive + // until it gets removed from the list. + // The returned node has a pointer which constrains it to the lifetime + // of the list. This is ok, since the Node is supposed to outlive + // its insertion in the list. + unsafe { + self.head + .map(|mut node| &mut *(node.as_mut() as *mut ListNode)) + } + } + + /// Returns the last node in the linked list without removing it from the list + /// The function is only safe as long as valid pointers are stored inside + /// the linked list. + /// The returned pointer is only guaranteed to be valid as long as the list + /// is not mutated + pub fn peek_last(&self) -> Option<&mut ListNode> { + // Safety: When the node was inserted it was promised that it is alive + // until it gets removed from the list. + // The returned node has a pointer which constrains it to the lifetime + // of the list. This is ok, since the Node is supposed to outlive + // its insertion in the list. + unsafe { + self.tail + .map(|mut node| &mut *(node.as_mut() as *mut ListNode)) + } + } + + /// Removes the first node from the linked list + pub fn remove_first(&mut self) -> Option<&mut ListNode> { + #![allow(clippy::debug_assert_with_mut_call)] + + // Safety: When the node was inserted it was promised that it is alive + // until it gets removed from the list + unsafe { + let mut head = self.head?; + self.head = head.as_mut().next; + + let first_ref = head.as_mut(); + match first_ref.next { + None => { + // This was the only node in the list + debug_assert_eq!(Some(first_ref.into()), self.tail); + self.tail = None; + } + Some(mut next) => { + next.as_mut().prev = None; + } + } + + first_ref.prev = None; + first_ref.next = None; + Some(&mut *(first_ref as *mut ListNode)) + } + } + + /// Removes the last node from the linked list and returns it + pub fn remove_last(&mut self) -> Option<&mut ListNode> { + #![allow(clippy::debug_assert_with_mut_call)] + + // Safety: When the node was inserted it was promised that it is alive + // until it gets removed from the list + unsafe { + let mut tail = self.tail?; + self.tail = tail.as_mut().prev; + + let last_ref = tail.as_mut(); + match last_ref.prev { + None => { + // This was the last node in the list + debug_assert_eq!(Some(last_ref.into()), self.head); + self.head = None; + } + Some(mut prev) => { + prev.as_mut().next = None; + } + } + + last_ref.prev = None; + last_ref.next = None; + Some(&mut *(last_ref as *mut ListNode)) + } + } + + /// Returns whether the linked list doesn not contain any node + pub fn is_empty(&self) -> bool { + if self.head.is_some() { + return false; + } + + debug_assert!(self.tail.is_none()); + true + } + + /// Removes the given `node` from the linked list. + /// Returns whether the `node` was removed. + /// It is also only safe if it is known that the `node` is either part of this + /// list, or of no list at all. If `node` is part of another list, the + /// behavior is undefined. + pub unsafe fn remove(&mut self, node: &mut ListNode) -> bool { + #![allow(clippy::debug_assert_with_mut_call)] + + match node.prev { + None => { + // This might be the first node in the list. If it is not, the + // node is not in the list at all. Since our precondition is that + // the node must either be in this list or in no list, we check that + // the node is really in no list. + if self.head != Some(node.into()) { + debug_assert!(node.next.is_none()); + return false; + } + self.head = node.next; + } + Some(mut prev) => { + debug_assert_eq!(prev.as_ref().next, Some(node.into())); + prev.as_mut().next = node.next; + } + } + + match node.next { + None => { + // This must be the last node in our list. Otherwise the list + // is inconsistent. + debug_assert_eq!(self.tail, Some(node.into())); + self.tail = node.prev; + } + Some(mut next) => { + debug_assert_eq!(next.as_mut().prev, Some(node.into())); + next.as_mut().prev = node.prev; + } + } + + node.next = None; + node.prev = None; + + true + } + + /// Drains the list iby calling a callback on each list node + /// + /// The method does not return an iterator since stopping or deferring + /// draining the list is not permitted. If the method would push nodes to + /// an iterator we could not guarantee that the nodes do not get utilized + /// after having been removed from the list anymore. + pub fn drain(&mut self, mut func: F) + where + F: FnMut(&mut ListNode), + { + let mut current = self.head; + self.head = None; + self.tail = None; + + while let Some(mut node) = current { + // Safety: The nodes have not been removed from the list yet and must + // therefore contain valid data. The nodes can also not be added to + // the list again during iteration, since the list is mutably borrowed. + unsafe { + let node_ref = node.as_mut(); + current = node_ref.next; + + node_ref.next = None; + node_ref.prev = None; + + // Note: We do not reset the pointers from the next element in the + // list to the current one since we will iterate over the whole + // list anyway, and therefore clean up all pointers. + + func(node_ref); + } + } + } + + /// Drains the list in reverse order by calling a callback on each list node + /// + /// The method does not return an iterator since stopping or deferring + /// draining the list is not permitted. If the method would push nodes to + /// an iterator we could not guarantee that the nodes do not get utilized + /// after having been removed from the list anymore. + pub fn reverse_drain(&mut self, mut func: F) + where + F: FnMut(&mut ListNode), + { + let mut current = self.tail; + self.head = None; + self.tail = None; + + while let Some(mut node) = current { + // Safety: The nodes have not been removed from the list yet and must + // therefore contain valid data. The nodes can also not be added to + // the list again during iteration, since the list is mutably borrowed. + unsafe { + let node_ref = node.as_mut(); + current = node_ref.prev; + + node_ref.next = None; + node_ref.prev = None; + + // Note: We do not reset the pointers from the next element in the + // list to the current one since we will iterate over the whole + // list anyway, and therefore clean up all pointers. + + func(node_ref); + } + } + } +} + +#[cfg(all(test, feature = "std"))] // Tests make use of Vec at the moment +mod tests { + use super::*; + + fn collect_list(mut list: LinkedList) -> Vec { + let mut result = Vec::new(); + list.drain(|node| { + result.push(**node); + }); + result + } + + fn collect_reverse_list(mut list: LinkedList) -> Vec { + let mut result = Vec::new(); + list.reverse_drain(|node| { + result.push(**node); + }); + result + } + + unsafe fn add_nodes(list: &mut LinkedList, nodes: &mut [&mut ListNode]) { + for node in nodes.iter_mut() { + list.add_front(node); + } + } + + unsafe fn assert_clean(node: &mut ListNode) { + assert!(node.next.is_none()); + assert!(node.prev.is_none()); + } + + #[test] + fn insert_and_iterate() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + let mut setup = |list: &mut LinkedList| { + assert_eq!(true, list.is_empty()); + list.add_front(&mut c); + assert_eq!(31, **list.peek_first().unwrap()); + assert_eq!(false, list.is_empty()); + list.add_front(&mut b); + assert_eq!(7, **list.peek_first().unwrap()); + list.add_front(&mut a); + assert_eq!(5, **list.peek_first().unwrap()); + }; + + let mut list = LinkedList::new(); + setup(&mut list); + let items: Vec = collect_list(list); + assert_eq!([5, 7, 31].to_vec(), items); + + let mut list = LinkedList::new(); + setup(&mut list); + let items: Vec = collect_reverse_list(list); + assert_eq!([31, 7, 5].to_vec(), items); + } + } + + #[test] + fn add_sorted() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + let mut d = ListNode::new(99); + + let mut list = LinkedList::new(); + list.add_sorted(&mut a); + let items: Vec = collect_list(list); + assert_eq!([5].to_vec(), items); + + let mut list = LinkedList::new(); + list.add_sorted(&mut a); + let items: Vec = collect_reverse_list(list); + assert_eq!([5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut c, &mut b]); + list.add_sorted(&mut a); + let items: Vec = collect_list(list); + assert_eq!([5, 7, 31, 99].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut c, &mut b]); + list.add_sorted(&mut a); + let items: Vec = collect_reverse_list(list); + assert_eq!([99, 31, 7, 5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut c, &mut a]); + list.add_sorted(&mut b); + let items: Vec = collect_list(list); + assert_eq!([5, 7, 31, 99].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut c, &mut a]); + list.add_sorted(&mut b); + let items: Vec = collect_reverse_list(list); + assert_eq!([99, 31, 7, 5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut b, &mut a]); + list.add_sorted(&mut c); + let items: Vec = collect_list(list); + assert_eq!([5, 7, 31, 99].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut b, &mut a]); + list.add_sorted(&mut c); + let items: Vec = collect_reverse_list(list); + assert_eq!([99, 31, 7, 5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + list.add_sorted(&mut d); + let items: Vec = collect_list(list); + assert_eq!([5, 7, 31, 99].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + list.add_sorted(&mut d); + let items: Vec = collect_reverse_list(list); + assert_eq!([99, 31, 7, 5].to_vec(), items); + } + } + + #[test] + fn drain_and_collect() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + + let taken_items: Vec = collect_list(list); + assert_eq!([5, 7, 31].to_vec(), taken_items); + } + } + + #[test] + fn peek_last() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + + let last = list.peek_last(); + assert_eq!(31, **last.unwrap()); + list.remove_last(); + + let last = list.peek_last(); + assert_eq!(7, **last.unwrap()); + list.remove_last(); + + let last = list.peek_last(); + assert_eq!(5, **last.unwrap()); + list.remove_last(); + + let last = list.peek_last(); + assert!(last.is_none()); + } + } + + #[test] + fn remove_first() { + unsafe { + // We iterate forward and backwards through the manipulated lists + // to make sure pointers in both directions are still ok. + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + let removed = list.remove_first().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_list(list); + assert_eq!([7, 31].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + let removed = list.remove_first().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_reverse_list(list); + assert_eq!([31, 7].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + let removed = list.remove_first().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_list(list); + assert_eq!([7].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + let removed = list.remove_first().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_reverse_list(list); + assert_eq!([7].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut a]); + let removed = list.remove_first().unwrap(); + assert_clean(removed); + assert!(list.is_empty()); + let items: Vec = collect_list(list); + assert!(items.is_empty()); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut a]); + let removed = list.remove_first().unwrap(); + assert_clean(removed); + assert!(list.is_empty()); + let items: Vec = collect_reverse_list(list); + assert!(items.is_empty()); + } + } + + #[test] + fn remove_last() { + unsafe { + // We iterate forward and backwards through the manipulated lists + // to make sure pointers in both directions are still ok. + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + let removed = list.remove_last().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_list(list); + assert_eq!([5, 7].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + let removed = list.remove_last().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_reverse_list(list); + assert_eq!([7, 5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + let removed = list.remove_last().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_list(list); + assert_eq!([5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + let removed = list.remove_last().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_reverse_list(list); + assert_eq!([5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut a]); + let removed = list.remove_last().unwrap(); + assert_clean(removed); + assert!(list.is_empty()); + let items: Vec = collect_list(list); + assert!(items.is_empty()); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut a]); + let removed = list.remove_last().unwrap(); + assert_clean(removed); + assert!(list.is_empty()); + let items: Vec = collect_reverse_list(list); + assert!(items.is_empty()); + } + } + + #[test] + fn remove_by_address() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + { + // Remove first + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut a)); + assert_clean((&mut a).into()); + // a should be no longer there and can't be removed twice + assert_eq!(false, list.remove(&mut a)); + assert_eq!(Some((&mut b).into()), list.head); + assert_eq!(Some((&mut c).into()), b.next); + assert_eq!(Some((&mut b).into()), c.prev); + let items: Vec = collect_list(list); + assert_eq!([7, 31].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut a)); + assert_clean((&mut a).into()); + // a should be no longer there and can't be removed twice + assert_eq!(false, list.remove(&mut a)); + assert_eq!(Some((&mut c).into()), b.next); + assert_eq!(Some((&mut b).into()), c.prev); + let items: Vec = collect_reverse_list(list); + assert_eq!([31, 7].to_vec(), items); + } + + { + // Remove middle + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut b)); + assert_clean((&mut b).into()); + assert_eq!(Some((&mut c).into()), a.next); + assert_eq!(Some((&mut a).into()), c.prev); + let items: Vec = collect_list(list); + assert_eq!([5, 31].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut b)); + assert_clean((&mut b).into()); + assert_eq!(Some((&mut c).into()), a.next); + assert_eq!(Some((&mut a).into()), c.prev); + let items: Vec = collect_reverse_list(list); + assert_eq!([31, 5].to_vec(), items); + } + + { + // Remove last + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut c)); + assert_clean((&mut c).into()); + assert!(b.next.is_none()); + assert_eq!(Some((&mut b).into()), list.tail); + let items: Vec = collect_list(list); + assert_eq!([5, 7].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut c)); + assert_clean((&mut c).into()); + assert!(b.next.is_none()); + assert_eq!(Some((&mut b).into()), list.tail); + let items: Vec = collect_reverse_list(list); + assert_eq!([7, 5].to_vec(), items); + } + + { + // Remove first of two + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + assert_eq!(true, list.remove(&mut a)); + assert_clean((&mut a).into()); + // a should be no longer there and can't be removed twice + assert_eq!(false, list.remove(&mut a)); + assert_eq!(Some((&mut b).into()), list.head); + assert_eq!(Some((&mut b).into()), list.tail); + assert!(b.next.is_none()); + assert!(b.prev.is_none()); + let items: Vec = collect_list(list); + assert_eq!([7].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + assert_eq!(true, list.remove(&mut a)); + assert_clean((&mut a).into()); + // a should be no longer there and can't be removed twice + assert_eq!(false, list.remove(&mut a)); + assert_eq!(Some((&mut b).into()), list.head); + assert_eq!(Some((&mut b).into()), list.tail); + assert!(b.next.is_none()); + assert!(b.prev.is_none()); + let items: Vec = collect_reverse_list(list); + assert_eq!([7].to_vec(), items); + } + + { + // Remove last of two + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + assert_eq!(true, list.remove(&mut b)); + assert_clean((&mut b).into()); + assert_eq!(Some((&mut a).into()), list.head); + assert_eq!(Some((&mut a).into()), list.tail); + assert!(a.next.is_none()); + assert!(a.prev.is_none()); + let items: Vec = collect_list(list); + assert_eq!([5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + assert_eq!(true, list.remove(&mut b)); + assert_clean((&mut b).into()); + assert_eq!(Some((&mut a).into()), list.head); + assert_eq!(Some((&mut a).into()), list.tail); + assert!(a.next.is_none()); + assert!(a.prev.is_none()); + let items: Vec = collect_reverse_list(list); + assert_eq!([5].to_vec(), items); + } + + { + // Remove last item + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut a]); + assert_eq!(true, list.remove(&mut a)); + assert_clean((&mut a).into()); + assert!(list.head.is_none()); + assert!(list.tail.is_none()); + let items: Vec = collect_list(list); + assert!(items.is_empty()); + } + + { + // Remove missing + let mut list = LinkedList::new(); + list.add_front(&mut b); + list.add_front(&mut a); + assert_eq!(false, list.remove(&mut c)); + } + } + } +} diff --git a/tokio/src/util/mod.rs b/tokio/src/util/mod.rs index a093395c020..6dda08ca411 100644 --- a/tokio/src/util/mod.rs +++ b/tokio/src/util/mod.rs @@ -19,6 +19,10 @@ cfg_rt_threaded! { pub(crate) use try_lock::TryLock; } +pub(crate) mod trace; + #[cfg(any(feature = "macros", feature = "stream"))] #[cfg_attr(not(feature = "macros"), allow(unreachable_pub))] pub use rand::thread_rng_n; + +pub(crate) mod intrusive_double_linked_list; diff --git a/tokio/src/util/trace.rs b/tokio/src/util/trace.rs new file mode 100644 index 00000000000..d8c6120d97c --- /dev/null +++ b/tokio/src/util/trace.rs @@ -0,0 +1,57 @@ +cfg_trace! { + cfg_rt_core! { + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + use pin_project_lite::pin_project; + + use tracing::Span; + + pin_project! { + /// A future that has been instrumented with a `tracing` span. + #[derive(Debug, Clone)] + pub(crate) struct Instrumented { + #[pin] + inner: T, + span: Span, + } + } + + impl Future for Instrumented { + type Output = T::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let _enter = this.span.enter(); + this.inner.poll(cx) + } + } + + impl Instrumented { + pub(crate) fn new(inner: T, span: Span) -> Self { + Self { inner, span } + } + } + + #[inline] + pub(crate) fn task(task: F, kind: &'static str) -> Instrumented { + let span = tracing::trace_span!( + target: "tokio::task", + "task", + %kind, + future = %std::any::type_name::(), + ); + Instrumented::new(task, span) + } + } +} + +cfg_not_trace! { + cfg_rt_core! { + #[inline] + pub(crate) fn task(task: F, _: &'static str) -> F { + // nop + task + } + } +} diff --git a/tokio/tests/async_send_sync.rs b/tokio/tests/async_send_sync.rs index f50b41a0890..afe053b1010 100644 --- a/tokio/tests/async_send_sync.rs +++ b/tokio/tests/async_send_sync.rs @@ -41,6 +41,44 @@ macro_rules! into_todo { x }}; } +macro_rules! assert_value { + ($type:ty: Send & Sync) => { + #[allow(unreachable_code)] + #[allow(unused_variables)] + const _: fn() = || { + let f: $type = todo!(); + require_send(&f); + require_sync(&f); + }; + }; + ($type:ty: !Send & Sync) => { + #[allow(unreachable_code)] + #[allow(unused_variables)] + const _: fn() = || { + let f: $type = todo!(); + AmbiguousIfSend::some_item(&f); + require_sync(&f); + }; + }; + ($type:ty: Send & !Sync) => { + #[allow(unreachable_code)] + #[allow(unused_variables)] + const _: fn() = || { + let f: $type = todo!(); + require_send(&f); + AmbiguousIfSync::some_item(&f); + }; + }; + ($type:ty: !Send & !Sync) => { + #[allow(unreachable_code)] + #[allow(unused_variables)] + const _: fn() = || { + let f: $type = todo!(); + AmbiguousIfSend::some_item(&f); + AmbiguousIfSync::some_item(&f); + }; + }; +} macro_rules! async_assert_fn { ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): Send & Sync) => { #[allow(unreachable_code)] @@ -165,6 +203,9 @@ async_assert_fn!(tokio::sync::Barrier::wait(_): Send & Sync); async_assert_fn!(tokio::sync::Mutex::lock(_): Send & Sync); async_assert_fn!(tokio::sync::Mutex>::lock(_): Send & Sync); async_assert_fn!(tokio::sync::Mutex>::lock(_): !Send & !Sync); +async_assert_fn!(tokio::sync::Mutex::lock_owned(_): Send & Sync); +async_assert_fn!(tokio::sync::Mutex>::lock_owned(_): Send & Sync); +async_assert_fn!(tokio::sync::Mutex>::lock_owned(_): !Send & !Sync); async_assert_fn!(tokio::sync::Notify::notified(_): Send & !Sync); async_assert_fn!(tokio::sync::RwLock::read(_): Send & Sync); async_assert_fn!(tokio::sync::RwLock::write(_): Send & Sync); @@ -206,6 +247,7 @@ async_assert_fn!(tokio::task::LocalKey>::scope(_, Rc, BoxFutureSync async_assert_fn!(tokio::task::LocalKey>::scope(_, Rc, BoxFutureSend<()>): !Send & !Sync); async_assert_fn!(tokio::task::LocalKey>::scope(_, Rc, BoxFuture<()>): !Send & !Sync); async_assert_fn!(tokio::task::LocalSet::run_until(_, BoxFutureSync<()>): !Send & !Sync); +assert_value!(tokio::task::LocalSet: !Send & !Sync); async_assert_fn!(tokio::time::advance(Duration): Send & Sync); async_assert_fn!(tokio::time::delay_for(Duration): Send & Sync); @@ -217,3 +259,6 @@ async_assert_fn!(tokio::time::timeout_at(Instant, BoxFutureSync<()>): Send & Syn async_assert_fn!(tokio::time::timeout_at(Instant, BoxFutureSend<()>): Send & !Sync); async_assert_fn!(tokio::time::timeout_at(Instant, BoxFuture<()>): !Send & !Sync); async_assert_fn!(tokio::time::Interval::tick(_): Send & Sync); + +#[cfg(tokio_unstable)] +assert_value!(tokio::sync::CancellationToken: Send & Sync); diff --git a/tokio/tests/fs_dir.rs b/tokio/tests/fs_dir.rs index eaff59da4f9..6355ef05fcb 100644 --- a/tokio/tests/fs_dir.rs +++ b/tokio/tests/fs_dir.rs @@ -2,7 +2,7 @@ #![cfg(feature = "full")] use tokio::fs; -use tokio_test::assert_ok; +use tokio_test::{assert_err, assert_ok}; use std::sync::{Arc, Mutex}; use tempfile::tempdir; @@ -28,6 +28,23 @@ async fn create_all() { assert!(new_dir_2.is_dir()); } +#[tokio::test] +async fn build_dir() { + let base_dir = tempdir().unwrap(); + let new_dir = base_dir.path().join("foo").join("bar"); + let new_dir_2 = new_dir.clone(); + + assert_ok!(fs::DirBuilder::new().recursive(true).create(new_dir).await); + + assert!(new_dir_2.is_dir()); + assert_err!( + fs::DirBuilder::new() + .recursive(false) + .create(new_dir_2) + .await + ); +} + #[tokio::test] async fn remove() { let base_dir = tempdir().unwrap(); diff --git a/tokio/tests/fs_file_mocked.rs b/tokio/tests/fs_file_mocked.rs index 0c5722404ea..2e7e8b7cf48 100644 --- a/tokio/tests/fs_file_mocked.rs +++ b/tokio/tests/fs_file_mocked.rs @@ -257,27 +257,27 @@ fn flush_while_idle() { #[test] fn read_with_buffer_larger_than_max() { // Chunks - let a = 16 * 1024; - let b = a * 2; - let c = a * 3; - let d = a * 4; + let chunk_a = 16 * 1024; + let chunk_b = chunk_a * 2; + let chunk_c = chunk_a * 3; + let chunk_d = chunk_a * 4; - assert_eq!(d / 1024, 64); + assert_eq!(chunk_d / 1024, 64); let mut data = vec![]; - for i in 0..(d - 1) { + for i in 0..(chunk_d - 1) { data.push((i % 151) as u8); } let (mock, file) = sys::File::mock(); - mock.read(&data[0..a]) - .read(&data[a..b]) - .read(&data[b..c]) - .read(&data[c..]); + mock.read(&data[0..chunk_a]) + .read(&data[chunk_a..chunk_b]) + .read(&data[chunk_b..chunk_c]) + .read(&data[chunk_c..]); let mut file = File::from_std(file); - let mut actual = vec![0; d]; + let mut actual = vec![0; chunk_d]; let mut pos = 0; while pos < data.len() { @@ -288,7 +288,7 @@ fn read_with_buffer_larger_than_max() { assert!(t.is_woken()); let n = assert_ready_ok!(t.poll()); - assert!(n <= a); + assert!(n <= chunk_a); pos += n; } @@ -300,23 +300,23 @@ fn read_with_buffer_larger_than_max() { #[test] fn write_with_buffer_larger_than_max() { // Chunks - let a = 16 * 1024; - let b = a * 2; - let c = a * 3; - let d = a * 4; + let chunk_a = 16 * 1024; + let chunk_b = chunk_a * 2; + let chunk_c = chunk_a * 3; + let chunk_d = chunk_a * 4; - assert_eq!(d / 1024, 64); + assert_eq!(chunk_d / 1024, 64); let mut data = vec![]; - for i in 0..(d - 1) { + for i in 0..(chunk_d - 1) { data.push((i % 151) as u8); } let (mock, file) = sys::File::mock(); - mock.write(&data[0..a]) - .write(&data[a..b]) - .write(&data[b..c]) - .write(&data[c..]); + mock.write(&data[0..chunk_a]) + .write(&data[chunk_a..chunk_b]) + .write(&data[chunk_b..chunk_c]) + .write(&data[chunk_c..]); let mut file = File::from_std(file); @@ -325,17 +325,17 @@ fn write_with_buffer_larger_than_max() { let mut first = true; while !rem.is_empty() { - let mut t = task::spawn(file.write(rem)); + let mut task = task::spawn(file.write(rem)); if !first { - assert_pending!(t.poll()); + assert_pending!(task.poll()); pool::run_one(); - assert!(t.is_woken()); + assert!(task.is_woken()); } first = false; - let n = assert_ready_ok!(t.poll()); + let n = assert_ready_ok!(task.poll()); rem = &rem[n..]; } diff --git a/tokio/tests/io_read_line.rs b/tokio/tests/io_read_line.rs index 57ae37cef3e..15841c9b49d 100644 --- a/tokio/tests/io_read_line.rs +++ b/tokio/tests/io_read_line.rs @@ -1,8 +1,9 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::AsyncBufReadExt; -use tokio_test::assert_ok; +use std::io::ErrorKind; +use tokio::io::{AsyncBufReadExt, BufReader, Error}; +use tokio_test::{assert_ok, io::Builder}; use std::io::Cursor; @@ -27,3 +28,80 @@ async fn read_line() { assert_eq!(n, 0); assert_eq!(buf, ""); } + +#[tokio::test] +async fn read_line_not_all_ready() { + let mock = Builder::new() + .read(b"Hello Wor") + .read(b"ld\nFizzBuz") + .read(b"z\n1\n2") + .build(); + + let mut read = BufReader::new(mock); + + let mut line = "We say ".to_string(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, "Hello World\n".len()); + assert_eq!(line.as_str(), "We say Hello World\n"); + + line = "I solve ".to_string(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, "FizzBuzz\n".len()); + assert_eq!(line.as_str(), "I solve FizzBuzz\n"); + + line.clear(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, 2); + assert_eq!(line.as_str(), "1\n"); + + line.clear(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, 1); + assert_eq!(line.as_str(), "2"); +} + +#[tokio::test] +async fn read_line_invalid_utf8() { + let mock = Builder::new().read(b"Hello Wor\xffld.\n").build(); + + let mut read = BufReader::new(mock); + + let mut line = "Foo".to_string(); + let err = read.read_line(&mut line).await.expect_err("Should fail"); + assert_eq!(err.kind(), ErrorKind::InvalidData); + assert_eq!(err.to_string(), "stream did not contain valid UTF-8"); + assert_eq!(line.as_str(), "Foo"); +} + +#[tokio::test] +async fn read_line_fail() { + let mock = Builder::new() + .read(b"Hello Wor") + .read_error(Error::new(ErrorKind::Other, "The world has no end")) + .build(); + + let mut read = BufReader::new(mock); + + let mut line = "Foo".to_string(); + let err = read.read_line(&mut line).await.expect_err("Should fail"); + assert_eq!(err.kind(), ErrorKind::Other); + assert_eq!(err.to_string(), "The world has no end"); + assert_eq!(line.as_str(), "FooHello Wor"); +} + +#[tokio::test] +async fn read_line_fail_and_utf8_fail() { + let mock = Builder::new() + .read(b"Hello Wor") + .read(b"\xff\xff\xff") + .read_error(Error::new(ErrorKind::Other, "The world has no end")) + .build(); + + let mut read = BufReader::new(mock); + + let mut line = "Foo".to_string(); + let err = read.read_line(&mut line).await.expect_err("Should fail"); + assert_eq!(err.kind(), ErrorKind::Other); + assert_eq!(err.to_string(), "The world has no end"); + assert_eq!(line.as_str(), "Foo"); +} diff --git a/tokio/tests/io_read_until.rs b/tokio/tests/io_read_until.rs index 4e0e0d10d34..61800a0d9c1 100644 --- a/tokio/tests/io_read_until.rs +++ b/tokio/tests/io_read_until.rs @@ -1,8 +1,9 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::AsyncBufReadExt; -use tokio_test::assert_ok; +use std::io::ErrorKind; +use tokio::io::{AsyncBufReadExt, BufReader, Error}; +use tokio_test::{assert_ok, io::Builder}; #[tokio::test] async fn read_until() { @@ -21,3 +22,53 @@ async fn read_until() { assert_eq!(n, 0); assert_eq!(buf, []); } + +#[tokio::test] +async fn read_until_not_all_ready() { + let mock = Builder::new() + .read(b"Hello Wor") + .read(b"ld#Fizz\xffBuz") + .read(b"z#1#2") + .build(); + + let mut read = BufReader::new(mock); + + let mut chunk = b"We say ".to_vec(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, b"Hello World#".len()); + assert_eq!(chunk, b"We say Hello World#"); + + chunk = b"I solve ".to_vec(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, b"Fizz\xffBuzz\n".len()); + assert_eq!(chunk, b"I solve Fizz\xffBuzz#"); + + chunk.clear(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, 2); + assert_eq!(chunk, b"1#"); + + chunk.clear(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, 1); + assert_eq!(chunk, b"2"); +} + +#[tokio::test] +async fn read_until_fail() { + let mock = Builder::new() + .read(b"Hello \xffWor") + .read_error(Error::new(ErrorKind::Other, "The world has no end")) + .build(); + + let mut read = BufReader::new(mock); + + let mut chunk = b"Foo".to_vec(); + let err = read + .read_until(b'#', &mut chunk) + .await + .expect_err("Should fail"); + assert_eq!(err.kind(), ErrorKind::Other); + assert_eq!(err.to_string(), "The world has no end"); + assert_eq!(chunk, b"FooHello \xffWor"); +} diff --git a/tokio/tests/io_write_int.rs b/tokio/tests/io_write_int.rs new file mode 100644 index 00000000000..48a583d8c3f --- /dev/null +++ b/tokio/tests/io_write_int.rs @@ -0,0 +1,37 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[tokio::test] +async fn write_int_should_err_if_write_count_0() { + struct Wr {} + + impl AsyncWrite for Wr { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Ok(0).into() + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Ok(()).into() + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Ok(()).into() + } + } + + let mut wr = Wr {}; + + // should be ok just to test these 2, other cases actually expanded by same macro. + assert!(wr.write_i8(0).await.is_err()); + assert!(wr.write_i32(12).await.is_err()); +} diff --git a/tokio/tests/macros_join.rs b/tokio/tests/macros_join.rs index d9b748d9a7b..169e898f97d 100644 --- a/tokio/tests/macros_join.rs +++ b/tokio/tests/macros_join.rs @@ -1,3 +1,4 @@ +#![allow(clippy::blacklisted_name)] use tokio::sync::oneshot; use tokio_test::{assert_pending, assert_ready, task}; diff --git a/tokio/tests/macros_select.rs b/tokio/tests/macros_select.rs index c08e816a015..6f027f3bdfc 100644 --- a/tokio/tests/macros_select.rs +++ b/tokio/tests/macros_select.rs @@ -1,3 +1,4 @@ +#![allow(clippy::blacklisted_name)] use tokio::sync::{mpsc, oneshot}; use tokio::task; use tokio_test::{assert_ok, assert_pending, assert_ready}; diff --git a/tokio/tests/macros_test.rs b/tokio/tests/macros_test.rs new file mode 100644 index 00000000000..8e68b8a4417 --- /dev/null +++ b/tokio/tests/macros_test.rs @@ -0,0 +1,19 @@ +use tokio::test; + +#[test] +async fn test_macro_can_be_used_via_use() { + tokio::spawn(async { + assert_eq!(1 + 1, 2); + }) + .await + .unwrap(); +} + +#[tokio::test] +async fn test_macro_is_resilient_to_shadowing() { + tokio::spawn(async { + assert_eq!(1 + 1, 2); + }) + .await + .unwrap(); +} diff --git a/tokio/tests/macros_try_join.rs b/tokio/tests/macros_try_join.rs index faa55421a2b..a9251532664 100644 --- a/tokio/tests/macros_try_join.rs +++ b/tokio/tests/macros_try_join.rs @@ -1,3 +1,4 @@ +#![allow(clippy::blacklisted_name)] use tokio::sync::oneshot; use tokio_test::{assert_pending, assert_ready, task}; diff --git a/tokio/tests/read_to_string.rs b/tokio/tests/read_to_string.rs new file mode 100644 index 00000000000..db3fa1bf4bd --- /dev/null +++ b/tokio/tests/read_to_string.rs @@ -0,0 +1,49 @@ +use std::io; +use tokio::io::AsyncReadExt; +use tokio_test::io::Builder; + +#[tokio::test] +async fn to_string_does_not_truncate_on_utf8_error() { + let data = vec![0xff, 0xff, 0xff]; + + let mut s = "abc".to_string(); + + match AsyncReadExt::read_to_string(&mut data.as_slice(), &mut s).await { + Ok(len) => panic!("Should fail: {} bytes.", len), + Err(err) if err.to_string() == "stream did not contain valid UTF-8" => {} + Err(err) => panic!("Fail: {}.", err), + } + + assert_eq!(s, "abc"); +} + +#[tokio::test] +async fn to_string_does_not_truncate_on_io_error() { + let mut mock = Builder::new() + .read(b"def") + .read_error(io::Error::new(io::ErrorKind::Other, "whoops")) + .build(); + let mut s = "abc".to_string(); + + match AsyncReadExt::read_to_string(&mut mock, &mut s).await { + Ok(len) => panic!("Should fail: {} bytes.", len), + Err(err) if err.to_string() == "whoops" => {} + Err(err) => panic!("Fail: {}.", err), + } + + assert_eq!(s, "abc"); +} + +#[tokio::test] +async fn to_string_appends() { + let data = b"def".to_vec(); + + let mut s = "abc".to_string(); + + let len = AsyncReadExt::read_to_string(&mut data.as_slice(), &mut s) + .await + .unwrap(); + + assert_eq!(len, 3); + assert_eq!(s, "abcdef"); +} diff --git a/tokio/tests/rt_basic.rs b/tokio/tests/rt_basic.rs index b9e373b88f8..0885992d7d2 100644 --- a/tokio/tests/rt_basic.rs +++ b/tokio/tests/rt_basic.rs @@ -29,6 +29,7 @@ fn spawned_task_does_not_progress_without_block_on() { #[test] fn no_extra_poll() { + use pin_project_lite::pin_project; use std::pin::Pin; use std::sync::{ atomic::{AtomicUsize, Ordering::SeqCst}, @@ -37,9 +38,12 @@ fn no_extra_poll() { use std::task::{Context, Poll}; use tokio::stream::{Stream, StreamExt}; - struct TrackPolls { - npolls: Arc, - s: S, + pin_project! { + struct TrackPolls { + npolls: Arc, + #[pin] + s: S, + } } impl Stream for TrackPolls @@ -48,11 +52,9 @@ fn no_extra_poll() { { type Item = S::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // safety: we do not move s - let this = unsafe { self.get_unchecked_mut() }; + let this = self.project(); this.npolls.fetch_add(1, SeqCst); - // safety: we are pinned, and so is s - unsafe { Pin::new_unchecked(&mut this.s) }.poll_next(cx) + this.s.poll_next(cx) } } @@ -65,7 +67,7 @@ fn no_extra_poll() { let mut rt = rt(); - rt.spawn(async move { while let Some(_) = rx.next().await {} }); + rt.spawn(async move { while rx.next().await.is_some() {} }); rt.block_on(async { tokio::task::yield_now().await; }); diff --git a/tokio/tests/rt_common.rs b/tokio/tests/rt_common.rs index 8dc0da3c5a1..71101d46cef 100644 --- a/tokio/tests/rt_common.rs +++ b/tokio/tests/rt_common.rs @@ -82,6 +82,18 @@ rt_test! { assert!(win); } + #[test] + fn block_on_handle_sync() { + let rt = rt(); + + let mut win = false; + rt.handle().block_on(async { + win = true; + }); + + assert!(win); + } + #[test] fn block_on_async() { let mut rt = rt(); @@ -100,6 +112,24 @@ rt_test! { assert_eq!(out, "ZOMG"); } + #[test] + fn block_on_handle_async() { + let rt = rt(); + + let out = rt.handle().block_on(async { + let (tx, rx) = oneshot::channel(); + + thread::spawn(move || { + thread::sleep(Duration::from_millis(50)); + tx.send("ZOMG").unwrap(); + }); + + assert_ok!(rx.await) + }); + + assert_eq!(out, "ZOMG"); + } + #[test] fn spawn_one_bg() { let mut rt = rt(); @@ -571,6 +601,19 @@ rt_test! { } #[test] + // IOCP requires setting the "max thread" concurrency value. The sane, + // default, is to set this to the number of cores. Threads that poll I/O + // become associated with the IOCP handle. Once those threads sleep for any + // reason (mutex), they yield their ownership. + // + // This test hits an edge case on windows where more threads than cores are + // created, none of those threads ever yield due to being at capacity, so + // IOCP gets "starved". + // + // For now, this is a very edge case that is probably not a real production + // concern. There also isn't a great/obvious solution to take. For now, the + // test is disabled. + #[cfg(not(windows))] fn io_driver_called_when_under_load() { let mut rt = rt(); diff --git a/tokio/tests/rt_threaded.rs b/tokio/tests/rt_threaded.rs index 9c95afd5ae2..b5ec96dec35 100644 --- a/tokio/tests/rt_threaded.rs +++ b/tokio/tests/rt_threaded.rs @@ -7,6 +7,7 @@ use tokio::runtime::{self, Runtime}; use tokio::sync::oneshot; use tokio_test::{assert_err, assert_ok}; +use futures::future::poll_fn; use std::future::Future; use std::pin::Pin; use std::sync::atomic::AtomicUsize; @@ -322,6 +323,74 @@ fn multi_threadpool() { done_rx.recv().unwrap(); } +// When `block_in_place` returns, it attempts to reclaim the yielded runtime +// worker. In this case, the remainder of the task is on the runtime worker and +// must take part in the cooperative task budgeting system. +// +// The test ensures that, when this happens, attempting to consume from a +// channel yields occasionally even if there are values ready to receive. +#[test] +fn coop_and_block_in_place() { + use tokio::sync::mpsc; + + let mut rt = tokio::runtime::Builder::new() + .threaded_scheduler() + // Setting max threads to 1 prevents another thread from claiming the + // runtime worker yielded as part of `block_in_place` and guarantees the + // same thread will reclaim the worker at the end of the + // `block_in_place` call. + .max_threads(1) + .build() + .unwrap(); + + rt.block_on(async move { + let (mut tx, mut rx) = mpsc::channel(1024); + + // Fill the channel + for _ in 0..1024 { + tx.send(()).await.unwrap(); + } + + drop(tx); + + tokio::spawn(async move { + // Block in place without doing anything + tokio::task::block_in_place(|| {}); + + // Receive all the values, this should trigger a `Pending` as the + // coop limit will be reached. + poll_fn(|cx| { + while let Poll::Ready(v) = { + tokio::pin! { + let fut = rx.recv(); + } + + Pin::new(&mut fut).poll(cx) + } { + if v.is_none() { + panic!("did not yield"); + } + } + + Poll::Ready(()) + }) + .await + }) + .await + .unwrap(); + }); +} + +// Testing this does not panic +#[test] +fn max_threads() { + let _rt = tokio::runtime::Builder::new() + .threaded_scheduler() + .max_threads(1) + .build() + .unwrap(); +} + fn rt() -> Runtime { Runtime::new().unwrap() } diff --git a/tokio/tests/stream_chain.rs b/tokio/tests/stream_chain.rs index 0e14618b49b..98461a8ccb3 100644 --- a/tokio/tests/stream_chain.rs +++ b/tokio/tests/stream_chain.rs @@ -69,3 +69,27 @@ async fn pending_first() { assert_eq!(stream.size_hint(), (0, None)); assert_eq!(None, assert_ready!(stream.poll_next())); } + +#[test] +fn size_overflow() { + struct Monster; + + impl tokio::stream::Stream for Monster { + type Item = (); + fn poll_next( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + panic!() + } + + fn size_hint(&self) -> (usize, Option) { + (usize::max_value(), Some(usize::max_value())) + } + } + + let m1 = Monster; + let m2 = Monster; + let m = m1.chain(m2); + assert_eq!(m.size_hint(), (usize::max_value(), None)); +} diff --git a/tokio/tests/stream_merge.rs b/tokio/tests/stream_merge.rs index f0168d72eec..45ecdcb6625 100644 --- a/tokio/tests/stream_merge.rs +++ b/tokio/tests/stream_merge.rs @@ -52,3 +52,27 @@ async fn merge_async_streams() { assert!(rx.is_woken()); assert_eq!(None, assert_ready!(rx.poll_next())); } + +#[test] +fn size_overflow() { + struct Monster; + + impl tokio::stream::Stream for Monster { + type Item = (); + fn poll_next( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + panic!() + } + + fn size_hint(&self) -> (usize, Option) { + (usize::max_value(), Some(usize::max_value())) + } + } + + let m1 = Monster; + let m2 = Monster; + let m = m1.merge(m2); + assert_eq!(m.size_hint(), (usize::max_value(), None)); +} diff --git a/tokio/tests/sync_broadcast.rs b/tokio/tests/sync_broadcast.rs index e9e7b366104..e37695b37d9 100644 --- a/tokio/tests/sync_broadcast.rs +++ b/tokio/tests/sync_broadcast.rs @@ -40,7 +40,16 @@ macro_rules! assert_lagged { }; } -trait AssertSend: Send {} +macro_rules! assert_closed { + ($e:expr) => { + match assert_err!($e) { + broadcast::TryRecvError::Closed => {} + _ => panic!("did not lag"), + } + }; +} + +trait AssertSend: Send + Sync {} impl AssertSend for broadcast::Sender {} impl AssertSend for broadcast::Receiver {} @@ -81,10 +90,13 @@ fn send_two_recv() { } #[tokio::test] -async fn send_recv_stream() { +async fn send_recv_into_stream_ready() { use tokio::stream::StreamExt; - let (tx, mut rx) = broadcast::channel::(8); + let (tx, rx) = broadcast::channel::(8); + tokio::pin! { + let rx = rx.into_stream(); + } assert_ok!(tx.send(1)); assert_ok!(tx.send(2)); @@ -97,6 +109,26 @@ async fn send_recv_stream() { assert_eq!(None, rx.next().await); } +#[tokio::test] +async fn send_recv_into_stream_pending() { + use tokio::stream::StreamExt; + + let (tx, rx) = broadcast::channel::(8); + + tokio::pin! { + let rx = rx.into_stream(); + } + + let mut recv = task::spawn(rx.next()); + assert_pending!(recv.poll()); + + assert_ok!(tx.send(1)); + + assert!(recv.is_woken()); + let val = assert_ready!(recv.poll()); + assert_eq!(val, Some(Ok(1))); +} + #[test] fn send_recv_bounded() { let (tx, mut rx) = broadcast::channel(16); @@ -151,6 +183,23 @@ fn send_two_recv_bounded() { assert_eq!(val2, "world"); } +#[test] +fn change_tasks() { + let (tx, mut rx) = broadcast::channel(1); + + let mut recv = Box::pin(rx.recv()); + + let mut task1 = task::spawn(&mut recv); + assert_pending!(task1.poll()); + + let mut task2 = task::spawn(&mut recv); + assert_pending!(task2.poll()); + + tx.send("hello").unwrap(); + + assert!(task2.is_woken()); +} + #[test] fn send_slow_rx() { let (tx, mut rx1) = broadcast::channel(16); @@ -229,7 +278,8 @@ fn lagging_rx() { assert_ok!(tx.send("three")); // Lagged too far - assert_lagged!(rx2.try_recv(), 1); + let x = dbg!(rx2.try_recv()); + assert_lagged!(x, 1); // Calling again gets the next value assert_eq!("two", assert_recv!(rx2)); @@ -349,6 +399,131 @@ fn unconsumed_messages_are_dropped() { assert_eq!(1, Arc::strong_count(&msg)); } +#[test] +fn single_capacity_recvs() { + let (tx, mut rx) = broadcast::channel(1); + + assert_ok!(tx.send(1)); + + assert_eq!(assert_recv!(rx), 1); + assert_empty!(rx); +} + +#[test] +fn single_capacity_recvs_after_drop_1() { + let (tx, mut rx) = broadcast::channel(1); + + assert_ok!(tx.send(1)); + drop(tx); + + assert_eq!(assert_recv!(rx), 1); + assert_closed!(rx.try_recv()); +} + +#[test] +fn single_capacity_recvs_after_drop_2() { + let (tx, mut rx) = broadcast::channel(1); + + assert_ok!(tx.send(1)); + assert_ok!(tx.send(2)); + drop(tx); + + assert_lagged!(rx.try_recv(), 1); + assert_eq!(assert_recv!(rx), 2); + assert_closed!(rx.try_recv()); +} + +#[test] +fn dropping_sender_does_not_overwrite() { + let (tx, mut rx) = broadcast::channel(2); + + assert_ok!(tx.send(1)); + assert_ok!(tx.send(2)); + drop(tx); + + assert_eq!(assert_recv!(rx), 1); + assert_eq!(assert_recv!(rx), 2); + assert_closed!(rx.try_recv()); +} + +#[test] +fn lagging_receiver_recovers_after_wrap_closed_1() { + let (tx, mut rx) = broadcast::channel(2); + + assert_ok!(tx.send(1)); + assert_ok!(tx.send(2)); + assert_ok!(tx.send(3)); + drop(tx); + + assert_lagged!(rx.try_recv(), 1); + assert_eq!(assert_recv!(rx), 2); + assert_eq!(assert_recv!(rx), 3); + assert_closed!(rx.try_recv()); +} + +#[test] +fn lagging_receiver_recovers_after_wrap_closed_2() { + let (tx, mut rx) = broadcast::channel(2); + + assert_ok!(tx.send(1)); + assert_ok!(tx.send(2)); + assert_ok!(tx.send(3)); + assert_ok!(tx.send(4)); + drop(tx); + + assert_lagged!(rx.try_recv(), 2); + assert_eq!(assert_recv!(rx), 3); + assert_eq!(assert_recv!(rx), 4); + assert_closed!(rx.try_recv()); +} + +#[test] +fn lagging_receiver_recovers_after_wrap_open() { + let (tx, mut rx) = broadcast::channel(2); + + assert_ok!(tx.send(1)); + assert_ok!(tx.send(2)); + assert_ok!(tx.send(3)); + + assert_lagged!(rx.try_recv(), 1); + assert_eq!(assert_recv!(rx), 2); + assert_eq!(assert_recv!(rx), 3); + assert_empty!(rx); +} + +#[tokio::test] +async fn send_recv_stream_ready_deprecated() { + use tokio::stream::StreamExt; + + let (tx, mut rx) = broadcast::channel::(8); + + assert_ok!(tx.send(1)); + assert_ok!(tx.send(2)); + + assert_eq!(Some(Ok(1)), rx.next().await); + assert_eq!(Some(Ok(2)), rx.next().await); + + drop(tx); + + assert_eq!(None, rx.next().await); +} + +#[tokio::test] +async fn send_recv_stream_pending_deprecated() { + use tokio::stream::StreamExt; + + let (tx, mut rx) = broadcast::channel::(8); + + let mut recv = task::spawn(rx.next()); + assert_pending!(recv.poll()); + + assert_ok!(tx.send(1)); + + assert!(recv.is_woken()); + let val = assert_ready!(recv.poll()); + assert_eq!(val, Some(Ok(1))); +} + fn is_closed(err: broadcast::RecvError) -> bool { match err { broadcast::RecvError::Closed => true, diff --git a/tokio/tests/sync_cancellation_token.rs b/tokio/tests/sync_cancellation_token.rs new file mode 100644 index 00000000000..de543c94b1f --- /dev/null +++ b/tokio/tests/sync_cancellation_token.rs @@ -0,0 +1,220 @@ +#![cfg(tokio_unstable)] + +use tokio::pin; +use tokio::sync::CancellationToken; + +use core::future::Future; +use core::task::{Context, Poll}; +use futures_test::task::new_count_waker; + +#[test] +fn cancel_token() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + assert_eq!(false, token.is_cancelled()); + + let wait_fut = token.cancelled(); + pin!(wait_fut); + + assert_eq!( + Poll::Pending, + wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + let wait_fut_2 = token.cancelled(); + pin!(wait_fut_2); + + token.cancel(); + assert_eq!(wake_counter, 1); + assert_eq!(true, token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + wait_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn cancel_child_token_through_parent() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let child_token = token.child_token(); + assert!(!child_token.is_cancelled()); + + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Pending, + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + token.cancel(); + assert_eq!(wake_counter, 2); + assert_eq!(true, token.is_cancelled()); + assert_eq!(true, child_token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn cancel_child_token_without_parent() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let child_token_1 = token.child_token(); + + let child_fut = child_token_1.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Pending, + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + child_token_1.cancel(); + assert_eq!(wake_counter, 1); + assert_eq!(false, token.is_cancelled()); + assert_eq!(true, child_token_1.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + let child_token_2 = token.child_token(); + let child_fut_2 = child_token_2.cancelled(); + pin!(child_fut_2); + + assert_eq!( + Poll::Pending, + child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + token.cancel(); + assert_eq!(wake_counter, 3); + assert_eq!(true, token.is_cancelled()); + assert_eq!(true, child_token_2.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn create_child_token_after_parent_was_cancelled() { + for drop_child_first in [true, false].iter().cloned() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + token.cancel(); + + let child_token = token.child_token(); + assert!(child_token.is_cancelled()); + + { + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + drop(child_fut); + drop(parent_fut); + } + + if drop_child_first { + drop(child_token); + drop(token); + } else { + drop(token); + drop(child_token); + } + } +} + +#[test] +fn drop_multiple_child_tokens() { + for drop_first_child_first in &[true, false] { + let token = CancellationToken::new(); + let mut child_tokens = [None, None, None]; + for i in 0..child_tokens.len() { + child_tokens[i] = Some(token.child_token()); + } + + assert!(!token.is_cancelled()); + assert!(!child_tokens[0].as_ref().unwrap().is_cancelled()); + + for i in 0..child_tokens.len() { + if *drop_first_child_first { + child_tokens[i] = None; + } else { + child_tokens[child_tokens.len() - 1 - i] = None; + } + assert!(!token.is_cancelled()); + } + + drop(token); + } +} + +#[test] +fn drop_parent_before_child_tokens() { + let token = CancellationToken::new(); + let child1 = token.child_token(); + let child2 = token.child_token(); + + drop(token); + assert!(!child1.is_cancelled()); + + drop(child1); + drop(child2); +} diff --git a/tokio/tests/sync_mutex_owned.rs b/tokio/tests/sync_mutex_owned.rs new file mode 100644 index 00000000000..394a6708bd2 --- /dev/null +++ b/tokio/tests/sync_mutex_owned.rs @@ -0,0 +1,121 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::sync::Mutex; +use tokio::time::{interval, timeout}; +use tokio_test::task::spawn; +use tokio_test::{assert_pending, assert_ready}; + +use std::sync::Arc; +use std::time::Duration; + +#[test] +fn straight_execution() { + let l = Arc::new(Mutex::new(100)); + + { + let mut t = spawn(l.clone().lock_owned()); + let mut g = assert_ready!(t.poll()); + assert_eq!(&*g, &100); + *g = 99; + } + { + let mut t = spawn(l.clone().lock_owned()); + let mut g = assert_ready!(t.poll()); + assert_eq!(&*g, &99); + *g = 98; + } + { + let mut t = spawn(l.lock_owned()); + let g = assert_ready!(t.poll()); + assert_eq!(&*g, &98); + } +} + +#[test] +fn readiness() { + let l = Arc::new(Mutex::new(100)); + let mut t1 = spawn(l.clone().lock_owned()); + let mut t2 = spawn(l.lock_owned()); + + let g = assert_ready!(t1.poll()); + + // We can't now acquire the lease since it's already held in g + assert_pending!(t2.poll()); + + // But once g unlocks, we can acquire it + drop(g); + assert!(t2.is_woken()); + assert_ready!(t2.poll()); +} + +#[tokio::test] +/// Ensure a mutex is unlocked if a future holding the lock +/// is aborted prematurely. +async fn aborted_future_1() { + let m1: Arc> = Arc::new(Mutex::new(0)); + { + let m2 = m1.clone(); + // Try to lock mutex in a future that is aborted prematurely + timeout(Duration::from_millis(1u64), async move { + let mut iv = interval(Duration::from_millis(1000)); + m2.lock_owned().await; + iv.tick().await; + iv.tick().await; + }) + .await + .unwrap_err(); + } + // This should succeed as there is no lock left for the mutex. + timeout(Duration::from_millis(1u64), async move { + m1.lock_owned().await; + }) + .await + .expect("Mutex is locked"); +} + +#[tokio::test] +/// This test is similar to `aborted_future_1` but this time the +/// aborted future is waiting for the lock. +async fn aborted_future_2() { + let m1: Arc> = Arc::new(Mutex::new(0)); + { + // Lock mutex + let _lock = m1.clone().lock_owned().await; + { + let m2 = m1.clone(); + // Try to lock mutex in a future that is aborted prematurely + timeout(Duration::from_millis(1u64), async move { + m2.lock_owned().await; + }) + .await + .unwrap_err(); + } + } + // This should succeed as there is no lock left for the mutex. + timeout(Duration::from_millis(1u64), async move { + m1.lock_owned().await; + }) + .await + .expect("Mutex is locked"); +} + +#[test] +fn try_lock_owned() { + let m: Arc> = Arc::new(Mutex::new(0)); + { + let g1 = m.clone().try_lock_owned(); + assert_eq!(g1.is_ok(), true); + let g2 = m.clone().try_lock_owned(); + assert_eq!(g2.is_ok(), false); + } + let g3 = m.try_lock_owned(); + assert_eq!(g3.is_ok(), true); +} + +#[tokio::test] +async fn debug_format() { + let s = "debug"; + let m = Arc::new(Mutex::new(s.to_string())); + assert_eq!(format!("{:?}", s), format!("{:?}", m.lock_owned().await)); +} diff --git a/tokio/tests/sync_semaphore_owned.rs b/tokio/tests/sync_semaphore_owned.rs new file mode 100644 index 00000000000..8ed6209f3b9 --- /dev/null +++ b/tokio/tests/sync_semaphore_owned.rs @@ -0,0 +1,75 @@ +#![cfg(feature = "full")] + +use std::sync::Arc; +use tokio::sync::Semaphore; + +#[test] +fn try_acquire() { + let sem = Arc::new(Semaphore::new(1)); + { + let p1 = sem.clone().try_acquire_owned(); + assert!(p1.is_ok()); + let p2 = sem.clone().try_acquire_owned(); + assert!(p2.is_err()); + } + let p3 = sem.try_acquire_owned(); + assert!(p3.is_ok()); +} + +#[tokio::test] +async fn acquire() { + let sem = Arc::new(Semaphore::new(1)); + let p1 = sem.clone().try_acquire_owned().unwrap(); + let sem_clone = sem.clone(); + let j = tokio::spawn(async move { + let _p2 = sem_clone.acquire_owned().await; + }); + drop(p1); + j.await.unwrap(); +} + +#[tokio::test] +async fn add_permits() { + let sem = Arc::new(Semaphore::new(0)); + let sem_clone = sem.clone(); + let j = tokio::spawn(async move { + let _p2 = sem_clone.acquire_owned().await; + }); + sem.add_permits(1); + j.await.unwrap(); +} + +#[test] +fn forget() { + let sem = Arc::new(Semaphore::new(1)); + { + let p = sem.clone().try_acquire_owned().unwrap(); + assert_eq!(sem.available_permits(), 0); + p.forget(); + assert_eq!(sem.available_permits(), 0); + } + assert_eq!(sem.available_permits(), 0); + assert!(sem.try_acquire_owned().is_err()); +} + +#[tokio::test] +async fn stresstest() { + let sem = Arc::new(Semaphore::new(5)); + let mut join_handles = Vec::new(); + for _ in 0..1000 { + let sem_clone = sem.clone(); + join_handles.push(tokio::spawn(async move { + let _p = sem_clone.acquire_owned().await; + })); + } + for j in join_handles { + j.await.unwrap(); + } + // there should be exactly 5 semaphores available now + let _p1 = sem.clone().try_acquire_owned().unwrap(); + let _p2 = sem.clone().try_acquire_owned().unwrap(); + let _p3 = sem.clone().try_acquire_owned().unwrap(); + let _p4 = sem.clone().try_acquire_owned().unwrap(); + let _p5 = sem.clone().try_acquire_owned().unwrap(); + assert!(sem.try_acquire_owned().is_err()); +} diff --git a/tokio/tests/task_blocking.rs b/tokio/tests/task_blocking.rs index 4cd83d8a0d6..50c070a355a 100644 --- a/tokio/tests/task_blocking.rs +++ b/tokio/tests/task_blocking.rs @@ -1,7 +1,7 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::task; +use tokio::{runtime, task}; use tokio_test::assert_ok; use std::thread; @@ -27,3 +27,152 @@ async fn basic_blocking() { assert_eq!(out, "hello"); } } + +#[tokio::test(threaded_scheduler)] +async fn block_in_blocking() { + // Run a few times + for _ in 0..100 { + let out = assert_ok!( + tokio::spawn(async { + assert_ok!( + task::spawn_blocking(|| { + task::block_in_place(|| { + thread::sleep(Duration::from_millis(5)); + }); + "hello" + }) + .await + ) + }) + .await + ); + + assert_eq!(out, "hello"); + } +} + +#[tokio::test(threaded_scheduler)] +async fn block_in_block() { + // Run a few times + for _ in 0..100 { + let out = assert_ok!( + tokio::spawn(async { + task::block_in_place(|| { + task::block_in_place(|| { + thread::sleep(Duration::from_millis(5)); + }); + "hello" + }) + }) + .await + ); + + assert_eq!(out, "hello"); + } +} + +#[tokio::test(basic_scheduler)] +#[should_panic] +async fn no_block_in_basic_scheduler() { + task::block_in_place(|| {}); +} + +#[test] +fn yes_block_in_threaded_block_on() { + let mut rt = runtime::Builder::new() + .threaded_scheduler() + .build() + .unwrap(); + rt.block_on(async { + task::block_in_place(|| {}); + }); +} + +#[test] +#[should_panic] +fn no_block_in_basic_block_on() { + let mut rt = runtime::Builder::new().basic_scheduler().build().unwrap(); + rt.block_on(async { + task::block_in_place(|| {}); + }); +} + +#[test] +fn can_enter_basic_rt_from_within_block_in_place() { + let mut outer = tokio::runtime::Builder::new() + .threaded_scheduler() + .build() + .unwrap(); + + outer.block_on(async { + tokio::task::block_in_place(|| { + let mut inner = tokio::runtime::Builder::new() + .basic_scheduler() + .build() + .unwrap(); + + inner.block_on(async {}) + }) + }); +} + +#[test] +fn useful_panic_message_when_dropping_rt_in_rt() { + use std::panic::{catch_unwind, AssertUnwindSafe}; + + let mut outer = tokio::runtime::Builder::new() + .threaded_scheduler() + .build() + .unwrap(); + + let result = catch_unwind(AssertUnwindSafe(|| { + outer.block_on(async { + let _ = tokio::runtime::Builder::new() + .basic_scheduler() + .build() + .unwrap(); + }); + })); + + assert!(result.is_err()); + let err = result.unwrap_err(); + let err: &'static str = err.downcast_ref::<&'static str>().unwrap(); + + assert!( + err.find("Cannot drop a runtime").is_some(), + "Wrong panic message: {:?}", + err + ); +} + +#[test] +fn can_shutdown_with_zero_timeout_in_runtime() { + let mut outer = tokio::runtime::Builder::new() + .threaded_scheduler() + .build() + .unwrap(); + + outer.block_on(async { + let rt = tokio::runtime::Builder::new() + .basic_scheduler() + .build() + .unwrap(); + rt.shutdown_timeout(Duration::from_nanos(0)); + }); +} + +#[test] +fn can_shutdown_now_in_runtime() { + let mut outer = tokio::runtime::Builder::new() + .threaded_scheduler() + .build() + .unwrap(); + + outer.block_on(async { + let rt = tokio::runtime::Builder::new() + .basic_scheduler() + .build() + .unwrap(); + rt.shutdown_background(); + }); +} diff --git a/tokio/tests/task_local_set.rs b/tokio/tests/task_local_set.rs index 1a10fefa68e..bf80b8ee5f5 100644 --- a/tokio/tests/task_local_set.rs +++ b/tokio/tests/task_local_set.rs @@ -312,28 +312,17 @@ fn drop_cancels_tasks() { assert_eq!(1, Rc::strong_count(&rc1)); } -#[test] -fn drop_cancels_remote_tasks() { - // This test reproduces issue #1885. +/// Runs a test function in a separate thread, and panics if the test does not +/// complete within the specified timeout, or if the test function panics. +/// +/// This is intended for running tests whose failure mode is a hang or infinite +/// loop that cannot be detected otherwise. +fn with_timeout(timeout: Duration, f: impl FnOnce() + Send + 'static) { use std::sync::mpsc::RecvTimeoutError; let (done_tx, done_rx) = std::sync::mpsc::channel(); let thread = std::thread::spawn(move || { - let (tx, mut rx) = mpsc::channel::<()>(1024); - - let mut rt = rt(); - - let local = LocalSet::new(); - local.spawn_local(async move { while let Some(_) = rx.recv().await {} }); - local.block_on(&mut rt, async { - time::delay_for(Duration::from_millis(1)).await; - }); - - drop(tx); - - // This enters an infinite loop if the remote notified tasks are not - // properly cancelled. - drop(local); + f(); // Send a message on the channel so that the test thread can // determine if we have entered an infinite loop: @@ -349,10 +338,11 @@ fn drop_cancels_remote_tasks() { // // Note that it should definitely complete in under a minute, but just // in case CI is slow, we'll give it a long timeout. - match done_rx.recv_timeout(Duration::from_secs(60)) { + match done_rx.recv_timeout(timeout) { Err(RecvTimeoutError::Timeout) => panic!( - "test did not complete within 60 seconds, \ - we have (probably) entered an infinite loop!" + "test did not complete within {:?} seconds, \ + we have (probably) entered an infinite loop!", + timeout, ), // Did the test thread panic? We'll find out for sure when we `join` // with it. @@ -366,6 +356,49 @@ fn drop_cancels_remote_tasks() { thread.join().expect("test thread should not panic!") } +#[test] +fn drop_cancels_remote_tasks() { + // This test reproduces issue #1885. + with_timeout(Duration::from_secs(60), || { + let (tx, mut rx) = mpsc::channel::<()>(1024); + + let mut rt = rt(); + + let local = LocalSet::new(); + local.spawn_local(async move { while rx.recv().await.is_some() {} }); + local.block_on(&mut rt, async { + time::delay_for(Duration::from_millis(1)).await; + }); + + drop(tx); + + // This enters an infinite loop if the remote notified tasks are not + // properly cancelled. + drop(local); + }); +} + +#[test] +fn local_tasks_wake_join_all() { + // This test reproduces issue #2460. + with_timeout(Duration::from_secs(60), || { + use futures::future::join_all; + use tokio::task::LocalSet; + + let mut rt = rt(); + let set = LocalSet::new(); + let mut handles = Vec::new(); + + for _ in 1..=128 { + handles.push(set.spawn_local(async move { + tokio::task::spawn_local(async move {}).await.unwrap(); + })); + } + + rt.block_on(set.run_until(join_all(handles))); + }); +} + #[tokio::test] async fn local_tasks_are_polled_after_tick() { // Reproduces issues #1899 and #1900 diff --git a/tokio/tests/tcp_accept.rs b/tokio/tests/tcp_accept.rs index ff62fb96a2b..9f5b441468d 100644 --- a/tokio/tests/tcp_accept.rs +++ b/tokio/tests/tcp_accept.rs @@ -39,6 +39,7 @@ test_accept! { (ip_port_tuple, ("127.0.0.1".parse::().unwrap(), 0)), } +use pin_project_lite::pin_project; use std::pin::Pin; use std::sync::{ atomic::{AtomicUsize, Ordering::SeqCst}, @@ -47,9 +48,12 @@ use std::sync::{ use std::task::{Context, Poll}; use tokio::stream::{Stream, StreamExt}; -struct TrackPolls { - npolls: Arc, - s: S, +pin_project! { + struct TrackPolls { + npolls: Arc, + #[pin] + s: S, + } } impl Stream for TrackPolls @@ -58,11 +62,9 @@ where { type Item = S::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // safety: we do not move s - let this = unsafe { self.get_unchecked_mut() }; + let this = self.project(); this.npolls.fetch_add(1, SeqCst); - // safety: we are pinned, and so is s - unsafe { Pin::new_unchecked(&mut this.s) }.poll_next(cx) + this.s.poll_next(cx) } } @@ -80,7 +82,7 @@ async fn no_extra_poll() { s: listener.incoming(), }; assert_ok!(tx.send(Arc::clone(&incoming.npolls))); - while let Some(_) = incoming.next().await { + while incoming.next().await.is_some() { accepted_tx.send(()).unwrap(); } }); diff --git a/tokio/tests/tcp_into_split.rs b/tokio/tests/tcp_into_split.rs new file mode 100644 index 00000000000..86ed461923d --- /dev/null +++ b/tokio/tests/tcp_into_split.rs @@ -0,0 +1,131 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use std::io::{Error, ErrorKind, Result}; +use std::io::{Read, Write}; +use std::{net, thread}; + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::try_join; + +#[tokio::test] +async fn split() -> Result<()> { + const MSG: &[u8] = b"split"; + + let mut listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let (stream1, (mut stream2, _)) = try_join! { + TcpStream::connect(&addr), + listener.accept(), + }?; + let (mut read_half, mut write_half) = stream1.into_split(); + + let ((), (), ()) = try_join! { + async { + let len = stream2.write(MSG).await?; + assert_eq!(len, MSG.len()); + + let mut read_buf = vec![0u8; 32]; + let read_len = stream2.read(&mut read_buf).await?; + assert_eq!(&read_buf[..read_len], MSG); + Result::Ok(()) + }, + async { + let len = write_half.write(MSG).await?; + assert_eq!(len, MSG.len()); + Ok(()) + }, + async { + let mut read_buf = vec![0u8; 32]; + let peek_len1 = read_half.peek(&mut read_buf[..]).await?; + let peek_len2 = read_half.peek(&mut read_buf[..]).await?; + assert_eq!(peek_len1, peek_len2); + + let read_len = read_half.read(&mut read_buf[..]).await?; + assert_eq!(peek_len1, read_len); + assert_eq!(&read_buf[..read_len], MSG); + Ok(()) + }, + }?; + + Ok(()) +} + +#[tokio::test] +async fn reunite() -> Result<()> { + let listener = net::TcpListener::bind("127.0.0.1:0")?; + let addr = listener.local_addr()?; + + let handle = thread::spawn(move || { + drop(listener.accept().unwrap()); + drop(listener.accept().unwrap()); + }); + + let stream1 = TcpStream::connect(&addr).await?; + let (read1, write1) = stream1.into_split(); + + let stream2 = TcpStream::connect(&addr).await?; + let (_, write2) = stream2.into_split(); + + let read1 = match read1.reunite(write2) { + Ok(_) => panic!("Reunite should not succeed"), + Err(err) => err.0, + }; + + read1.reunite(write1).expect("Reunite should succeed"); + + handle.join().unwrap(); + Ok(()) +} + +/// Test that dropping the write half actually closes the stream. +#[tokio::test] +async fn drop_write() -> Result<()> { + const MSG: &[u8] = b"split"; + + let listener = net::TcpListener::bind("127.0.0.1:0")?; + let addr = listener.local_addr()?; + + let handle = thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + stream.write_all(MSG).unwrap(); + + let mut read_buf = [0u8; 32]; + let res = match stream.read(&mut read_buf) { + Ok(0) => Ok(()), + Ok(len) => Err(Error::new( + ErrorKind::Other, + format!("Unexpected read: {} bytes.", len), + )), + Err(err) => Err(err), + }; + + drop(stream); + + res + }); + + let stream = TcpStream::connect(&addr).await?; + let (mut read_half, write_half) = stream.into_split(); + + let mut read_buf = [0u8; 32]; + let read_len = read_half.read(&mut read_buf[..]).await?; + assert_eq!(&read_buf[..read_len], MSG); + + // drop it while the read is in progress + std::thread::spawn(move || { + thread::sleep(std::time::Duration::from_millis(50)); + drop(write_half); + }); + + match read_half.read(&mut read_buf[..]).await { + Ok(0) => {} + Ok(len) => panic!("Unexpected read: {} bytes.", len), + Err(err) => panic!("Unexpected error: {}.", err), + } + + handle.join().unwrap().unwrap(); + Ok(()) +} diff --git a/tokio/tests/tcp_split.rs b/tokio/tests/tcp_split.rs index 42f797708cc..7171dac4635 100644 --- a/tokio/tests/tcp_split.rs +++ b/tokio/tests/tcp_split.rs @@ -17,7 +17,7 @@ async fn split() -> Result<()> { let handle = thread::spawn(move || { let (mut stream, _) = listener.accept().unwrap(); - stream.write(MSG).unwrap(); + stream.write_all(MSG).unwrap(); let mut read_buf = [0u8; 32]; let read_len = stream.read(&mut read_buf).unwrap(); diff --git a/tokio/tests/time_delay.rs b/tokio/tests/time_delay.rs index e763ae03bec..e4804ec6740 100644 --- a/tokio/tests/time_delay.rs +++ b/tokio/tests/time_delay.rs @@ -2,7 +2,7 @@ #![cfg(feature = "full")] use tokio::time::{self, Duration, Instant}; -use tokio_test::{assert_pending, task}; +use tokio_test::{assert_pending, assert_ready, task}; macro_rules! assert_elapsed { ($now:expr, $ms:expr) => {{ @@ -137,6 +137,26 @@ async fn reset_future_delay_after_fire() { assert_elapsed!(now, 110); } +#[tokio::test] +async fn reset_delay_to_past() { + time::pause(); + + let now = Instant::now(); + + let mut delay = task::spawn(time::delay_until(now + ms(100))); + assert_pending!(delay.poll()); + + time::delay_for(ms(50)).await; + + assert!(!delay.is_woken()); + + delay.reset(now + ms(40)); + + assert!(delay.is_woken()); + + assert_ready!(delay.poll()); +} + #[test] #[should_panic] fn creating_delay_outside_of_context() { diff --git a/tokio/tests/time_delay_queue.rs b/tokio/tests/time_delay_queue.rs index 214b9ebee68..f04576d5a18 100644 --- a/tokio/tests/time_delay_queue.rs +++ b/tokio/tests/time_delay_queue.rs @@ -1,3 +1,4 @@ +#![allow(clippy::blacklisted_name)] #![warn(rust_2018_idioms)] #![cfg(feature = "full")] @@ -443,6 +444,103 @@ async fn insert_after_ready_poll() { assert_eq!("3", res[2]); } +#[tokio::test] +async fn reset_later_after_slot_starts() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let foo = queue.insert_at("foo", now + ms(100)); + + assert_pending!(poll!(queue)); + + delay_for(ms(80)).await; + + assert!(!queue.is_woken()); + + // At this point the queue hasn't been polled, so `elapsed` on the wheel + // for the queue is still at 0 and hence the 1ms resolution slots cover + // [0-64). Resetting the time on the entry to 120 causes it to get put in + // the [64-128) slot. As the queue knows that the first entry is within + // that slot, but doesn't know when, it must wake immediately to advance + // the wheel. + queue.reset_at(&foo, now + ms(120)); + assert!(queue.is_woken()); + + assert_pending!(poll!(queue)); + + delay_for(ms(39)).await; + assert!(!queue.is_woken()); + + delay_for(ms(1)).await; + assert!(queue.is_woken()); + + let entry = assert_ready_ok!(poll!(queue)).into_inner(); + assert_eq!(entry, "foo"); +} + +#[tokio::test] +async fn reset_earlier_after_slot_starts() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let foo = queue.insert_at("foo", now + ms(200)); + + assert_pending!(poll!(queue)); + + delay_for(ms(80)).await; + + assert!(!queue.is_woken()); + + // At this point the queue hasn't been polled, so `elapsed` on the wheel + // for the queue is still at 0 and hence the 1ms resolution slots cover + // [0-64). Resetting the time on the entry to 120 causes it to get put in + // the [64-128) slot. As the queue knows that the first entry is within + // that slot, but doesn't know when, it must wake immediately to advance + // the wheel. + queue.reset_at(&foo, now + ms(120)); + assert!(queue.is_woken()); + + assert_pending!(poll!(queue)); + + delay_for(ms(39)).await; + assert!(!queue.is_woken()); + + delay_for(ms(1)).await; + assert!(queue.is_woken()); + + let entry = assert_ready_ok!(poll!(queue)).into_inner(); + assert_eq!(entry, "foo"); +} + +#[tokio::test] +async fn insert_in_past_after_poll_fires_immediately() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("foo", now + ms(200)); + + assert_pending!(poll!(queue)); + + delay_for(ms(80)).await; + + assert!(!queue.is_woken()); + queue.insert_at("bar", now + ms(40)); + + assert!(queue.is_woken()); + + let entry = assert_ready_ok!(poll!(queue)).into_inner(); + assert_eq!(entry, "bar"); +} + fn ms(n: u64) -> Duration { Duration::from_millis(n) }