diff --git a/.cirrus.yml b/.cirrus.yml index c2344c0c3cb..fa035e81cd3 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -19,18 +19,6 @@ task: - | echo "~~~~ rustc --version ~~~~" rustc --version - - # Remove any existing patch statements - mv Cargo.toml Cargo.toml.bck - sed -n '/\[patch.crates-io\]/q;p' Cargo.toml.bck > Cargo.toml - - # Patch all crates - cat ci/patch.toml >> Cargo.toml - - # Print `Cargo.toml` for debugging - echo "~~~~ Cargo.toml ~~~~" - cat Cargo.toml - echo "~~~~~~~~~~~~~~~~~~~~" test_script: - . $HOME/.cargo/env - cargo test --all @@ -39,4 +27,4 @@ task: # i686_test_script: # - . $HOME/.cargo/env # - | - # cargo test --all --exclude tokio-tls --exclude tokio-macros --target i686-unknown-freebsd + # cargo test --all --exclude tokio-macros --target i686-unknown-freebsd diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md deleted file mode 100644 index 269780639f3..00000000000 --- a/.github/ISSUE_TEMPLATE.md +++ /dev/null @@ -1,51 +0,0 @@ - - -## Version - - - -## Platform - - - -## Subcrates - - - -## Description - - diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000000..98cf6116a72 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,36 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: A-tokio, C-bug +assignees: '' + +--- + +**Version** +List the versions of all `tokio` crates you are using. The easiest way to get +this information is using `cargo-tree`. + +`cargo install cargo-tree` +(see install here: https://github.com/sfackler/cargo-tree) + +Then: + +`cargo tree | grep tokio` + +**Platform** +The output of `uname -a` (UNIX), or version and 32 or 64-bit (Windows) + +**Description** +Enter your issue details here. +One way to structure the description: + +[short summary of the bug] + +I tried this code: + +[code sample that causes the bug] + +I expected to see this happen: [explanation] + +Instead, this happened: [explanation] diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000000..e90a4933174 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: A-tokio, C-feature-request +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md new file mode 100644 index 00000000000..03112f55412 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.md @@ -0,0 +1,16 @@ +--- +name: Question +about: Please use the discussions tab for questions +title: '' +labels: '' +assignees: '' + +--- + +Please post your question as a discussion here: +https://github.com/tokio-rs/tokio/discussions + + +You may also be able to find help here: +https://discord.gg/tokio +https://users.rust-lang.org/ diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index bbff0233b35..6b3db9ae609 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -5,6 +5,9 @@ the requirements below. Bug fixes and new features should include tests. Contributors guide: https://github.com/tokio-rs/tokio/blob/master/CONTRIBUTING.md + +The contributors guide includes instructions for running rustfmt and building the +documentation, which requires special commands beyond `cargo fmt` and `cargo doc`. --> ## Motivation diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml new file mode 100644 index 00000000000..a901a0fd014 --- /dev/null +++ b/.github/workflows/audit.yml @@ -0,0 +1,22 @@ +name: Security Audit + +on: + push: + branches: + - master + paths: + - '**/Cargo.toml' + schedule: + - cron: '0 2 * * *' # run at 2 AM UTC + +jobs: + security-audit: + runs-on: ubuntu-latest + if: "!contains(github.event.head_commit.message, 'ci skip')" + steps: + - uses: actions/checkout@v2 + + - name: Audit Check + uses: actions-rs/audit-check@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000000..054cf1166f0 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,240 @@ +on: + push: + branches: ["master"] + pull_request: + branches: ["master"] + +name: CI + +env: + RUSTFLAGS: -Dwarnings + RUST_BACKTRACE: 1 + nightly: nightly-2020-07-12 + minrust: 1.39.0 + +jobs: + # Depends on all action sthat are required for a "successful" CI run. + tests-pass: + name: all systems go + runs-on: ubuntu-latest + needs: + - test + - test-unstable + - miri + - cross + - features + - minrust + - fmt + - clippy + - docs + - loom + steps: + - run: exit 0 + + test: + name: test tokio full + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: + - windows-latest + - ubuntu-latest + - macos-latest + steps: + - uses: actions/checkout@v2 + - name: Install Rust + run: rustup update stable + - name: Install cargo-hack + run: cargo install cargo-hack + + # Run `tokio` with `full` features. This excludes testing utilities which + # can alter the runtime behavior of Tokio. + - name: test tokio full + run: cargo test --features full + working-directory: tokio + + # Check `tokio` with `full + parking_lot` to make sure it compiles. + - name: check tokio full,parking_lot + run: cargo check --features full,parking_lot + working-directory: tokio + + # Test **all** crates in the workspace with all features. + - name: test all --all-features + run: cargo test --workspace --all-features + + # Run integration tests for each feature + - name: test tests-integration --each-feature + run: cargo hack test --each-feature + working-directory: tests-integration + + # Run macro build tests + - name: test tests-build --each-feature + run: cargo hack test --each-feature + working-directory: tests-build + + test-unstable: + name: test tokio full --unstable + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: + - windows-latest + - ubuntu-latest + - macos-latest + steps: + - uses: actions/checkout@v2 + - name: Install Rust + run: rustup update stable + + # Run `tokio` with "unstable" cfg flag. + - name: test tokio full --cfg unstable + run: cargo test --features full + working-directory: tokio + env: + RUSTFLAGS: '--cfg tokio_unstable' + + miri: + name: miri + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + toolchain: ${{ env.nightly }} + override: true + - name: Install Miri + run: | + set -e + rustup component add miri + cargo miri setup + rm -rf tokio/tests + + - name: miri + run: cargo miri test --features rt-core,rt-threaded,rt-util,sync -- -- task + working-directory: tokio + + cross: + name: cross + runs-on: ubuntu-latest + strategy: + matrix: + target: + - i686-unknown-linux-gnu + - powerpc-unknown-linux-gnu + - powerpc64-unknown-linux-gnu + - mips-unknown-linux-gnu + - arm-linux-androideabi + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + toolchain: stable + target: ${{ matrix.target }} + override: true + - uses: actions-rs/cargo@v1 + with: + use-cross: true + command: check + args: --workspace --target ${{ matrix.target }} + + features: + name: features + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + toolchain: ${{ env.nightly }} + override: true + - name: Install cargo-hack + run: cargo install cargo-hack + + - name: check --each-feature + run: cargo hack check --all --each-feature -Z avoid-dev-deps + + # Try with unstable feature flags + - name: check --each-feature --unstable + run: cargo hack check --all --each-feature -Z avoid-dev-deps + env: + RUSTFLAGS: --cfg tokio_unstable + + minrust: + name: minrust + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + toolchain: ${{ env.minrust }} + override: true + + - name: "test --workspace --all-features" + run: cargo check --workspace --all-features + + fmt: + name: fmt + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Install Rust + run: rustup update stable + - name: Install rustfmt + run: rustup component add rustfmt + + # Check fmt + - name: "rustfmt --check" + # Workaround for rust-lang/cargo#7732 + run: rustfmt --check --edition 2018 $(find . -name '*.rs' -print) + + clippy: + name: clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Install Rust + run: rustup update stable + - name: Install clippy + run: rustup component add clippy + + # Run clippy + - name: "clippy --all" + run: cargo clippy --all --tests + + docs: + name: docs + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + toolchain: ${{ env.nightly }} + override: true + + - name: "doc --lib --all-features" + run: cargo doc --lib --no-deps --all-features + env: + RUSTDOCFLAGS: --cfg docsrs + + loom: + name: loom + runs-on: ubuntu-latest + strategy: + matrix: + scope: + - --skip loom_pool + - loom_pool::group_a + - loom_pool::group_b + - loom_pool::group_c + - loom_pool::group_d + steps: + - uses: actions/checkout@v2 + - name: Install Rust + run: rustup update stable + + - name: loom ${{ matrix.scope }} + run: cargo test --lib --release --features full -- --nocapture $SCOPE + working-directory: tokio + env: + RUSTFLAGS: --cfg loom --cfg tokio_unstable + LOOM_MAX_PREEMPTIONS: 2 + SCOPE: ${{ matrix.scope }} diff --git a/.github/workflows/pr-audit.yml b/.github/workflows/pr-audit.yml new file mode 100644 index 00000000000..26c0ee2f119 --- /dev/null +++ b/.github/workflows/pr-audit.yml @@ -0,0 +1,32 @@ +name: Pull Request Security Audit + +on: + push: + paths: + - '**/Cargo.toml' + pull_request: + paths: + - '**/Cargo.toml' + +jobs: + security-audit: + runs-on: ubuntu-latest + if: "!contains(github.event.head_commit.message, 'ci skip')" + steps: + - uses: actions/checkout@v2 + + - name: Install cargo-audit + uses: actions-rs/cargo@v1 + with: + command: install + args: cargo-audit + + - name: Generate lockfile + uses: actions-rs/cargo@v1 + with: + command: generate-lockfile + + - name: Audit dependencies + uses: actions-rs/cargo@v1 + with: + command: audit diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0aa58273081..70bc4559b22 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -15,12 +15,14 @@ It should be considered a map to help you navigate the process. The [dev channel][dev] is available for any concerns not covered in this guide, please join us! -[dev]: https://discord.gg/6yGkFeN +[dev]: https://discord.gg/tokio ## Conduct The Tokio project adheres to the [Rust Code of Conduct][coc]. This describes -the _minimum_ behavior expected from all contributors. Instances of violations of the Code of Conduct can be reported by contacting the project team at [moderation@tokio.rs](mailto:moderation@tokio.rs). +the _minimum_ behavior expected from all contributors. Instances of violations of the +Code of Conduct can be reported by contacting the project team at +[moderation@tokio.rs](mailto:moderation@tokio.rs). [coc]: https://github.com/rust-lang/rust/blob/master/CODE_OF_CONDUCT.md @@ -29,8 +31,8 @@ the _minimum_ behavior expected from all contributors. Instances of violations o For any issue, there are fundamentally three ways an individual can contribute: 1. By opening the issue for discussion: For instance, if you believe that you - have uncovered a bug in Tokio, creating a new issue in the tokio-rs/tokio - issue tracker is the way to report it. + have discovered a bug in Tokio, creating a new issue in [the tokio-rs/tokio + issue tracker][issue] is the way to report it. 2. By helping to triage the issue: This can be done by providing supporting details (a test case that demonstrates a bug), providing @@ -42,21 +44,25 @@ For any issue, there are fundamentally three ways an individual can contribute: often, by opening a Pull Request that changes some bit of something in Tokio in a concrete and reviewable manner. +[issue]: https://github.com/tokio-rs/tokio/issues + **Anybody can participate in any stage of contribution**. We urge you to participate in the discussion around bugs and participate in reviewing PRs. ### Asking for General Help If you have reviewed existing documentation and still have questions or are -having problems, you can open an issue asking for help. +having problems, you can [open a discussion] asking for help. In exchange for receiving help, we ask that you contribute back a documentation PR that helps others avoid the problems that you encountered. +[open a discussion]: https://github.com/tokio-rs/tokio/discussions/new + ### Submitting a Bug Report -When opening a new issue in the Tokio issue tracker, users will be presented -with a [basic template][template] that should be filled in. If you believe that you have +When opening a new issue in the Tokio issue tracker, you will be presented +with a basic template that should be filled in. If you believe that you have uncovered a bug, please fill out this form, following the template to the best of your ability. Do not worry if you cannot answer every detail, just fill in what you can. @@ -72,7 +78,6 @@ cases should be limited, as much as possible, to using only Tokio APIs. See [How to create a Minimal, Complete, and Verifiable example][mcve]. [mcve]: https://stackoverflow.com/help/mcve -[template]: .github/PULL_REQUEST_TEMPLATE.md ### Triaging a Bug Report @@ -132,8 +137,13 @@ RUSTDOCFLAGS="--cfg docsrs" cargo +nightly doc --all-features ``` The `cargo fmt` command does not work on the Tokio codebase. You can use the command below instead: + ``` +# Mac or Linux rustfmt --check --edition 2018 $(find . -name '*.rs' -print) + +# Powershell +Get-ChildItem . -Filter "*.rs" -Recurse | foreach { rustfmt --check --edition 2018 $_.FullName } ``` The `--check` argument prints the things that need to be fixed. If you remove it, `rustfmt` will update your files locally instead. @@ -250,7 +260,7 @@ That said, if you have a number of commits that are "checkpoints" and don't represent a single logical change, please squash those together. Note that multiple commits often get squashed when they are landed (see the -notes about [commit squashing]). +notes about [commit squashing](#commit-squashing)). #### Commit message guidelines @@ -321,7 +331,7 @@ in order to evaluate whether the changes are correct and necessary. Keep an eye out for comments from code owners to provide guidance on conflicting feedback. -**Once the PR is open, do not rebase the commits**. See [Commit Squashing] for +**Once the PR is open, do not rebase the commits**. See [Commit Squashing](#commit-squashing) for more details. ### Commit Squashing diff --git a/Cargo.toml b/Cargo.toml index ebae0d3db6b..39d2936645b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,6 @@ members = [ "tokio", "tokio-macros", "tokio-test", - "tokio-tls", "tokio-util", # Internal diff --git a/README.md b/README.md index 5a661998f5a..da9078c5824 100644 --- a/README.md +++ b/README.md @@ -20,14 +20,14 @@ the Rust programming language. It is: [crates-badge]: https://img.shields.io/crates/v/tokio.svg [crates-url]: https://crates.io/crates/tokio [mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg -[mit-url]: LICENSE +[mit-url]: https://github.com/tokio-rs/tokio/blob/master/LICENSE [azure-badge]: https://dev.azure.com/tokio-rs/Tokio/_apis/build/status/tokio-rs.tokio?branchName=master [azure-url]: https://dev.azure.com/tokio-rs/Tokio/_build/latest?definitionId=1&branchName=master [discord-badge]: https://img.shields.io/discord/500028886025895936.svg?logo=discord&style=flat-square [discord-url]: https://discord.gg/tokio [Website](https://tokio.rs) | -[Guides](https://tokio.rs/docs/overview/) | +[Guides](https://tokio.rs/tokio/tutorial) | [API Docs](https://docs.rs/tokio/latest/tokio) | [Roadmap](https://github.com/tokio-rs/tokio/blob/master/ROADMAP.md) | [Chat](https://discord.gg/tokio) @@ -90,19 +90,27 @@ async fn main() -> Result<(), Box> { } ``` -More examples can be found [here](examples). +More examples can be found [here][examples]. For a larger "real world" example, see the +[mini-redis] repository. + +[examples]: https://github.com/tokio-rs/tokio/tree/master/examples +[mini-redis]: https://github.com/tokio-rs/mini-redis/ + +To see a list of the available features flags that can be enabled, check our +[docs][feature-flag-docs]. ## Getting Help First, see if the answer to your question can be found in the [Guides] or the [API documentation]. If the answer is not there, there is an active community in the [Tokio Discord server][chat]. We would be happy to try to answer your -question. Last, if that doesn't work, try opening an [issue] with the question. +question. You can also ask your question on [the discussions page][discussions]. -[Guides]: https://tokio.rs/docs/overview/ +[Guides]: https://tokio.rs/tokio/tutorial [API documentation]: https://docs.rs/tokio/latest/tokio [chat]: https://discord.gg/tokio -[issue]: https://github.com/tokio-rs/tokio/issues/new +[discussions]: https://github.com/tokio-rs/tokio/discussions +[feature-flag-docs]: https://docs.rs/tokio/#feature-flags ## Contributing @@ -149,15 +157,15 @@ several other libraries, including: ## Supported Rust Versions -Tokio is built against the latest stable, nightly, and beta Rust releases. The -minimum version supported is the stable release from three months before the -current stable release version. For example, if the latest stable Rust is 1.29, -the minimum version supported is 1.26. The current Tokio version is not -guaranteed to build on Rust versions earlier than the minimum supported version. +Tokio is built against the latest stable release. The minimum supported version is 1.39. +The current Tokio version is not guaranteed to build on Rust versions earlier than the +minimum supported version. ## License -This project is licensed under the [MIT license](LICENSE). +This project is licensed under the [MIT license]. + +[MIT license]: https://github.com/tokio-rs/tokio/blob/master/LICENSE ### Contribution diff --git a/azure-pipelines.yml b/azure-pipelines.yml deleted file mode 100644 index cc50f3c88ca..00000000000 --- a/azure-pipelines.yml +++ /dev/null @@ -1,121 +0,0 @@ -trigger: ["master"] -pr: ["master"] - -variables: - RUSTFLAGS: -Dwarnings - nightly: nightly-2020-01-25 - -jobs: -# Test top level crate -- template: ci/azure-test-stable.yml - parameters: - name: test_tokio - rust: stable - displayName: Test tokio - cross: true - crates: - - tokio - - tests-integration - -# Test sub crates -- template: ci/azure-test-stable.yml - parameters: - name: test_linux - displayName: Test sub crates - - rust: stable - crates: - - tokio-macros - - tokio-test - - tokio-tls - - tokio-util - - examples - -# Run integration tests -- template: ci/azure-test-integration.yml - parameters: - name: test_integration - displayName: Integration tests - rust: stable - -# Run tests from `tests-build`. This requires a different process -- template: ci/azure-test-build.yml - parameters: - name: test_build - displayName: Test build permutations - rust: stable - -# Run miri tests -- template: ci/azure-miri.yml - parameters: - name: miri - -# Try cross compiling -- template: ci/azure-cross-compile.yml - parameters: - name: cross - rust: stable - -# Check each feature works properly -- template: ci/azure-check-features.yml - parameters: - rust: $(nightly) - name: check_features - -# This represents the minimum Rust version supported by -# Tokio. Updating this should be done in a dedicated PR and -# cannot be greater than two 0.x releases prior to the -# current stable. -# -# Tests are not run as tests may require newer versions of -# rust. -- template: ci/azure-check-minrust.yml - parameters: - name: minrust - rust: 1.39.0 - -# Check formatting -- template: ci/azure-rustfmt.yml - parameters: - rust: stable - name: rustfmt - -# Apply clippy lints to all crates -- template: ci/azure-clippy.yml - parameters: - rust: stable - name: clippy - -# Check doc generation -- template: ci/azure-check-docs.yml - parameters: - rust: $(nightly) - name: docs - -# - template: ci/azure-tsan.yml -# parameters: -# name: tsan -# rust: stable - -# Run loom tests -- template: ci/azure-loom.yml - parameters: - name: loom - rust: stable - -- template: ci/azure-deploy-docs.yml - parameters: - rust: stable - dependsOn: - - rustfmt - - docs - - clippy - - test_tokio - - test_linux - - test_integration - - test_build - - loom - - miri - - cross - - minrust - - check_features -# - tsan diff --git a/ci/azure-cargo-check.yml b/ci/azure-cargo-check.yml deleted file mode 100644 index 5bf7af20150..00000000000 --- a/ci/azure-cargo-check.yml +++ /dev/null @@ -1,29 +0,0 @@ -parameters: - noDefaultFeatures: '--no-default-features' - -jobs: -- job: ${{ parameters.name }} - displayName: ${{ parameters.displayName }} - pool: - vmImage: ubuntu-16.04 - steps: - - template: azure-install-rust.yml - parameters: - rust_version: ${{ parameters.rust }} - - - template: azure-is-release.yml - - - ${{ each crate in parameters.crates }}: - - ${{ each feature in crate.value }}: - - script: cargo check ${{ parameters.noDefaultFeatures }} --features ${{ feature }} - displayName: Check `${{ crate.key }}`, features = ${{ feature }} - workingDirectory: $(Build.SourcesDirectory)/${{ crate.key }} - condition: and(succeeded(), not(variables['isRelease'])) - - - template: azure-patch-crates.yml - - - ${{ each crate in parameters.crates }}: - - ${{ each feature in crate.value }}: - - script: cargo check ${{ parameters.noDefaultFeatures }} --features ${{ feature }} - displayName: Check `${{ crate.key }}`, features = ${{ feature }} - workingDirectory: $(Build.SourcesDirectory)/${{ crate.key }} diff --git a/ci/azure-check-docs.yml b/ci/azure-check-docs.yml deleted file mode 100644 index 6f94f339ee4..00000000000 --- a/ci/azure-check-docs.yml +++ /dev/null @@ -1,15 +0,0 @@ -jobs: -# Check docs -- job: ${{ parameters.name }} - displayName: Check docs - pool: - vmImage: ubuntu-16.04 - steps: - - template: azure-install-rust.yml - parameters: - rust_version: ${{ parameters.rust }} - - - script: | - RUSTDOCFLAGS="--cfg docsrs" cargo doc --lib --no-deps --all-features - displayName: Check docs - diff --git a/ci/azure-check-features.yml b/ci/azure-check-features.yml deleted file mode 100644 index f5985843e10..00000000000 --- a/ci/azure-check-features.yml +++ /dev/null @@ -1,32 +0,0 @@ -jobs: -- job: ${{ parameters.name }} - displayName: Check features - strategy: - matrix: - Linux: - vmImage: ubuntu-16.04 - MacOS: - vmImage: macos-latest - Windows: - vmImage: vs2017-win2016 - pool: - vmImage: $(vmImage) - - steps: - - template: azure-install-rust.yml - parameters: - rust_version: ${{ parameters.rust }} - - - template: azure-patch-crates.yml - - - script: cargo install cargo-hack - displayName: Install cargo-hack - - # Check each feature works properly - # * --each-feature - # run for each feature which includes --no-default-features and default features of package - # * -Z avoid-dev-deps - # build without dev-dependencies to avoid https://github.com/rust-lang/cargo/issues/4866 - # tracking-issue: https://github.com/rust-lang/cargo/issues/5133 - - script: cargo hack check --all --each-feature -Z avoid-dev-deps - displayName: cargo hack check --all --each-feature diff --git a/ci/azure-check-minrust.yml b/ci/azure-check-minrust.yml deleted file mode 100644 index 1a28f53bda1..00000000000 --- a/ci/azure-check-minrust.yml +++ /dev/null @@ -1,14 +0,0 @@ -jobs: -- job: ${{ parameters.name }} - displayName: Min supported Rust version - pool: - vmImage: ubuntu-16.04 - steps: - - template: azure-install-rust.yml - parameters: - rust_version: ${{ parameters.rust }} - - - template: azure-patch-crates.yml - - - script: cargo check --all - displayName: cargo check --all diff --git a/ci/azure-clippy.yml b/ci/azure-clippy.yml deleted file mode 100644 index 58ab318f718..00000000000 --- a/ci/azure-clippy.yml +++ /dev/null @@ -1,16 +0,0 @@ -jobs: -- job: ${{ parameters.name }} - displayName: Clippy - pool: - vmImage: ubuntu-16.04 - steps: - - template: azure-install-rust.yml - parameters: - rust_version: ${{ parameters.rust }} - - script: | - rustup component add clippy - cargo clippy --version - displayName: Install clippy - - script: | - cargo clippy --all --all-features - displayName: cargo clippy --all diff --git a/ci/azure-cross-compile.yml b/ci/azure-cross-compile.yml deleted file mode 100644 index 74acaee282f..00000000000 --- a/ci/azure-cross-compile.yml +++ /dev/null @@ -1,44 +0,0 @@ -jobs: -- job: ${{ parameters.name }} - displayName: ${{ parameters.displayName }} - strategy: - matrix: - i686: - vmImage: ubuntu-16.04 - target: i686-unknown-linux-gnu - powerpc: - vmImage: ubuntu-16.04 - target: powerpc-unknown-linux-gnu - powerpc64: - vmImage: ubuntu-16.04 - target: powerpc64-unknown-linux-gnu - mips: - vmImage: ubuntu-16.04 - target: mips-unknown-linux-gnu - arm: - vmImage: ubuntu-16.04 - target: arm-linux-androideabi - pool: - vmImage: $(vmImage) - steps: - - template: azure-install-rust.yml - parameters: - rust_version: ${{ parameters.rust }} - - - script: sudo apt-get update - displayName: apt-get update - - - script: sudo apt-get install gcc-multilib - displayName: Install gcc-multilib - - - script: cargo install cross - displayName: Install cross - - # Always patch - - template: azure-patch-crates.yml - - - script: cross check --all --exclude tokio-tls --target $(target) - displayName: Check source - - # - script: cross check --tests --all --exclude tokio-tls --target $(target) - # displayName: Check tests diff --git a/ci/azure-deploy-docs.yml b/ci/azure-deploy-docs.yml deleted file mode 100644 index 77ec1b0f64a..00000000000 --- a/ci/azure-deploy-docs.yml +++ /dev/null @@ -1,39 +0,0 @@ -parameters: - dependsOn: [] - -jobs: -- job: documentation - displayName: 'Deploy API Documentation' - condition: and(succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/master')) - pool: - vmImage: 'Ubuntu 16.04' - dependsOn: - - ${{ parameters.dependsOn }} - steps: - - template: azure-install-rust.yml - parameters: - # rust_version: stable - rust_version: ${{ parameters.rust }} - - script: | - cargo doc --all --no-deps --all-features - cp -R target/doc '$(Build.BinariesDirectory)' - displayName: 'Generate Documentation' - - script: | - set -e - - git --version - ls -la - git init - git config user.name 'Deployment Bot (from Azure Pipelines)' - git config user.email 'deploy@tokio-rs.com' - git config --global credential.helper 'store --file ~/.my-credentials' - printf "protocol=https\nhost=github.com\nusername=carllerche\npassword=%s\n\n" "$GITHUB_TOKEN" | git credential-store --file ~/.my-credentials store - git remote add origin https://github.com/tokio-rs/tokio - git checkout -b gh-pages - git add . - git commit -m 'Deploy Tokio API documentation' - git push -f origin gh-pages - env: - GITHUB_TOKEN: $(githubPersonalToken) - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Deploy Documentation' diff --git a/ci/azure-install-rust.yml b/ci/azure-install-rust.yml deleted file mode 100644 index 4cf5ca3f905..00000000000 --- a/ci/azure-install-rust.yml +++ /dev/null @@ -1,40 +0,0 @@ -steps: - # Linux and macOS. - - script: | - set -e - - if [ "$RUSTUP_TOOLCHAIN" == "nightly" ]; then - echo "++ getting latest miri version" - export RUSTUP_TOOLCHAIN="nightly-$(curl -s https://rust-lang.github.io/rustup-components-history/x86_64-unknown-linux-gnu/miri)" - echo "$RUSTUP_TOOLCHAIN" - fi - - curl https://sh.rustup.rs -sSf | sh -s -- -y --profile minimal --default-toolchain none - export PATH=$PATH:$HOME/.cargo/bin - rustup toolchain install $RUSTUP_TOOLCHAIN - rustup default $RUSTUP_TOOLCHAIN - echo "##vso[task.setvariable variable=PATH;]$PATH:$HOME/.cargo/bin" - env: - RUSTUP_TOOLCHAIN: ${{parameters.rust_version}} - displayName: "Install rust (*nix)" - condition: not(eq(variables['Agent.OS'], 'Windows_NT')) - - # Windows. - - script: | - curl -sSf -o rustup-init.exe https://win.rustup.rs - rustup-init.exe -y --profile minimal --default-toolchain none - set PATH=%PATH%;%USERPROFILE%\.cargo\bin - rustup toolchain install %RUSTUP_TOOLCHAIN% - rustup default %RUSTUP_TOOLCHAIN% - echo "##vso[task.setvariable variable=PATH;]%PATH%;%USERPROFILE%\.cargo\bin" - env: - RUSTUP_TOOLCHAIN: ${{parameters.rust_version}} - displayName: "Install rust (windows)" - condition: eq(variables['Agent.OS'], 'Windows_NT') - - # All platforms. - - script: | - rustup toolchain list - rustc -Vv - cargo -V - displayName: Query rust and cargo versions diff --git a/ci/azure-is-release.yml b/ci/azure-is-release.yml deleted file mode 100644 index d7271b345fe..00000000000 --- a/ci/azure-is-release.yml +++ /dev/null @@ -1,9 +0,0 @@ -steps: - - bash: | - set -e - - if git log --no-merges -1 --format='%B' | grep -qF '[ci-release]'; then - echo "##vso[task.setvariable variable=isRelease]true" - fi - failOnStderr: true - displayName: Check if release commit diff --git a/ci/azure-loom.yml b/ci/azure-loom.yml deleted file mode 100644 index 001aedec263..00000000000 --- a/ci/azure-loom.yml +++ /dev/null @@ -1,29 +0,0 @@ -jobs: -- job: ${{ parameters.name }} - displayName: Loom tests - strategy: - matrix: - rest: - scope: --skip loom_pool - pool_group_a: - scope: loom_pool::group_a - pool_group_b: - scope: loom_pool::group_b - pool_group_c: - scope: loom_pool::group_c - pool_group_d: - scope: loom_pool::group_d - pool: - vmImage: ubuntu-16.04 - - steps: - - template: azure-install-rust.yml - parameters: - rust_version: ${{ parameters.rust }} - - - script: RUSTFLAGS="--cfg loom" cargo test --lib --release --features "full" -- --nocapture $(scope) - env: - LOOM_MAX_PREEMPTIONS: 2 - CI: 'True' - displayName: $(scope) - workingDirectory: $(Build.SourcesDirectory)/tokio diff --git a/ci/azure-miri.yml b/ci/azure-miri.yml deleted file mode 100644 index fb886edc7d4..00000000000 --- a/ci/azure-miri.yml +++ /dev/null @@ -1,23 +0,0 @@ -jobs: -- job: ${{ parameters.name }} - displayName: Miri - pool: - vmImage: ubuntu-16.04 - - steps: - - template: azure-install-rust.yml - parameters: - rust_version: nightly - - - script: | - rustup component add miri - cargo miri setup - rm -rf $(Build.SourcesDirectory)/tokio/tests - displayName: Install miri - - # TODO: enable all tests once they pass - - script: cargo miri test --features rt-core,rt-threaded,rt-util,sync -- -- task - env: - CI: 'True' - displayName: cargo miri test - workingDirectory: $(Build.SourcesDirectory)/tokio diff --git a/ci/azure-patch-crates.yml b/ci/azure-patch-crates.yml deleted file mode 100644 index 7bf96e60220..00000000000 --- a/ci/azure-patch-crates.yml +++ /dev/null @@ -1,16 +0,0 @@ -steps: - - script: | - set -e - - # Remove any existing patch statements - mv Cargo.toml Cargo.toml.bck - sed -n '/\[patch.crates-io\]/q;p' Cargo.toml.bck > Cargo.toml - - # Patch all crates - cat ci/patch.toml >> Cargo.toml - - # Print `Cargo.toml` for debugging - echo "~~~~ Cargo.toml ~~~~" - cat Cargo.toml - echo "~~~~~~~~~~~~~~~~~~~~" - displayName: Patch Cargo.toml diff --git a/ci/azure-rustfmt.yml b/ci/azure-rustfmt.yml deleted file mode 100644 index f8e79d7209a..00000000000 --- a/ci/azure-rustfmt.yml +++ /dev/null @@ -1,18 +0,0 @@ -jobs: -# Check formatting -- job: ${{ parameters.name }} - displayName: Check rustfmt - pool: - vmImage: ubuntu-16.04 - steps: - - template: azure-install-rust.yml - parameters: - rust_version: ${{ parameters.rust }} - - script: | - rustup component add rustfmt - cargo fmt --version - displayName: Install rustfmt - - script: | - # Workaround for rust-lang/cargo#7732 - rustfmt --check --edition 2018 $(find . -name '*.rs' -print) - displayName: Check formatting diff --git a/ci/azure-test-build.yml b/ci/azure-test-build.yml deleted file mode 100644 index 944b9e24f62..00000000000 --- a/ci/azure-test-build.yml +++ /dev/null @@ -1,17 +0,0 @@ -jobs: -- job: ${{ parameters.name }} - displayName: ${{ parameters.displayName }} - pool: - vmImage: 'Ubuntu 16.04' - - steps: - - template: azure-install-rust.yml - parameters: - rust_version: ${{ parameters.rust }} - - - script: cargo install cargo-hack - displayName: Install cargo-hack - - - script: cargo hack test --each-feature - displayName: cargo hack test --each-feature - workingDirectory: $(Build.SourcesDirectory)/tests-build diff --git a/ci/azure-test-integration.yml b/ci/azure-test-integration.yml deleted file mode 100644 index f498a649679..00000000000 --- a/ci/azure-test-integration.yml +++ /dev/null @@ -1,28 +0,0 @@ -jobs: -- job: ${{ parameters.name }} - displayName: ${{ parameters.displayName }} - strategy: - matrix: - Linux: - vmImage: ubuntu-16.04 - MacOS: - vmImage: macos-latest - Windows: - vmImage: vs2017-win2016 - pool: - vmImage: $(vmImage) - - steps: - - template: azure-install-rust.yml - parameters: - rust_version: ${{ parameters.rust }} - - - script: cargo install cargo-hack - displayName: Install cargo-hack - - # Run with all crate features - - script: cargo hack test --each-feature - env: - CI: 'True' - displayName: cargo hack test --each-feature - workingDirectory: $(Build.SourcesDirectory)/tests-integration diff --git a/ci/azure-test-nightly.yml b/ci/azure-test-nightly.yml deleted file mode 100644 index bbb44442619..00000000000 --- a/ci/azure-test-nightly.yml +++ /dev/null @@ -1,19 +0,0 @@ -jobs: -- job: ${{ parameters.name }} - displayName: ${{ parameters.displayName }} - pool: - vmImage: ubuntu-16.04 - - steps: - - template: azure-install-rust.yml - parameters: - rust_version: ${{ parameters.rust }} - - - template: azure-patch-crates.yml - - - script: cargo check --all - displayName: cargo check --all - - # Check benches - - script: cargo check --benches --all - displayName: Check benchmarks diff --git a/ci/azure-test-stable.yml b/ci/azure-test-stable.yml deleted file mode 100644 index ce22c942f38..00000000000 --- a/ci/azure-test-stable.yml +++ /dev/null @@ -1,47 +0,0 @@ -jobs: -- job: ${{ parameters.name }} - displayName: ${{ parameters.displayName }} - strategy: - matrix: - Linux: - vmImage: ubuntu-16.04 - - ${{ if parameters.cross }}: - MacOS: - vmImage: macos-latest - Windows: - vmImage: vs2017-win2016 - pool: - vmImage: $(vmImage) - - steps: - - template: azure-install-rust.yml - parameters: - rust_version: ${{ parameters.rust }} - - - template: azure-is-release.yml - - - ${{ each crate in parameters.crates }}: - # Run with all crate features - - script: cargo test --all-features - env: - RUST_BACKTRACE: 1 - CI: 'True' - displayName: ${{ crate }} - cargo test --all-features - workingDirectory: $(Build.SourcesDirectory)/${{ crate }} - - # Check benches - - script: cargo check --all-features --benches - displayName: ${{ crate }} - cargo check --benches - workingDirectory: $(Build.SourcesDirectory)/${{ crate }} - - - template: azure-patch-crates.yml - - - ${{ each crate in parameters.crates }}: - # Run with all crate features - - script: cargo test --all-features - env: - RUST_BACKTRACE: 1 - CI: 'True' - displayName: ${{ crate }} - cargo test --all-features - workingDirectory: $(Build.SourcesDirectory)/${{ crate }} diff --git a/ci/azure-tsan.yml b/ci/azure-tsan.yml deleted file mode 100644 index 0104697e19b..00000000000 --- a/ci/azure-tsan.yml +++ /dev/null @@ -1,34 +0,0 @@ -jobs: -- job: ${{ parameters.name }} - displayName: TSAN - strategy: - matrix: - Timer: - cmd: cargo test -p tokio-timer --test hammer - pool: - vmImage: ubuntu-16.04 - steps: - - template: azure-install-rust.yml - parameters: - rust_version: ${{ parameters.rust }} - - - template: azure-patch-crates.yml - - script: | - set -e - - # Make sure the benchmarks compile - export ASAN_OPTIONS="detect_odr_violation=0 detect_leaks=0" - export TSAN_OPTIONS="suppressions=`pwd`/ci/tsan" - export RUST_BACKTRACE=1 - - # Run address sanitizer - RUSTFLAGS="-Z sanitizer=address" \ - $(cmd) --target x86_64-unknown-linux-gnu - - # Run thread sanitizer - RUSTFLAGS="-Z sanitizer=thread" \ - $(cmd) --target x86_64-unknown-linux-gnu - displayName: TSAN / MSAN - env: - TSAN: yes - diff --git a/ci/patch.toml b/ci/patch.toml deleted file mode 100644 index 22311cf9a76..00000000000 --- a/ci/patch.toml +++ /dev/null @@ -1,8 +0,0 @@ -# Patch dependencies to run all tests against versions of the crate in the -# repository. -[patch.crates-io] -tokio = { path = "tokio" } -tokio-macros = { path = "tokio-macros" } -tokio-test = { path = "tokio-test" } -tokio-tls = { path = "tokio-tls" } -tokio-util = { path = "tokio-util" } diff --git a/ci/tsan b/ci/tsan deleted file mode 100644 index 3791c27038a..00000000000 --- a/ci/tsan +++ /dev/null @@ -1,39 +0,0 @@ -# TSAN suppressions file for Tokio - -# TSAN does not understand fences and `Arc::drop` is implemented using a fence. -# This causes many false positives. -race:Arc*drop -race:Weak*drop - -# `std` mpsc is not used in any Tokio code base. This race is triggered by some -# rust runtime logic. -race:std*mpsc_queue -race:std*lang_start -race:drop*std::thread* - -# Probably more fences in std. -race:__call_tls_dtors - -# The epoch-based GC uses fences. -race:crossbeam_epoch - -# Push and steal operations in crossbeam-deque may cause data races, but such -# data races are safe. If a data race happens, the value read by `steal` is -# forgotten and the steal operation is then retried. -race:crossbeam_deque*push -race:crossbeam_deque*steal - -# This filters out expected data race in the Treiber stack implementations. -# Treiber stacks are inherently racy. The pop operation will attempt to access -# the "next" pointer on the node it is attempting to pop. However, at this -# point it has not gained ownership of the node and another thread might beat -# it and take ownership of the node first (touching the next pointer). The -# original pop operation will fail due to the ABA guard, but tsan still picks -# up the access on the next pointer. -race:Backup::next_sleeper -race:Backup::set_next_sleeper -race:WorkerEntry::set_next_sleeper - -# This ignores a false positive caused by `thread::park()`/`thread::unpark()`. -# See: https://github.com/rust-lang/rust/pull/54806#issuecomment-436193353 -race:pthread_cond_destroy diff --git a/examples/Cargo.toml b/examples/Cargo.toml index c3dc6091518..fe3c90f9a56 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -7,7 +7,9 @@ edition = "2018" # If you copy one of the examples into a new project, you should be using # [dependencies] instead. [dev-dependencies] -tokio = { version = "0.2.0", path = "../tokio", features = ["full"] } +tokio = { version = "0.2.0", path = "../tokio", features = ["full", "tracing"] } +tracing = "0.1" +tracing-subscriber = { version = "0.2.7", default-features = false, features = ["fmt", "ansi", "env-filter", "chrono", "tracing-log"] } tokio-util = { version = "0.3.0", path = "../tokio-util", features = ["full"] } bytes = "0.5" futures = "0.3.0" diff --git a/examples/README.md b/examples/README.md index 15b06c092b5..caab606bbd0 100644 --- a/examples/README.md +++ b/examples/README.md @@ -19,5 +19,5 @@ If you've got an example you'd like to see here, please feel free to open an issue. Otherwise if you've got an example you'd like to add, please feel free to make a PR! -[tokioweb]: https://tokio.rs/docs/overview/ +[tokioweb]: https://tokio.rs/tokio/tutorial [redis]: https://github.com/tokio-rs/mini-redis diff --git a/examples/chat.rs b/examples/chat.rs index b3fb727a2cc..c4b8c6a2afc 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -43,6 +43,26 @@ use std::task::{Context, Poll}; #[tokio::main] async fn main() -> Result<(), Box> { + use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; + // Configure a `tracing` subscriber that logs traces emitted by the chat + // server. + tracing_subscriber::fmt() + // Filter what traces are displayed based on the RUST_LOG environment + // variable. + // + // Traces emitted by the example code will always be displayed. You + // can set `RUST_LOG=tokio=trace` to enable additional traces emitted by + // Tokio itself. + .with_env_filter(EnvFilter::from_default_env().add_directive("chat=info".parse()?)) + // Log events when `tracing` spans are created, entered, exited, or + // closed. When Tokio's internal tracing support is enabled (as + // described above), this can be used to track the lifecycle of spawned + // tasks on the Tokio runtime. + .with_span_events(FmtSpan::FULL) + // Set this subscriber as the default, to collect all traces emitted by + // the program. + .init(); + // Create the shared state. This is how all the peers communicate. // // The server task will hold a handle to this. For every new client, the @@ -59,7 +79,7 @@ async fn main() -> Result<(), Box> { // Note that this is the Tokio TcpListener, which is fully async. let mut listener = TcpListener::bind(&addr).await?; - println!("server running on {}", addr); + tracing::info!("server running on {}", addr); loop { // Asynchronously wait for an inbound TcpStream. @@ -70,8 +90,9 @@ async fn main() -> Result<(), Box> { // Spawn our handler to be run asynchronously. tokio::spawn(async move { + tracing::debug!("accepted connection"); if let Err(e) = process(state, stream, addr).await { - println!("an error occurred; error = {:?}", e); + tracing::info!("an error occurred; error = {:?}", e); } }); } @@ -200,7 +221,7 @@ async fn process( Some(Ok(line)) => line, // We didn't get a line so we return early here. _ => { - println!("Failed to get username from {}. Client disconnected.", addr); + tracing::error!("Failed to get username from {}. Client disconnected.", addr); return Ok(()); } }; @@ -212,7 +233,7 @@ async fn process( { let mut state = state.lock().await; let msg = format!("{} has joined the chat", username); - println!("{}", msg); + tracing::info!("{}", msg); state.broadcast(addr, &msg).await; } @@ -233,9 +254,10 @@ async fn process( peer.lines.send(&msg).await?; } Err(e) => { - println!( + tracing::error!( "an error occurred while processing messages for {}; error = {:?}", - username, e + username, + e ); } } @@ -248,7 +270,7 @@ async fn process( state.peers.remove(&addr); let msg = format!("{} has left the chat", username); - println!("{}", msg); + tracing::info!("{}", msg); state.broadcast(addr, &msg).await; } diff --git a/examples/echo-udp.rs b/examples/echo-udp.rs index d8b2af9cbb4..bc688b9b79b 100644 --- a/examples/echo-udp.rs +++ b/examples/echo-udp.rs @@ -15,7 +15,6 @@ use std::error::Error; use std::net::SocketAddr; use std::{env, io}; -use tokio; use tokio::net::UdpSocket; struct Server { diff --git a/examples/echo.rs b/examples/echo.rs index 35b122794cc..f30680748db 100644 --- a/examples/echo.rs +++ b/examples/echo.rs @@ -21,7 +21,6 @@ #![warn(rust_2018_idioms)] -use tokio; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; diff --git a/examples/proxy.rs b/examples/proxy.rs index f7a9111f6a2..144f0179fc3 100644 --- a/examples/proxy.rs +++ b/examples/proxy.rs @@ -23,6 +23,7 @@ #![warn(rust_2018_idioms)] use tokio::io; +use tokio::io::AsyncWriteExt; use tokio::net::{TcpListener, TcpStream}; use futures::future::try_join; @@ -63,8 +64,15 @@ async fn transfer(mut inbound: TcpStream, proxy_addr: String) -> Result<(), Box< let (mut ri, mut wi) = inbound.split(); let (mut ro, mut wo) = outbound.split(); - let client_to_server = io::copy(&mut ri, &mut wo); - let server_to_client = io::copy(&mut ro, &mut wi); + let client_to_server = async { + io::copy(&mut ri, &mut wo).await?; + wo.shutdown().await + }; + + let server_to_client = async { + io::copy(&mut ro, &mut wi).await?; + wi.shutdown().await + }; try_join(client_to_server, server_to_client).await?; diff --git a/examples/tinyhttp.rs b/examples/tinyhttp.rs index 732da0d64ee..4870aea2f83 100644 --- a/examples/tinyhttp.rs +++ b/examples/tinyhttp.rs @@ -18,7 +18,6 @@ use futures::SinkExt; use http::{header::HeaderValue, Request, Response, StatusCode}; #[macro_use] extern crate serde_derive; -use serde_json; use std::{env, error::Error, fmt, io}; use tokio::net::{TcpListener, TcpStream}; use tokio::stream::StreamExt; diff --git a/tokio-macros/src/entry.rs b/tokio-macros/src/entry.rs index 6a58b791ed8..2681f50d9c0 100644 --- a/tokio-macros/src/entry.rs +++ b/tokio-macros/src/entry.rs @@ -142,7 +142,7 @@ fn parse_knobs( let header = { if is_test { quote! { - #[test] + #[::core::prelude::v1::test] } } else { quote! {} @@ -334,14 +334,14 @@ pub(crate) mod old { let result = match runtime { Runtime::Threaded => quote! { - #[test] + #[::core::prelude::v1::test] #(#attrs)* #vis fn #name() #ret { tokio::runtime::Runtime::new().unwrap().block_on(async { #body }) } }, Runtime::Basic | Runtime::Auto => quote! { - #[test] + #[::core::prelude::v1::test] #(#attrs)* #vis fn #name() #ret { tokio::runtime::Builder::new() diff --git a/tokio-macros/src/lib.rs b/tokio-macros/src/lib.rs index 9fdfb5bd769..64cdc4f1bc4 100644 --- a/tokio-macros/src/lib.rs +++ b/tokio-macros/src/lib.rs @@ -24,7 +24,9 @@ mod select; use proc_macro::TokenStream; -/// Marks async function to be executed by selected runtime. +/// Marks async function to be executed by selected runtime. This macro helps set up a `Runtime` +/// without requiring the user to use [Runtime](../tokio/runtime/struct.Runtime.html) or +/// [Builder](../tokio/runtime/struct.builder.html) directly. /// /// ## Options: /// @@ -47,21 +49,62 @@ use proc_macro::TokenStream; /// } /// ``` /// +/// Equivalent code not using `#[tokio::main]` +/// +/// ```rust +/// fn main() { +/// tokio::runtime::Builder::new() +/// .threaded_scheduler() +/// .enable_all() +/// .build() +/// .unwrap() +/// .block_on(async { +/// println!("Hello world"); +/// }) +/// } +/// ``` +/// /// ### Set number of core threads /// /// ```rust -/// #[tokio::main(core_threads = 1)] +/// #[tokio::main(core_threads = 2)] /// async fn main() { /// println!("Hello world"); /// } /// ``` +/// +/// Equivalent code not using `#[tokio::main]` +/// +/// ```rust +/// fn main() { +/// tokio::runtime::Builder::new() +/// .threaded_scheduler() +/// .core_threads(2) +/// .enable_all() +/// .build() +/// .unwrap() +/// .block_on(async { +/// println!("Hello world"); +/// }) +/// } +/// ``` +/// +/// ### NOTE: +/// +/// If you rename the tokio crate in your dependencies this macro +/// will not work. If you must rename the 0.2 version of tokio because +/// you're also using the 0.1 version of tokio, you _must_ make the +/// tokio 0.2 crate available as `tokio` in the module where this +/// macro is expanded. #[proc_macro_attribute] #[cfg(not(test))] // Work around for rust-lang/rust#62127 pub fn main_threaded(args: TokenStream, item: TokenStream) -> TokenStream { entry::main(args, item, true) } -/// Marks async function to be executed by selected runtime. +/// Marks async function to be executed by selected runtime. This macro helps set up a `Runtime` +/// without requiring the user to use [Runtime](../tokio/runtime/struct.Runtime.html) or +/// [Builder](../tokio/runtime/struct.builder.html) directly. /// /// ## Options: /// @@ -83,6 +126,18 @@ pub fn main_threaded(args: TokenStream, item: TokenStream) -> TokenStream { /// } /// ``` /// +/// Equivalent code not using `#[tokio::main]` +/// +/// ```rust +/// fn main() { +/// tokio::runtime::Runtime::new() +/// .unwrap() +/// .block_on(async { +/// println!("Hello world"); +/// }) +/// } +/// ``` +/// /// ### Select runtime /// /// ```rust @@ -91,13 +146,38 @@ pub fn main_threaded(args: TokenStream, item: TokenStream) -> TokenStream { /// println!("Hello world"); /// } /// ``` +/// +/// Equivalent code not using `#[tokio::main]` +/// +/// ```rust +/// fn main() { +/// tokio::runtime::Builder::new() +/// .basic_scheduler() +/// .enable_all() +/// .build() +/// .unwrap() +/// .block_on(async { +/// println!("Hello world"); +/// }) +/// } +/// ``` +/// +/// ### NOTE: +/// +/// If you rename the tokio crate in your dependencies this macro +/// will not work. If you must rename the 0.2 version of tokio because +/// you're also using the 0.1 version of tokio, you _must_ make the +/// tokio 0.2 crate available as `tokio` in the module where this +/// macro is expanded. #[proc_macro_attribute] #[cfg(not(test))] // Work around for rust-lang/rust#62127 pub fn main(args: TokenStream, item: TokenStream) -> TokenStream { entry::old::main(args, item) } -/// Marks async function to be executed by selected runtime. +/// Marks async function to be executed by selected runtime. This macro helps set up a `Runtime` +/// without requiring the user to use [Runtime](../tokio/runtime/struct.Runtime.html) or +/// [Builder](../tokio/runtime/struct.builder.html) directly. /// /// ## Options: /// @@ -117,6 +197,29 @@ pub fn main(args: TokenStream, item: TokenStream) -> TokenStream { /// println!("Hello world"); /// } /// ``` +/// +/// Equivalent code not using `#[tokio::main]` +/// +/// ```rust +/// fn main() { +/// tokio::runtime::Builder::new() +/// .basic_scheduler() +/// .enable_all() +/// .build() +/// .unwrap() +/// .block_on(async { +/// println!("Hello world"); +/// }) +/// } +/// ``` +/// +/// ### NOTE: +/// +/// If you rename the tokio crate in your dependencies this macro +/// will not work. If you must rename the 0.2 version of tokio because +/// you're also using the 0.1 version of tokio, you _must_ make the +/// tokio 0.2 crate available as `tokio` in the module where this +/// macro is expanded. #[proc_macro_attribute] #[cfg(not(test))] // Work around for rust-lang/rust#62127 pub fn main_basic(args: TokenStream, item: TokenStream) -> TokenStream { @@ -149,6 +252,14 @@ pub fn main_basic(args: TokenStream, item: TokenStream) -> TokenStream { /// assert!(true); /// } /// ``` +/// +/// ### NOTE: +/// +/// If you rename the tokio crate in your dependencies this macro +/// will not work. If you must rename the 0.2 version of tokio because +/// you're also using the 0.1 version of tokio, you _must_ make the +/// tokio 0.2 crate available as `tokio` in the module where this +/// macro is expanded. #[proc_macro_attribute] pub fn test_threaded(args: TokenStream, item: TokenStream) -> TokenStream { entry::test(args, item, true) @@ -180,6 +291,14 @@ pub fn test_threaded(args: TokenStream, item: TokenStream) -> TokenStream { /// assert!(true); /// } /// ``` +/// +/// ### NOTE: +/// +/// If you rename the tokio crate in your dependencies this macro +/// will not work. If you must rename the 0.2 version of tokio because +/// you're also using the 0.1 version of tokio, you _must_ make the +/// tokio 0.2 crate available as `tokio` in the module where this +/// macro is expanded. #[proc_macro_attribute] pub fn test(args: TokenStream, item: TokenStream) -> TokenStream { entry::old::test(args, item) @@ -199,6 +318,14 @@ pub fn test(args: TokenStream, item: TokenStream) -> TokenStream { /// assert!(true); /// } /// ``` +/// +/// ### NOTE: +/// +/// If you rename the tokio crate in your dependencies this macro +/// will not work. If you must rename the 0.2 version of tokio because +/// you're also using the 0.1 version of tokio, you _must_ make the +/// tokio 0.2 crate available as `tokio` in the module where this +/// macro is expanded. #[proc_macro_attribute] pub fn test_basic(args: TokenStream, item: TokenStream) -> TokenStream { entry::test(args, item, false) diff --git a/tokio-test/src/io.rs b/tokio-test/src/io.rs index 8af6a9f6490..26ef57e47a2 100644 --- a/tokio-test/src/io.rs +++ b/tokio-test/src/io.rs @@ -12,7 +12,7 @@ //! //! # Usage //! -//! Attempting to write data that the mock isn't expected will result in a +//! Attempting to write data that the mock isn't expecting will result in a //! panic. //! //! [`AsyncRead`]: tokio::io::AsyncRead diff --git a/tokio-test/src/task.rs b/tokio-test/src/task.rs index 82d29134c12..728117cc158 100644 --- a/tokio-test/src/task.rs +++ b/tokio-test/src/task.rs @@ -46,21 +46,11 @@ const SLEEP: usize = 2; impl Spawn { /// Consumes `self` returning the inner value - pub fn into_inner(mut self) -> T + pub fn into_inner(self) -> T where T: Unpin, { - drop(self.task); - - // Pin::into_inner is unstable, so we work around it - // - // Safety: `T` is bound by `Unpin`. - unsafe { - let ptr = Pin::get_mut(self.future.as_mut()) as *mut T; - let future = Box::from_raw(ptr); - mem::forget(self.future); - *future - } + *Pin::into_inner(self.future) } /// Returns `true` if the inner future has received a wake notification diff --git a/tokio-tls/CHANGELOG.md b/tokio-tls/CHANGELOG.md deleted file mode 100644 index 82c0d50460f..00000000000 --- a/tokio-tls/CHANGELOG.md +++ /dev/null @@ -1,39 +0,0 @@ -# 0.3.0 (November 26, 2019) - -- Updates for tokio 0.2 release - -# 0.3.0-alpha.6 (September 30, 2019) - -- Move to `futures-*-preview 0.3.0-alpha.19` -- Move to `pin-project 0.4` - -# 0.3.0-alpha.5 (September 19, 2019) - -### Added -- `TlsStream::get_ref` and `TlsStream::get_mut` ([#1537]). - -# 0.3.0-alpha.4 (August 30, 2019) - -### Changed -- Track `tokio` 0.2.0-alpha.4 - -# 0.3.0-alpha.2 (August 17, 2019) - -### Changed -- Update `futures` dependency to 0.3.0-alpha.18. - -# 0.3.0-alpha.1 (August 8, 2019) - -### Changed -- Switch to `async`, `await`, and `std::future`. - -# 0.2.1 (January 6, 2019) - -* Implement `Clone` for `TlsConnector` and `TlsAcceptor` ([#777]) - -# 0.2.0 (August 8, 2018) - -* Initial release with `tokio` support. - -[#1537]: https://github.com/tokio-rs/tokio/pull/1537 -[#777]: https://github.com/tokio-rs/tokio/pull/777 diff --git a/tokio-tls/Cargo.toml b/tokio-tls/Cargo.toml deleted file mode 100644 index a9877926933..00000000000 --- a/tokio-tls/Cargo.toml +++ /dev/null @@ -1,63 +0,0 @@ -[package] -name = "tokio-tls" -# When releasing to crates.io: -# - Remove path dependencies -# - Update html_root_url. -# - Update doc url -# - Cargo.toml -# - README.md -# - Update CHANGELOG.md. -# - Create "v0.3.x" git tag. -version = "0.3.0" -edition = "2018" -authors = ["Tokio Contributors "] -license = "MIT" -repository = "https://github.com/tokio-rs/tokio" -homepage = "https://tokio.rs" -documentation = "https://docs.rs/tokio-tls/0.3.0-alpha.6/tokio_tls/" -description = """ -An implementation of TLS/SSL streams for Tokio giving an implementation of TLS -for nonblocking I/O streams. -""" -categories = ["asynchronous", "network-programming"] - -[badges] -travis-ci = { repository = "tokio-rs/tokio-tls" } - -[dependencies] -native-tls = "0.2" -tokio = { version = "0.2.0", path = "../tokio" } - -[dev-dependencies] -tokio = { version = "0.2.0", path = "../tokio", features = ["macros", "stream", "rt-core", "io-util", "net"] } -tokio-util = { version = "0.3.0", path = "../tokio-util", features = ["full"] } - -cfg-if = "0.1" -env_logger = { version = "0.6", default-features = false } -futures = { version = "0.3.0", features = ["async-await"] } - -[target.'cfg(all(not(target_os = "macos"), not(windows), not(target_os = "ios")))'.dev-dependencies] -openssl = "0.10" - -[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dev-dependencies] -security-framework = "0.2" - -[target.'cfg(windows)'.dev-dependencies] -schannel = "0.1" - -[target.'cfg(windows)'.dev-dependencies.winapi] -version = "0.3" -features = [ - "lmcons", - "basetsd", - "minwinbase", - "minwindef", - "ntdef", - "sysinfoapi", - "timezoneapi", - "wincrypt", - "winerror", -] - -[package.metadata.docs.rs] -all-features = true diff --git a/tokio-tls/LICENSE b/tokio-tls/LICENSE deleted file mode 100644 index cdb28b4b56a..00000000000 --- a/tokio-tls/LICENSE +++ /dev/null @@ -1,25 +0,0 @@ -Copyright (c) 2019 Tokio Contributors - -Permission is hereby granted, free of charge, to any -person obtaining a copy of this software and associated -documentation files (the "Software"), to deal in the -Software without restriction, including without -limitation the rights to use, copy, modify, merge, -publish, distribute, sublicense, and/or sell copies of -the Software, and to permit persons to whom the Software -is furnished to do so, subject to the following -conditions: - -The above copyright notice and this permission notice -shall be included in all copies or substantial portions -of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF -ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED -TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A -PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. diff --git a/tokio-tls/README.md b/tokio-tls/README.md deleted file mode 100644 index 455612be8b8..00000000000 --- a/tokio-tls/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# tokio-tls - -An implementation of TLS/SSL streams for Tokio built on top of the [`native-tls` -crate] - -## License - -This project is licensed under the [MIT license](./LICENSE). - -### Contribution - -Unless you explicitly state otherwise, any contribution intentionally submitted -for inclusion in Tokio by you, shall be licensed as MIT, without any additional -terms or conditions. diff --git a/tokio-tls/examples/download-rust-lang.rs b/tokio-tls/examples/download-rust-lang.rs deleted file mode 100644 index 324c077539a..00000000000 --- a/tokio-tls/examples/download-rust-lang.rs +++ /dev/null @@ -1,40 +0,0 @@ -// #![warn(rust_2018_idioms)] - -use native_tls::TlsConnector; -use std::error::Error; -use std::net::ToSocketAddrs; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpStream; -use tokio_tls; - -#[tokio::main] -async fn main() -> Result<(), Box> { - let addr = "www.rust-lang.org:443" - .to_socket_addrs()? - .next() - .ok_or("failed to resolve www.rust-lang.org")?; - - let socket = TcpStream::connect(&addr).await?; - let cx = TlsConnector::builder().build()?; - let cx = tokio_tls::TlsConnector::from(cx); - - let mut socket = cx.connect("www.rust-lang.org", socket).await?; - - socket - .write_all( - "\ - GET / HTTP/1.0\r\n\ - Host: www.rust-lang.org\r\n\ - \r\n\ - " - .as_bytes(), - ) - .await?; - - let mut data = Vec::new(); - socket.read_to_end(&mut data).await?; - - // println!("data: {:?}", &data); - println!("{}", String::from_utf8_lossy(&data[..])); - Ok(()) -} diff --git a/tokio-tls/examples/identity.p12 b/tokio-tls/examples/identity.p12 deleted file mode 100644 index d16abb8c706..00000000000 Binary files a/tokio-tls/examples/identity.p12 and /dev/null differ diff --git a/tokio-tls/examples/tls-echo.rs b/tokio-tls/examples/tls-echo.rs deleted file mode 100644 index 96309567403..00000000000 --- a/tokio-tls/examples/tls-echo.rs +++ /dev/null @@ -1,55 +0,0 @@ -#![warn(rust_2018_idioms)] - -// A tiny async TLS echo server with Tokio -use native_tls; -use native_tls::Identity; -use tokio; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpListener; -use tokio_tls; - -/** -an example to setup a tls server. -how to test: -wget https://127.0.0.1:12345 --no-check-certificate -*/ -#[tokio::main] -async fn main() -> Result<(), Box> { - // Bind the server's socket - let addr = "127.0.0.1:12345".to_string(); - let mut tcp: TcpListener = TcpListener::bind(&addr).await?; - - // Create the TLS acceptor. - let der = include_bytes!("identity.p12"); - let cert = Identity::from_pkcs12(der, "mypass")?; - let tls_acceptor = - tokio_tls::TlsAcceptor::from(native_tls::TlsAcceptor::builder(cert).build()?); - loop { - // Asynchronously wait for an inbound socket. - let (socket, remote_addr) = tcp.accept().await?; - let tls_acceptor = tls_acceptor.clone(); - println!("accept connection from {}", remote_addr); - tokio::spawn(async move { - // Accept the TLS connection. - let mut tls_stream = tls_acceptor.accept(socket).await.expect("accept error"); - // In a loop, read data from the socket and write the data back. - - let mut buf = [0; 1024]; - let n = tls_stream - .read(&mut buf) - .await - .expect("failed to read data from socket"); - - if n == 0 { - return; - } - println!("read={}", unsafe { - String::from_utf8_unchecked(buf[0..n].into()) - }); - tls_stream - .write_all(&buf[0..n]) - .await - .expect("failed to write data to socket"); - }); - } -} diff --git a/tokio-tls/src/lib.rs b/tokio-tls/src/lib.rs deleted file mode 100644 index 2770650934b..00000000000 --- a/tokio-tls/src/lib.rs +++ /dev/null @@ -1,361 +0,0 @@ -#![doc(html_root_url = "https://docs.rs/tokio-tls/0.3.0")] -#![warn( - missing_debug_implementations, - missing_docs, - rust_2018_idioms, - unreachable_pub -)] -#![deny(intra_doc_link_resolution_failure)] -#![doc(test( - no_crate_inject, - attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) -))] - -//! Async TLS streams -//! -//! This library is an implementation of TLS streams using the most appropriate -//! system library by default for negotiating the connection. That is, on -//! Windows this library uses SChannel, on OSX it uses SecureTransport, and on -//! other platforms it uses OpenSSL. -//! -//! Each TLS stream implements the `Read` and `Write` traits to interact and -//! interoperate with the rest of the futures I/O ecosystem. Client connections -//! initiated from this crate verify hostnames automatically and by default. -//! -//! This crate primarily exports this ability through two newtypes, -//! `TlsConnector` and `TlsAcceptor`. These newtypes augment the -//! functionality provided by the `native-tls` crate, on which this crate is -//! built. Configuration of TLS parameters is still primarily done through the -//! `native-tls` crate. - -use tokio::io::{AsyncRead, AsyncWrite}; - -use native_tls::{Error, HandshakeError, MidHandshakeTlsStream}; -use std::fmt; -use std::future::Future; -use std::io::{self, Read, Write}; -use std::marker::Unpin; -use std::mem::MaybeUninit; -use std::pin::Pin; -use std::ptr::null_mut; -use std::task::{Context, Poll}; - -#[derive(Debug)] -struct AllowStd { - inner: S, - context: *mut (), -} - -/// A wrapper around an underlying raw stream which implements the TLS or SSL -/// protocol. -/// -/// A `TlsStream` represents a handshake that has been completed successfully -/// and both the server and the client are ready for receiving and sending -/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written -/// to a `TlsStream` are encrypted when passing through to `S`. -#[derive(Debug)] -pub struct TlsStream(native_tls::TlsStream>); - -/// A wrapper around a `native_tls::TlsConnector`, providing an async `connect` -/// method. -#[derive(Clone)] -pub struct TlsConnector(native_tls::TlsConnector); - -/// A wrapper around a `native_tls::TlsAcceptor`, providing an async `accept` -/// method. -#[derive(Clone)] -pub struct TlsAcceptor(native_tls::TlsAcceptor); - -struct MidHandshake(Option>>); - -enum StartedHandshake { - Done(TlsStream), - Mid(MidHandshakeTlsStream>), -} - -struct StartedHandshakeFuture(Option>); -struct StartedHandshakeFutureInner { - f: F, - stream: S, -} - -struct Guard<'a, S>(&'a mut TlsStream) -where - AllowStd: Read + Write; - -impl Drop for Guard<'_, S> -where - AllowStd: Read + Write, -{ - fn drop(&mut self) { - (self.0).0.get_mut().context = null_mut(); - } -} - -// *mut () context is neither Send nor Sync -unsafe impl Send for AllowStd {} -unsafe impl Sync for AllowStd {} - -impl AllowStd -where - S: Unpin, -{ - fn with_context(&mut self, f: F) -> R - where - F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R, - { - unsafe { - assert!(!self.context.is_null()); - let waker = &mut *(self.context as *mut _); - f(waker, Pin::new(&mut self.inner)) - } - } -} - -impl Read for AllowStd -where - S: AsyncRead + Unpin, -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self.with_context(|ctx, stream| stream.poll_read(ctx, buf)) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } -} - -impl Write for AllowStd -where - S: AsyncWrite + Unpin, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self.with_context(|ctx, stream| stream.poll_flush(ctx)) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } -} - -fn cvt(r: io::Result) -> Poll> { - match r { - Ok(v) => Poll::Ready(Ok(v)), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - Err(e) => Poll::Ready(Err(e)), - } -} - -impl TlsStream { - fn with_context(&mut self, ctx: &mut Context<'_>, f: F) -> R - where - F: FnOnce(&mut native_tls::TlsStream>) -> R, - AllowStd: Read + Write, - { - self.0.get_mut().context = ctx as *mut _ as *mut (); - let g = Guard(self); - f(&mut (g.0).0) - } - - /// Returns a shared reference to the inner stream. - pub fn get_ref(&self) -> &S - where - S: AsyncRead + AsyncWrite + Unpin, - { - &self.0.get_ref().inner - } - - /// Returns a mutable reference to the inner stream. - pub fn get_mut(&mut self) -> &mut S - where - S: AsyncRead + AsyncWrite + Unpin, - { - &mut self.0.get_mut().inner - } -} - -impl AsyncRead for TlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { - // Note that this does not forward to `S` because the buffer is - // unconditionally filled in by OpenSSL, not the actual object `S`. - // We're decrypting bytes from `S` into the buffer above! - false - } - - fn poll_read( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - self.with_context(ctx, |s| cvt(s.read(buf))) - } -} - -impl AsyncWrite for TlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_write( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.with_context(ctx, |s| cvt(s.write(buf))) - } - - fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { - self.with_context(ctx, |s| cvt(s.flush())) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { - match self.with_context(ctx, |s| s.shutdown()) { - Ok(()) => Poll::Ready(Ok(())), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - Err(e) => Poll::Ready(Err(e)), - } - } -} - -async fn handshake(f: F, stream: S) -> Result, Error> -where - F: FnOnce( - AllowStd, - ) -> Result>, HandshakeError>> - + Unpin, - S: AsyncRead + AsyncWrite + Unpin, -{ - let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream })); - - match start.await { - Err(e) => Err(e), - Ok(StartedHandshake::Done(s)) => Ok(s), - Ok(StartedHandshake::Mid(s)) => MidHandshake(Some(s)).await, - } -} - -impl Future for StartedHandshakeFuture -where - F: FnOnce( - AllowStd, - ) -> Result>, HandshakeError>> - + Unpin, - S: Unpin, - AllowStd: Read + Write, -{ - type Output = Result, Error>; - - fn poll( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - ) -> Poll, Error>> { - let inner = self.0.take().expect("future polled after completion"); - let stream = AllowStd { - inner: inner.stream, - context: ctx as *mut _ as *mut (), - }; - - match (inner.f)(stream) { - Ok(mut s) => { - s.get_mut().context = null_mut(); - Poll::Ready(Ok(StartedHandshake::Done(TlsStream(s)))) - } - Err(HandshakeError::WouldBlock(mut s)) => { - s.get_mut().context = null_mut(); - Poll::Ready(Ok(StartedHandshake::Mid(s))) - } - Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), - } - } -} - -impl TlsConnector { - /// Connects the provided stream with this connector, assuming the provided - /// domain. - /// - /// This function will internally call `TlsConnector::connect` to connect - /// the stream and returns a future representing the resolution of the - /// connection operation. The returned future will resolve to either - /// `TlsStream` or `Error` depending if it's successful or not. - /// - /// This is typically used for clients who have already established, for - /// example, a TCP connection to a remote server. That stream is then - /// provided here to perform the client half of a connection to a - /// TLS-powered server. - pub async fn connect(&self, domain: &str, stream: S) -> Result, Error> - where - S: AsyncRead + AsyncWrite + Unpin, - { - handshake(move |s| self.0.connect(domain, s), stream).await - } -} - -impl fmt::Debug for TlsConnector { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TlsConnector").finish() - } -} - -impl From for TlsConnector { - fn from(inner: native_tls::TlsConnector) -> TlsConnector { - TlsConnector(inner) - } -} - -impl TlsAcceptor { - /// Accepts a new client connection with the provided stream. - /// - /// This function will internally call `TlsAcceptor::accept` to connect - /// the stream and returns a future representing the resolution of the - /// connection operation. The returned future will resolve to either - /// `TlsStream` or `Error` depending if it's successful or not. - /// - /// This is typically used after a new socket has been accepted from a - /// `TcpListener`. That socket is then passed to this function to perform - /// the server half of accepting a client connection. - pub async fn accept(&self, stream: S) -> Result, Error> - where - S: AsyncRead + AsyncWrite + Unpin, - { - handshake(move |s| self.0.accept(s), stream).await - } -} - -impl fmt::Debug for TlsAcceptor { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TlsAcceptor").finish() - } -} - -impl From for TlsAcceptor { - fn from(inner: native_tls::TlsAcceptor) -> TlsAcceptor { - TlsAcceptor(inner) - } -} - -impl Future for MidHandshake { - type Output = Result, Error>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut_self = self.get_mut(); - let mut s = mut_self.0.take().expect("future polled after completion"); - - s.get_mut().context = cx as *mut _ as *mut (); - match s.handshake() { - Ok(stream) => Poll::Ready(Ok(TlsStream(stream))), - Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), - Err(HandshakeError::WouldBlock(mut s)) => { - s.get_mut().context = null_mut(); - mut_self.0 = Some(s); - Poll::Pending - } - } - } -} diff --git a/tokio-tls/tests/bad.rs b/tokio-tls/tests/bad.rs deleted file mode 100644 index 87d4ca2708e..00000000000 --- a/tokio-tls/tests/bad.rs +++ /dev/null @@ -1,124 +0,0 @@ -#![warn(rust_2018_idioms)] - -use cfg_if::cfg_if; -use env_logger; -use native_tls::TlsConnector; -use std::io::{self, Error}; -use std::net::ToSocketAddrs; -use tokio::net::TcpStream; -use tokio_tls; - -macro_rules! t { - ($e:expr) => { - match $e { - Ok(e) => e, - Err(e) => panic!("{} failed with {:?}", stringify!($e), e), - } - }; -} - -cfg_if! { - if #[cfg(feature = "force-rustls")] { - fn verify_failed(err: &Error, s: &str) { - let err = err.to_string(); - assert!(err.contains(s), "bad error: {}", err); - } - - fn assert_expired_error(err: &Error) { - verify_failed(err, "CertExpired"); - } - - fn assert_wrong_host(err: &Error) { - verify_failed(err, "CertNotValidForName"); - } - - fn assert_self_signed(err: &Error) { - verify_failed(err, "UnknownIssuer"); - } - - fn assert_untrusted_root(err: &Error) { - verify_failed(err, "UnknownIssuer"); - } - } else if #[cfg(any(feature = "force-openssl", - all(not(target_os = "macos"), - not(target_os = "windows"), - not(target_os = "ios"))))] { - fn verify_failed(err: &Error) { - assert!(format!("{}", err).contains("certificate verify failed")) - } - - use verify_failed as assert_expired_error; - use verify_failed as assert_wrong_host; - use verify_failed as assert_self_signed; - use verify_failed as assert_untrusted_root; - } else if #[cfg(any(target_os = "macos", target_os = "ios"))] { - - fn assert_invalid_cert_chain(err: &Error) { - assert!(format!("{}", err).contains("was not trusted.")) - } - - use crate::assert_invalid_cert_chain as assert_expired_error; - use crate::assert_invalid_cert_chain as assert_wrong_host; - use crate::assert_invalid_cert_chain as assert_self_signed; - use crate::assert_invalid_cert_chain as assert_untrusted_root; - } else { - fn assert_expired_error(err: &Error) { - let s = err.to_string(); - assert!(s.contains("system clock"), "error = {:?}", s); - } - - fn assert_wrong_host(err: &Error) { - let s = err.to_string(); - assert!(s.contains("CN name"), "error = {:?}", s); - } - - fn assert_self_signed(err: &Error) { - let s = err.to_string(); - assert!(s.contains("root certificate which is not trusted"), "error = {:?}", s); - } - - use assert_self_signed as assert_untrusted_root; - } -} - -async fn get_host(host: &'static str) -> Error { - drop(env_logger::try_init()); - - let addr = format!("{}:443", host); - let addr = t!(addr.to_socket_addrs()).next().unwrap(); - - let socket = t!(TcpStream::connect(&addr).await); - let builder = TlsConnector::builder(); - let cx = t!(builder.build()); - let cx = tokio_tls::TlsConnector::from(cx); - let res = cx - .connect(host, socket) - .await - .map_err(|e| Error::new(io::ErrorKind::Other, e)); - - assert!(res.is_err()); - res.err().unwrap() -} - -#[tokio::test] -async fn expired() { - assert_expired_error(&get_host("expired.badssl.com").await) -} - -// TODO: the OSX builders on Travis apparently fail this tests spuriously? -// passes locally though? Seems... bad! -#[tokio::test] -#[cfg_attr(all(target_os = "macos", feature = "force-openssl"), ignore)] -async fn wrong_host() { - assert_wrong_host(&get_host("wrong.host.badssl.com").await) -} - -#[tokio::test] -async fn self_signed() { - assert_self_signed(&get_host("self-signed.badssl.com").await) -} - -#[tokio::test] -async fn untrusted_root() { - assert_untrusted_root(&get_host("untrusted-root.badssl.com").await) -} diff --git a/tokio-tls/tests/google.rs b/tokio-tls/tests/google.rs deleted file mode 100644 index 13b78d31831..00000000000 --- a/tokio-tls/tests/google.rs +++ /dev/null @@ -1,102 +0,0 @@ -#![warn(rust_2018_idioms)] - -use cfg_if::cfg_if; -use env_logger; -use native_tls; -use native_tls::TlsConnector; -use std::io; -use std::net::ToSocketAddrs; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpStream; -use tokio_tls; - -macro_rules! t { - ($e:expr) => { - match $e { - Ok(e) => e, - Err(e) => panic!("{} failed with {:?}", stringify!($e), e), - } - }; -} - -cfg_if! { - if #[cfg(feature = "force-rustls")] { - fn assert_bad_hostname_error(err: &io::Error) { - let err = err.to_string(); - assert!(err.contains("CertNotValidForName"), "bad error: {}", err); - } - } else if #[cfg(any(feature = "force-openssl", - all(not(target_os = "macos"), - not(target_os = "windows"), - not(target_os = "ios"))))] { - fn assert_bad_hostname_error(err: &io::Error) { - let err = err.get_ref().unwrap(); - let err = err.downcast_ref::().unwrap(); - assert!(format!("{}", err).contains("certificate verify failed")); - } - } else if #[cfg(any(target_os = "macos", target_os = "ios"))] { - fn assert_bad_hostname_error(err: &io::Error) { - let err = err.get_ref().unwrap(); - let err = err.downcast_ref::().unwrap(); - assert!(format!("{}", err).contains("was not trusted.")); - } - } else { - fn assert_bad_hostname_error(err: &io::Error) { - let err = err.get_ref().unwrap(); - let err = err.downcast_ref::().unwrap(); - assert!(format!("{}", err).contains("CN name")); - } - } -} - -#[tokio::test] -async fn fetch_google() { - drop(env_logger::try_init()); - - // First up, resolve google.com - let addr = t!("google.com:443".to_socket_addrs()).next().unwrap(); - - let socket = TcpStream::connect(&addr).await.unwrap(); - - // Send off the request by first negotiating an SSL handshake, then writing - // of our request, then flushing, then finally read off the response. - let builder = TlsConnector::builder(); - let connector = t!(builder.build()); - let connector = tokio_tls::TlsConnector::from(connector); - let mut socket = t!(connector.connect("google.com", socket).await); - t!(socket.write_all(b"GET / HTTP/1.0\r\n\r\n").await); - let mut data = Vec::new(); - t!(socket.read_to_end(&mut data).await); - - // any response code is fine - assert!(data.starts_with(b"HTTP/1.0 ")); - - let data = String::from_utf8_lossy(&data); - let data = data.trim_end(); - assert!(data.ends_with("") || data.ends_with("")); -} - -fn native2io(e: native_tls::Error) -> io::Error { - io::Error::new(io::ErrorKind::Other, e) -} - -// see comment in bad.rs for ignore reason -#[cfg_attr(all(target_os = "macos", feature = "force-openssl"), ignore)] -#[tokio::test] -async fn wrong_hostname_error() { - drop(env_logger::try_init()); - - let addr = t!("google.com:443".to_socket_addrs()).next().unwrap(); - - let socket = t!(TcpStream::connect(&addr).await); - let builder = TlsConnector::builder(); - let connector = t!(builder.build()); - let connector = tokio_tls::TlsConnector::from(connector); - let res = connector - .connect("rust-lang.org", socket) - .await - .map_err(native2io); - - assert!(res.is_err()); - assert_bad_hostname_error(&res.err().unwrap()); -} diff --git a/tokio-tls/tests/smoke.rs b/tokio-tls/tests/smoke.rs deleted file mode 100644 index 8788dd6d467..00000000000 --- a/tokio-tls/tests/smoke.rs +++ /dev/null @@ -1,629 +0,0 @@ -#![warn(rust_2018_idioms)] - -use cfg_if::cfg_if; -use env_logger; -use futures::join; -use native_tls; -use native_tls::{Identity, TlsAcceptor, TlsConnector}; -use std::io::Write; -use std::marker::Unpin; -use std::process::Command; -use std::ptr; -use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, Error, ErrorKind}; -use tokio::net::{TcpListener, TcpStream}; -use tokio::stream::StreamExt; -use tokio_tls; - -macro_rules! t { - ($e:expr) => { - match $e { - Ok(e) => e, - Err(e) => panic!("{} failed with {:?}", stringify!($e), e), - } - }; -} - -#[allow(dead_code)] -struct Keys { - cert_der: Vec, - pkey_der: Vec, - pkcs12_der: Vec, -} - -#[allow(dead_code)] -fn openssl_keys() -> &'static Keys { - static INIT: Once = Once::new(); - static mut KEYS: *mut Keys = ptr::null_mut(); - - INIT.call_once(|| { - let path = t!(env::current_exe()); - let path = path.parent().unwrap(); - let keyfile = path.join("test.key"); - let certfile = path.join("test.crt"); - let config = path.join("openssl.config"); - - File::create(&config) - .unwrap() - .write_all( - b"\ - [req]\n\ - distinguished_name=dn\n\ - [ dn ]\n\ - CN=localhost\n\ - [ ext ]\n\ - basicConstraints=CA:FALSE,pathlen:0\n\ - subjectAltName = @alt_names - extendedKeyUsage=serverAuth,clientAuth - [alt_names] - DNS.1 = localhost - ", - ) - .unwrap(); - - let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost"; - let output = t!(Command::new("openssl") - .arg("req") - .arg("-nodes") - .arg("-x509") - .arg("-newkey") - .arg("rsa:2048") - .arg("-config") - .arg(&config) - .arg("-extensions") - .arg("ext") - .arg("-subj") - .arg(subj) - .arg("-keyout") - .arg(&keyfile) - .arg("-out") - .arg(&certfile) - .arg("-days") - .arg("1") - .output()); - assert!(output.status.success()); - - let crtout = t!(Command::new("openssl") - .arg("x509") - .arg("-outform") - .arg("der") - .arg("-in") - .arg(&certfile) - .output()); - assert!(crtout.status.success()); - let keyout = t!(Command::new("openssl") - .arg("rsa") - .arg("-outform") - .arg("der") - .arg("-in") - .arg(&keyfile) - .output()); - assert!(keyout.status.success()); - - let pkcs12out = t!(Command::new("openssl") - .arg("pkcs12") - .arg("-export") - .arg("-nodes") - .arg("-inkey") - .arg(&keyfile) - .arg("-in") - .arg(&certfile) - .arg("-password") - .arg("pass:foobar") - .output()); - assert!(pkcs12out.status.success()); - - let keys = Box::new(Keys { - cert_der: crtout.stdout, - pkey_der: keyout.stdout, - pkcs12_der: pkcs12out.stdout, - }); - unsafe { - KEYS = Box::into_raw(keys); - } - }); - unsafe { &*KEYS } -} - -cfg_if! { - if #[cfg(feature = "rustls")] { - use webpki; - use untrusted; - use std::env; - use std::fs::File; - use std::process::Command; - use std::sync::Once; - - use untrusted::Input; - use webpki::trust_anchor_util; - - fn server_cx() -> io::Result { - let mut cx = ServerContext::new(); - - let (cert, key) = keys(); - cx.config_mut() - .set_single_cert(vec![cert.to_vec()], key.to_vec()); - - Ok(cx) - } - - fn configure_client(cx: &mut ClientContext) { - let (cert, _key) = keys(); - let cert = Input::from(cert); - let anchor = trust_anchor_util::cert_der_as_trust_anchor(cert).unwrap(); - cx.config_mut().root_store.add_trust_anchors(&[anchor]); - } - - // Like OpenSSL we generate certificates on the fly, but for OSX we - // also have to put them into a specific keychain. We put both the - // certificates and the keychain next to our binary. - // - // Right now I don't know of a way to programmatically create a - // self-signed certificate, so we just fork out to the `openssl` binary. - fn keys() -> (&'static [u8], &'static [u8]) { - static INIT: Once = Once::new(); - static mut KEYS: *mut (Vec, Vec) = ptr::null_mut(); - - INIT.call_once(|| { - let (key, cert) = openssl_keys(); - let path = t!(env::current_exe()); - let path = path.parent().unwrap(); - let keyfile = path.join("test.key"); - let certfile = path.join("test.crt"); - let config = path.join("openssl.config"); - - File::create(&config).unwrap().write_all(b"\ - [req]\n\ - distinguished_name=dn\n\ - [ dn ]\n\ - CN=localhost\n\ - [ ext ]\n\ - basicConstraints=CA:FALSE,pathlen:0\n\ - subjectAltName = @alt_names - [alt_names] - DNS.1 = localhost - ").unwrap(); - - let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost"; - let output = t!(Command::new("openssl") - .arg("req") - .arg("-nodes") - .arg("-x509") - .arg("-newkey").arg("rsa:2048") - .arg("-config").arg(&config) - .arg("-extensions").arg("ext") - .arg("-subj").arg(subj) - .arg("-keyout").arg(&keyfile) - .arg("-out").arg(&certfile) - .arg("-days").arg("1") - .output()); - assert!(output.status.success()); - - let crtout = t!(Command::new("openssl") - .arg("x509") - .arg("-outform").arg("der") - .arg("-in").arg(&certfile) - .output()); - assert!(crtout.status.success()); - let keyout = t!(Command::new("openssl") - .arg("rsa") - .arg("-outform").arg("der") - .arg("-in").arg(&keyfile) - .output()); - assert!(keyout.status.success()); - - let cert = crtout.stdout; - let key = keyout.stdout; - unsafe { - KEYS = Box::into_raw(Box::new((cert, key))); - } - }); - unsafe { - (&(*KEYS).0, &(*KEYS).1) - } - } - } else if #[cfg(any(feature = "force-openssl", - all(not(target_os = "macos"), - not(target_os = "windows"), - not(target_os = "ios"))))] { - use std::fs::File; - use std::env; - use std::sync::Once; - - fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) { - let keys = openssl_keys(); - - let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar")); - let srv = TlsAcceptor::builder(pkcs12); - - let cert = t!(native_tls::Certificate::from_der(&keys.cert_der)); - - let mut client = TlsConnector::builder(); - t!(client.add_root_certificate(cert).build()); - - (t!(srv.build()).into(), t!(client.build()).into()) - } - } else if #[cfg(any(target_os = "macos", target_os = "ios"))] { - use std::env; - use std::fs::File; - use std::sync::Once; - - fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) { - let keys = openssl_keys(); - - let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar")); - let srv = TlsAcceptor::builder(pkcs12); - - let cert = native_tls::Certificate::from_der(&keys.cert_der).unwrap(); - let mut client = TlsConnector::builder(); - client.add_root_certificate(cert); - - (t!(srv.build()).into(), t!(client.build()).into()) - } - } else { - use schannel; - use winapi; - - use std::env; - use std::fs::File; - use std::io; - use std::mem; - use std::sync::Once; - - use schannel::cert_context::CertContext; - use schannel::cert_store::{CertStore, CertAdd, Memory}; - use winapi::shared::basetsd::*; - use winapi::shared::lmcons::*; - use winapi::shared::minwindef::*; - use winapi::shared::ntdef::WCHAR; - use winapi::um::minwinbase::*; - use winapi::um::sysinfoapi::*; - use winapi::um::timezoneapi::*; - use winapi::um::wincrypt::*; - - const FRIENDLY_NAME: &str = "tokio-tls localhost testing cert"; - - fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) { - let cert = localhost_cert(); - let mut store = t!(Memory::new()).into_store(); - t!(store.add_cert(&cert, CertAdd::Always)); - let pkcs12_der = t!(store.export_pkcs12("foobar")); - let pkcs12 = t!(Identity::from_pkcs12(&pkcs12_der, "foobar")); - - let srv = TlsAcceptor::builder(pkcs12); - let client = TlsConnector::builder(); - (t!(srv.build()).into(), t!(client.build()).into()) - } - - // ==================================================================== - // Magic! - // - // Lots of magic is happening here to wrangle certificates for running - // these tests on Windows. For more information see the test suite - // in the schannel-rs crate as this is just coyping that. - // - // The general gist of this though is that the only way to add custom - // trusted certificates is to add it to the system store of trust. To - // do that we go through the whole rigamarole here to generate a new - // self-signed certificate and then insert that into the system store. - // - // This generates some dialogs, so we print what we're doing sometimes, - // and otherwise we just manage the ephemeral certificates. Because - // they're in the system store we always ensure that they're only valid - // for a small period of time (e.g. 1 day). - - fn localhost_cert() -> CertContext { - static INIT: Once = Once::new(); - INIT.call_once(|| { - for cert in local_root_store().certs() { - let name = match cert.friendly_name() { - Ok(name) => name, - Err(_) => continue, - }; - if name != FRIENDLY_NAME { - continue - } - if !cert.is_time_valid().unwrap() { - io::stdout().write_all(br#" - -The tokio-tls test suite is about to delete an old copy of one of its -certificates from your root trust store. This certificate was only valid for one -day and it is no longer needed. The host should be "localhost" and the -description should mention "tokio-tls". - - "#).unwrap(); - cert.delete().unwrap(); - } else { - return - } - } - - install_certificate().unwrap(); - }); - - for cert in local_root_store().certs() { - let name = match cert.friendly_name() { - Ok(name) => name, - Err(_) => continue, - }; - if name == FRIENDLY_NAME { - return cert - } - } - - panic!("couldn't find a cert"); - } - - fn local_root_store() -> CertStore { - if env::var("CI").is_ok() { - CertStore::open_local_machine("Root").unwrap() - } else { - CertStore::open_current_user("Root").unwrap() - } - } - - fn install_certificate() -> io::Result { - unsafe { - let mut provider = 0; - let mut hkey = 0; - - let mut buffer = "tokio-tls test suite".encode_utf16() - .chain(Some(0)) - .collect::>(); - let res = CryptAcquireContextW(&mut provider, - buffer.as_ptr(), - ptr::null_mut(), - PROV_RSA_FULL, - CRYPT_MACHINE_KEYSET); - if res != TRUE { - // create a new key container (since it does not exist) - let res = CryptAcquireContextW(&mut provider, - buffer.as_ptr(), - ptr::null_mut(), - PROV_RSA_FULL, - CRYPT_NEWKEYSET | CRYPT_MACHINE_KEYSET); - if res != TRUE { - return Err(Error::last_os_error()) - } - } - - // create a new keypair (RSA-2048) - let res = CryptGenKey(provider, - AT_SIGNATURE, - 0x0800<<16 | CRYPT_EXPORTABLE, - &mut hkey); - if res != TRUE { - return Err(Error::last_os_error()); - } - - // start creating the certificate - let name = "CN=localhost,O=tokio-tls,OU=tokio-tls,\ - G=tokio_tls".encode_utf16() - .chain(Some(0)) - .collect::>(); - let mut cname_buffer: [WCHAR; UNLEN as usize + 1] = mem::zeroed(); - let mut cname_len = cname_buffer.len() as DWORD; - let res = CertStrToNameW(X509_ASN_ENCODING, - name.as_ptr(), - CERT_X500_NAME_STR, - ptr::null_mut(), - cname_buffer.as_mut_ptr() as *mut u8, - &mut cname_len, - ptr::null_mut()); - if res != TRUE { - return Err(Error::last_os_error()); - } - - let mut subject_issuer = CERT_NAME_BLOB { - cbData: cname_len, - pbData: cname_buffer.as_ptr() as *mut u8, - }; - let mut key_provider = CRYPT_KEY_PROV_INFO { - pwszContainerName: buffer.as_mut_ptr(), - pwszProvName: ptr::null_mut(), - dwProvType: PROV_RSA_FULL, - dwFlags: CRYPT_MACHINE_KEYSET, - cProvParam: 0, - rgProvParam: ptr::null_mut(), - dwKeySpec: AT_SIGNATURE, - }; - let mut sig_algorithm = CRYPT_ALGORITHM_IDENTIFIER { - pszObjId: szOID_RSA_SHA256RSA.as_ptr() as *mut _, - Parameters: mem::zeroed(), - }; - let mut expiration_date: SYSTEMTIME = mem::zeroed(); - GetSystemTime(&mut expiration_date); - let mut file_time: FILETIME = mem::zeroed(); - let res = SystemTimeToFileTime(&expiration_date, - &mut file_time); - if res != TRUE { - return Err(Error::last_os_error()); - } - let mut timestamp: u64 = file_time.dwLowDateTime as u64 | - (file_time.dwHighDateTime as u64) << 32; - // one day, timestamp unit is in 100 nanosecond intervals - timestamp += (1E9 as u64) / 100 * (60 * 60 * 24); - file_time.dwLowDateTime = timestamp as u32; - file_time.dwHighDateTime = (timestamp >> 32) as u32; - let res = FileTimeToSystemTime(&file_time, - &mut expiration_date); - if res != TRUE { - return Err(Error::last_os_error()); - } - - // create a self signed certificate - let cert_context = CertCreateSelfSignCertificate( - 0 as ULONG_PTR, - &mut subject_issuer, - 0, - &mut key_provider, - &mut sig_algorithm, - ptr::null_mut(), - &mut expiration_date, - ptr::null_mut()); - if cert_context.is_null() { - return Err(Error::last_os_error()); - } - - // TODO: this is.. a terrible hack. Right now `schannel` - // doesn't provide a public method to go from a raw - // cert context pointer to the `CertContext` structure it - // has, so we just fake it here with a transmute. This'll - // probably break at some point, but hopefully by then - // it'll have a method to do this! - struct MyCertContext(T); - impl Drop for MyCertContext { - fn drop(&mut self) {} - } - - let cert_context = MyCertContext(cert_context); - let cert_context: CertContext = mem::transmute(cert_context); - - cert_context.set_friendly_name(FRIENDLY_NAME)?; - - // install the certificate to the machine's local store - io::stdout().write_all(br#" - -The tokio-tls test suite is about to add a certificate to your set of root -and trusted certificates. This certificate should be for the domain "localhost" -with the description related to "tokio-tls". This certificate is only valid -for one day and will be automatically deleted if you re-run the tokio-tls -test suite later. - - "#).unwrap(); - local_root_store().add_cert(&cert_context, - CertAdd::ReplaceExisting)?; - Ok(cert_context) - } - } - } -} - -const AMT: usize = 128 * 1024; - -async fn copy_data(mut w: W) -> Result { - let mut data = vec![9; AMT as usize]; - let mut amt = 0; - while !data.is_empty() { - let written = w.write(&data).await?; - if written <= data.len() { - amt += written; - data.resize(data.len() - written, 0); - } else { - w.write_all(&data).await?; - amt += data.len(); - break; - } - - println!("remaining: {}", data.len()); - } - Ok(amt) -} - -#[tokio::test] -async fn client_to_server() { - drop(env_logger::try_init()); - - // Create a server listening on a port, then figure out what that port is - let mut srv = t!(TcpListener::bind("127.0.0.1:0").await); - let addr = t!(srv.local_addr()); - - let (server_cx, client_cx) = contexts(); - - // Create a future to accept one socket, connect the ssl stream, and then - // read all the data from it. - let server = async move { - let mut incoming = srv.incoming(); - let socket = t!(incoming.next().await.unwrap()); - let mut socket = t!(server_cx.accept(socket).await); - let mut data = Vec::new(); - t!(socket.read_to_end(&mut data).await); - data - }; - - // Create a future to connect to our server, connect the ssl stream, and - // then write a bunch of data to it. - let client = async move { - let socket = t!(TcpStream::connect(&addr).await); - let socket = t!(client_cx.connect("localhost", socket).await); - copy_data(socket).await - }; - - // Finally, run everything! - let (data, _) = join!(server, client); - // assert_eq!(amt, AMT); - assert!(data == vec![9; AMT]); -} - -#[tokio::test] -async fn server_to_client() { - drop(env_logger::try_init()); - - // Create a server listening on a port, then figure out what that port is - let mut srv = t!(TcpListener::bind("127.0.0.1:0").await); - let addr = t!(srv.local_addr()); - - let (server_cx, client_cx) = contexts(); - - let server = async move { - let mut incoming = srv.incoming(); - let socket = t!(incoming.next().await.unwrap()); - let socket = t!(server_cx.accept(socket).await); - copy_data(socket).await - }; - - let client = async move { - let socket = t!(TcpStream::connect(&addr).await); - let mut socket = t!(client_cx.connect("localhost", socket).await); - let mut data = Vec::new(); - t!(socket.read_to_end(&mut data).await); - data - }; - - // Finally, run everything! - let (_, data) = join!(server, client); - // assert_eq!(amt, AMT); - assert!(data == vec![9; AMT]); -} - -#[tokio::test] -async fn one_byte_at_a_time() { - const AMT: usize = 1024; - drop(env_logger::try_init()); - - let mut srv = t!(TcpListener::bind("127.0.0.1:0").await); - let addr = t!(srv.local_addr()); - - let (server_cx, client_cx) = contexts(); - - let server = async move { - let mut incoming = srv.incoming(); - let socket = t!(incoming.next().await.unwrap()); - let mut socket = t!(server_cx.accept(socket).await); - let mut amt = 0; - for b in std::iter::repeat(9).take(AMT) { - let data = [b as u8]; - t!(socket.write_all(&data).await); - amt += 1; - } - amt - }; - - let client = async move { - let socket = t!(TcpStream::connect(&addr).await); - let mut socket = t!(client_cx.connect("localhost", socket).await); - let mut data = Vec::new(); - loop { - let mut buf = [0; 1]; - match socket.read_exact(&mut buf).await { - Ok(_) => data.extend_from_slice(&buf), - Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => break, - Err(err) => panic!(err), - } - } - data - }; - - let (amt, data) = join!(server, client); - assert_eq!(amt, AMT); - assert!(data == vec![9; AMT as usize]); -} diff --git a/tokio-util/src/codec/framed.rs b/tokio-util/src/codec/framed.rs index d2e7659eda2..36370da2694 100644 --- a/tokio-util/src/codec/framed.rs +++ b/tokio-util/src/codec/framed.rs @@ -1,10 +1,9 @@ use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; -use crate::codec::framed_read::{framed_read2, framed_read2_with_buffer, FramedRead2}; -use crate::codec::framed_write::{framed_write2, framed_write2_with_buffer, FramedWrite2}; +use crate::codec::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; use tokio::{ - io::{AsyncBufRead, AsyncRead, AsyncWrite}, + io::{AsyncRead, AsyncWrite}, stream::Stream, }; @@ -12,8 +11,7 @@ use bytes::BytesMut; use futures_sink::Sink; use pin_project_lite::pin_project; use std::fmt; -use std::io::{self, BufRead, Read, Write}; -use std::mem::MaybeUninit; +use std::io; use std::pin::Pin; use std::task::{Context, Poll}; @@ -30,37 +28,7 @@ pin_project! { /// [`Decoder::framed`]: crate::codec::Decoder::framed() pub struct Framed { #[pin] - inner: FramedRead2>>, - } -} - -pin_project! { - pub(crate) struct Fuse { - #[pin] - pub(crate) io: T, - pub(crate) codec: U, - } -} - -/// Abstracts over `FramedRead2` being either `FramedRead2>>` or -/// `FramedRead2>` and lets the io and codec parts be extracted in either case. -pub(crate) trait ProjectFuse { - type Io; - type Codec; - - fn project(self: Pin<&mut Self>) -> Fuse, &mut Self::Codec>; -} - -impl ProjectFuse for Fuse { - type Io = T; - type Codec = U; - - fn project(self: Pin<&mut Self>) -> Fuse, &mut Self::Codec> { - let self_ = self.project(); - Fuse { - io: self_.io, - codec: self_.codec, - } + inner: FramedImpl } } @@ -93,7 +61,11 @@ where /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split pub fn new(inner: T, codec: U) -> Framed { Framed { - inner: framed_read2(framed_write2(Fuse { io: inner, codec })), + inner: FramedImpl { + inner, + codec, + state: Default::default(), + }, } } @@ -123,10 +95,18 @@ where /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split pub fn with_capacity(inner: T, codec: U, capacity: usize) -> Framed { Framed { - inner: framed_read2_with_buffer( - framed_write2(Fuse { io: inner, codec }), - BytesMut::with_capacity(capacity), - ), + inner: FramedImpl { + inner, + codec, + state: RWFrames { + read: ReadFrame { + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(capacity), + }, + write: WriteFrame::default(), + }, + }, } } } @@ -161,16 +141,14 @@ impl Framed { /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split pub fn from_parts(parts: FramedParts) -> Framed { Framed { - inner: framed_read2_with_buffer( - framed_write2_with_buffer( - Fuse { - io: parts.io, - codec: parts.codec, - }, - parts.write_buf, - ), - parts.read_buf, - ), + inner: FramedImpl { + inner: parts.io, + codec: parts.codec, + state: RWFrames { + read: parts.read_buf.into(), + write: parts.write_buf.into(), + }, + }, } } @@ -181,7 +159,7 @@ impl Framed { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_ref(&self) -> &T { - &self.inner.get_ref().get_ref().io + &self.inner.inner } /// Returns a mutable reference to the underlying I/O stream wrapped by @@ -191,7 +169,7 @@ impl Framed { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_mut(&mut self) -> &mut T { - &mut self.inner.get_mut().get_mut().io + &mut self.inner.inner } /// Returns a reference to the underlying codec wrapped by @@ -200,7 +178,7 @@ impl Framed { /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn codec(&self) -> &U { - &self.inner.get_ref().get_ref().codec + &self.inner.codec } /// Returns a mutable reference to the underlying codec wrapped by @@ -209,12 +187,17 @@ impl Framed { /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn codec_mut(&mut self) -> &mut U { - &mut self.inner.get_mut().get_mut().codec + &mut self.inner.codec } /// Returns a reference to the read buffer. pub fn read_buffer(&self) -> &BytesMut { - self.inner.buffer() + &self.inner.state.read.buffer + } + + /// Returns a mutable reference to the read buffer. + pub fn read_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.inner.state.read.buffer } /// Consumes the `Framed`, returning its underlying I/O stream. @@ -223,7 +206,7 @@ impl Framed { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_inner(self) -> T { - self.inner.into_inner().into_inner().io + self.inner.inner } /// Consumes the `Framed`, returning its underlying I/O stream, the buffer @@ -233,19 +216,17 @@ impl Framed { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_parts(self) -> FramedParts { - let (inner, read_buf) = self.inner.into_parts(); - let (inner, write_buf) = inner.into_parts(); - FramedParts { - io: inner.io, - codec: inner.codec, - read_buf, - write_buf, + io: self.inner.inner, + codec: self.inner.codec, + read_buf: self.inner.state.read.buffer, + write_buf: self.inner.state.write.buffer, _priv: (), } } } +// This impl just defers to the underlying FramedImpl impl Stream for Framed where T: AsyncRead, @@ -258,6 +239,7 @@ where } } +// This impl just defers to the underlying FramedImpl impl Sink for Framed where T: AsyncWrite, @@ -267,19 +249,19 @@ where type Error = U::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.get_pin_mut().poll_ready(cx) + self.project().inner.poll_ready(cx) } fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { - self.project().inner.get_pin_mut().start_send(item) + self.project().inner.start_send(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.get_pin_mut().poll_flush(cx) + self.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.get_pin_mut().poll_close(cx) + self.project().inner.poll_close(cx) } } @@ -290,109 +272,19 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Framed") - .field("io", &self.inner.get_ref().get_ref().io) - .field("codec", &self.inner.get_ref().get_ref().codec) + .field("io", self.get_ref()) + .field("codec", self.codec()) .finish() } } -// ===== impl Fuse ===== - -impl Read for Fuse { - fn read(&mut self, dst: &mut [u8]) -> io::Result { - self.io.read(dst) - } -} - -impl BufRead for Fuse { - fn fill_buf(&mut self) -> io::Result<&[u8]> { - self.io.fill_buf() - } - - fn consume(&mut self, amt: usize) { - self.io.consume(amt) - } -} - -impl AsyncRead for Fuse { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { - self.io.prepare_uninitialized_buffer(buf) - } - - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - self.project().io.poll_read(cx, buf) - } -} - -impl AsyncBufRead for Fuse { - fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().io.poll_fill_buf(cx) - } - - fn consume(self: Pin<&mut Self>, amt: usize) { - self.project().io.consume(amt) - } -} - -impl Write for Fuse { - fn write(&mut self, src: &[u8]) -> io::Result { - self.io.write(src) - } - - fn flush(&mut self) -> io::Result<()> { - self.io.flush() - } -} - -impl AsyncWrite for Fuse { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.project().io.poll_write(cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().io.poll_flush(cx) - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().io.poll_shutdown(cx) - } -} - -impl Decoder for Fuse { - type Item = U::Item; - type Error = U::Error; - - fn decode(&mut self, buffer: &mut BytesMut) -> Result, Self::Error> { - self.codec.decode(buffer) - } - - fn decode_eof(&mut self, buffer: &mut BytesMut) -> Result, Self::Error> { - self.codec.decode_eof(buffer) - } -} - -impl> Encoder for Fuse { - type Error = U::Error; - - fn encode(&mut self, item: I, dst: &mut BytesMut) -> Result<(), Self::Error> { - self.codec.encode(item, dst) - } -} - /// `FramedParts` contains an export of the data of a Framed transport. /// It can be used to construct a new [`Framed`] with a different codec. /// It contains all current buffers and the inner transport. /// /// [`Framed`]: crate::codec::Framed #[derive(Debug)] +#[allow(clippy::manual_non_exhaustive)] pub struct FramedParts { /// The inner transport used to read bytes to and write bytes to pub io: T, diff --git a/tokio-util/src/codec/framed_impl.rs b/tokio-util/src/codec/framed_impl.rs new file mode 100644 index 00000000000..eb2e0d38c6d --- /dev/null +++ b/tokio-util/src/codec/framed_impl.rs @@ -0,0 +1,225 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use tokio::{ + io::{AsyncRead, AsyncWrite}, + stream::Stream, +}; + +use bytes::{Buf, BytesMut}; +use futures_core::ready; +use futures_sink::Sink; +use log::trace; +use pin_project_lite::pin_project; +use std::borrow::{Borrow, BorrowMut}; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + #[derive(Debug)] + pub(crate) struct FramedImpl { + #[pin] + pub(crate) inner: T, + pub(crate) state: State, + pub(crate) codec: U, + } +} + +const INITIAL_CAPACITY: usize = 8 * 1024; +const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY; + +pub(crate) struct ReadFrame { + pub(crate) eof: bool, + pub(crate) is_readable: bool, + pub(crate) buffer: BytesMut, +} + +pub(crate) struct WriteFrame { + pub(crate) buffer: BytesMut, +} + +#[derive(Default)] +pub(crate) struct RWFrames { + pub(crate) read: ReadFrame, + pub(crate) write: WriteFrame, +} + +impl Default for ReadFrame { + fn default() -> Self { + Self { + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(INITIAL_CAPACITY), + } + } +} + +impl Default for WriteFrame { + fn default() -> Self { + Self { + buffer: BytesMut::with_capacity(INITIAL_CAPACITY), + } + } +} + +impl From for ReadFrame { + fn from(mut buffer: BytesMut) -> Self { + let size = buffer.capacity(); + if size < INITIAL_CAPACITY { + buffer.reserve(INITIAL_CAPACITY - size); + } + + Self { + buffer, + is_readable: size > 0, + eof: false, + } + } +} + +impl From for WriteFrame { + fn from(mut buffer: BytesMut) -> Self { + let size = buffer.capacity(); + if size < INITIAL_CAPACITY { + buffer.reserve(INITIAL_CAPACITY - size); + } + + Self { buffer } + } +} + +impl Borrow for RWFrames { + fn borrow(&self) -> &ReadFrame { + &self.read + } +} +impl BorrowMut for RWFrames { + fn borrow_mut(&mut self) -> &mut ReadFrame { + &mut self.read + } +} +impl Borrow for RWFrames { + fn borrow(&self) -> &WriteFrame { + &self.write + } +} +impl BorrowMut for RWFrames { + fn borrow_mut(&mut self) -> &mut WriteFrame { + &mut self.write + } +} +impl Stream for FramedImpl +where + T: AsyncRead, + U: Decoder, + R: BorrowMut, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut pinned = self.project(); + let state: &mut ReadFrame = pinned.state.borrow_mut(); + loop { + // Repeatedly call `decode` or `decode_eof` as long as it is + // "readable". Readable is defined as not having returned `None`. If + // the upstream has returned EOF, and the decoder is no longer + // readable, it can be assumed that the decoder will never become + // readable again, at which point the stream is terminated. + if state.is_readable { + if state.eof { + let frame = pinned.codec.decode_eof(&mut state.buffer)?; + return Poll::Ready(frame.map(Ok)); + } + + trace!("attempting to decode a frame"); + + if let Some(frame) = pinned.codec.decode(&mut state.buffer)? { + trace!("frame decoded from buffer"); + return Poll::Ready(Some(Ok(frame))); + } + + state.is_readable = false; + } + + assert!(!state.eof); + + // Otherwise, try to read more data and try again. Make sure we've + // got room for at least one byte to read to ensure that we don't + // get a spurious 0 that looks like EOF + state.buffer.reserve(1); + let bytect = match pinned.inner.as_mut().poll_read_buf(cx, &mut state.buffer)? { + Poll::Ready(ct) => ct, + Poll::Pending => return Poll::Pending, + }; + if bytect == 0 { + state.eof = true; + } + + state.is_readable = true; + } + } +} + +impl Sink for FramedImpl +where + T: AsyncWrite, + U: Encoder, + U::Error: From, + W: BorrowMut, +{ + type Error = U::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.state.borrow().buffer.len() >= BACKPRESSURE_BOUNDARY { + self.as_mut().poll_flush(cx) + } else { + Poll::Ready(Ok(())) + } + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + let pinned = self.project(); + pinned + .codec + .encode(item, &mut pinned.state.borrow_mut().buffer)?; + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + trace!("flushing framed transport"); + let mut pinned = self.project(); + + while !pinned.state.borrow_mut().buffer.is_empty() { + let WriteFrame { buffer } = pinned.state.borrow_mut(); + trace!("writing; remaining={}", buffer.len()); + + let buf = &buffer; + let n = ready!(pinned.inner.as_mut().poll_write(cx, &buf))?; + + if n == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to \ + write frame to transport", + ) + .into())); + } + + pinned.state.borrow_mut().buffer.advance(n); + } + + // Try flushing the underlying IO + ready!(pinned.inner.poll_flush(cx))?; + + trace!("framed transport flushed"); + Poll::Ready(Ok(())) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + ready!(self.project().inner.poll_shutdown(cx))?; + + Poll::Ready(Ok(())) + } +} diff --git a/tokio-util/src/codec/framed_read.rs b/tokio-util/src/codec/framed_read.rs index e7798c327ee..a6844b73557 100644 --- a/tokio-util/src/codec/framed_read.rs +++ b/tokio-util/src/codec/framed_read.rs @@ -1,11 +1,10 @@ -use crate::codec::framed::{Fuse, ProjectFuse}; +use crate::codec::framed_impl::{FramedImpl, ReadFrame}; use crate::codec::Decoder; use tokio::{io::AsyncRead, stream::Stream}; use bytes::BytesMut; use futures_sink::Sink; -use log::trace; use pin_project_lite::pin_project; use std::fmt; use std::pin::Pin; @@ -18,22 +17,10 @@ pin_project! { /// [`AsyncRead`]: tokio::io::AsyncRead pub struct FramedRead { #[pin] - inner: FramedRead2>, + inner: FramedImpl, } } -pin_project! { - pub(crate) struct FramedRead2 { - #[pin] - inner: T, - eof: bool, - is_readable: bool, - buffer: BytesMut, - } -} - -const INITIAL_CAPACITY: usize = 8 * 1024; - // ===== impl FramedRead ===== impl FramedRead @@ -44,10 +31,11 @@ where /// Creates a new `FramedRead` with the given `decoder`. pub fn new(inner: T, decoder: D) -> FramedRead { FramedRead { - inner: framed_read2(Fuse { - io: inner, + inner: FramedImpl { + inner, codec: decoder, - }), + state: Default::default(), + }, } } @@ -55,13 +43,15 @@ where /// initial size. pub fn with_capacity(inner: T, decoder: D, capacity: usize) -> FramedRead { FramedRead { - inner: framed_read2_with_buffer( - Fuse { - io: inner, - codec: decoder, + inner: FramedImpl { + inner, + codec: decoder, + state: ReadFrame { + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(capacity), }, - BytesMut::with_capacity(capacity), - ), + }, } } } @@ -74,7 +64,7 @@ impl FramedRead { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_ref(&self) -> &T { - &self.inner.inner.io + &self.inner.inner } /// Returns a mutable reference to the underlying I/O stream wrapped by @@ -84,7 +74,7 @@ impl FramedRead { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_mut(&mut self) -> &mut T { - &mut self.inner.inner.io + &mut self.inner.inner } /// Consumes the `FramedRead`, returning its underlying I/O stream. @@ -93,25 +83,26 @@ impl FramedRead { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_inner(self) -> T { - self.inner.inner.io + self.inner.inner } /// Returns a reference to the underlying decoder. pub fn decoder(&self) -> &D { - &self.inner.inner.codec + &self.inner.codec } /// Returns a mutable reference to the underlying decoder. pub fn decoder_mut(&mut self) -> &mut D { - &mut self.inner.inner.codec + &mut self.inner.codec } /// Returns a reference to the read buffer. pub fn read_buffer(&self) -> &BytesMut { - &self.inner.buffer + &self.inner.state.buffer } } +// This impl just defers to the underlying FramedImpl impl Stream for FramedRead where T: AsyncRead, @@ -132,43 +123,19 @@ where type Error = T::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .inner - .project() - .inner - .project() - .io - .poll_ready(cx) + self.project().inner.project().inner.poll_ready(cx) } fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { - self.project() - .inner - .project() - .inner - .project() - .io - .start_send(item) + self.project().inner.project().inner.start_send(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .inner - .project() - .inner - .project() - .io - .poll_flush(cx) + self.project().inner.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .inner - .project() - .inner - .project() - .io - .poll_close(cx) + self.project().inner.project().inner.poll_close(cx) } } @@ -179,126 +146,11 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FramedRead") - .field("inner", &self.inner.inner.io) - .field("decoder", &self.inner.inner.codec) - .field("eof", &self.inner.eof) - .field("is_readable", &self.inner.is_readable) - .field("buffer", &self.inner.buffer) + .field("inner", &self.get_ref()) + .field("decoder", &self.decoder()) + .field("eof", &self.inner.state.eof) + .field("is_readable", &self.inner.state.is_readable) + .field("buffer", &self.read_buffer()) .finish() } } - -// ===== impl FramedRead2 ===== - -pub(crate) fn framed_read2(inner: T) -> FramedRead2 { - FramedRead2 { - inner, - eof: false, - is_readable: false, - buffer: BytesMut::with_capacity(INITIAL_CAPACITY), - } -} - -pub(crate) fn framed_read2_with_buffer(inner: T, mut buf: BytesMut) -> FramedRead2 { - if buf.capacity() < INITIAL_CAPACITY { - let bytes_to_reserve = INITIAL_CAPACITY - buf.capacity(); - buf.reserve(bytes_to_reserve); - } - FramedRead2 { - inner, - eof: false, - is_readable: !buf.is_empty(), - buffer: buf, - } -} - -impl FramedRead2 { - pub(crate) fn get_ref(&self) -> &T { - &self.inner - } - - pub(crate) fn into_inner(self) -> T { - self.inner - } - - pub(crate) fn into_parts(self) -> (T, BytesMut) { - (self.inner, self.buffer) - } - - pub(crate) fn get_mut(&mut self) -> &mut T { - &mut self.inner - } - - pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { - self.project().inner - } - - pub(crate) fn buffer(&self) -> &BytesMut { - &self.buffer - } -} - -impl Stream for FramedRead2 -where - T: ProjectFuse + AsyncRead, - T::Codec: Decoder, -{ - type Item = Result<::Item, ::Error>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut pinned = self.project(); - loop { - // Repeatedly call `decode` or `decode_eof` as long as it is - // "readable". Readable is defined as not having returned `None`. If - // the upstream has returned EOF, and the decoder is no longer - // readable, it can be assumed that the decoder will never become - // readable again, at which point the stream is terminated. - if *pinned.is_readable { - if *pinned.eof { - let frame = pinned - .inner - .as_mut() - .project() - .codec - .decode_eof(&mut pinned.buffer)?; - return Poll::Ready(frame.map(Ok)); - } - - trace!("attempting to decode a frame"); - - if let Some(frame) = pinned - .inner - .as_mut() - .project() - .codec - .decode(&mut pinned.buffer)? - { - trace!("frame decoded from buffer"); - return Poll::Ready(Some(Ok(frame))); - } - - *pinned.is_readable = false; - } - - assert!(!*pinned.eof); - - // Otherwise, try to read more data and try again. Make sure we've - // got room for at least one byte to read to ensure that we don't - // get a spurious 0 that looks like EOF - pinned.buffer.reserve(1); - let bytect = match pinned - .inner - .as_mut() - .poll_read_buf(cx, &mut pinned.buffer)? - { - Poll::Ready(ct) => ct, - Poll::Pending => return Poll::Pending, - }; - if bytect == 0 { - *pinned.eof = true; - } - - *pinned.is_readable = true; - } - } -} diff --git a/tokio-util/src/codec/framed_write.rs b/tokio-util/src/codec/framed_write.rs index c0049b2d04a..834eb6ed0f8 100644 --- a/tokio-util/src/codec/framed_write.rs +++ b/tokio-util/src/codec/framed_write.rs @@ -1,20 +1,12 @@ -use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; -use crate::codec::framed::{Fuse, ProjectFuse}; +use crate::codec::framed_impl::{FramedImpl, WriteFrame}; -use tokio::{ - io::{AsyncBufRead, AsyncRead, AsyncWrite}, - stream::Stream, -}; +use tokio::{io::AsyncWrite, stream::Stream}; -use bytes::{Buf, BytesMut}; -use futures_core::ready; use futures_sink::Sink; -use log::trace; use pin_project_lite::pin_project; use std::fmt; -use std::io::{self, BufRead, Read}; -use std::mem::MaybeUninit; +use std::io; use std::pin::Pin; use std::task::{Context, Poll}; @@ -24,21 +16,10 @@ pin_project! { /// [`Sink`]: futures_sink::Sink pub struct FramedWrite { #[pin] - inner: FramedWrite2>, + inner: FramedImpl, } } -pin_project! { - pub(crate) struct FramedWrite2 { - #[pin] - inner: T, - buffer: BytesMut, - } -} - -const INITIAL_CAPACITY: usize = 8 * 1024; -const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY; - impl FramedWrite where T: AsyncWrite, @@ -46,10 +27,11 @@ where /// Creates a new `FramedWrite` with the given `encoder`. pub fn new(inner: T, encoder: E) -> FramedWrite { FramedWrite { - inner: framed_write2(Fuse { - io: inner, + inner: FramedImpl { + inner, codec: encoder, - }), + state: WriteFrame::default(), + }, } } } @@ -62,7 +44,7 @@ impl FramedWrite { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_ref(&self) -> &T { - &self.inner.inner.io + &self.inner.inner } /// Returns a mutable reference to the underlying I/O stream wrapped by @@ -72,7 +54,7 @@ impl FramedWrite { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_mut(&mut self) -> &mut T { - &mut self.inner.inner.io + &mut self.inner.inner } /// Consumes the `FramedWrite`, returning its underlying I/O stream. @@ -81,21 +63,21 @@ impl FramedWrite { /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_inner(self) -> T { - self.inner.inner.io + self.inner.inner } - /// Returns a reference to the underlying decoder. + /// Returns a reference to the underlying encoder. pub fn encoder(&self) -> &E { - &self.inner.inner.codec + &self.inner.codec } - /// Returns a mutable reference to the underlying decoder. + /// Returns a mutable reference to the underlying encoder. pub fn encoder_mut(&mut self) -> &mut E { - &mut self.inner.inner.codec + &mut self.inner.codec } } -// This impl just defers to the underlying FramedWrite2 +// This impl just defers to the underlying FramedImpl impl Sink for FramedWrite where T: AsyncWrite, @@ -121,6 +103,7 @@ where } } +// This impl just defers to the underlying T: Stream impl Stream for FramedWrite where T: Stream, @@ -128,13 +111,7 @@ where type Item = T::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .inner - .project() - .inner - .project() - .io - .poll_next(cx) + self.project().inner.project().inner.poll_next(cx) } } @@ -145,180 +122,9 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FramedWrite") - .field("inner", &self.inner.get_ref().io) - .field("encoder", &self.inner.get_ref().codec) - .field("buffer", &self.inner.buffer) + .field("inner", &self.get_ref()) + .field("encoder", &self.encoder()) + .field("buffer", &self.inner.state.buffer) .finish() } } - -// ===== impl FramedWrite2 ===== - -pub(crate) fn framed_write2(inner: T) -> FramedWrite2 { - FramedWrite2 { - inner, - buffer: BytesMut::with_capacity(INITIAL_CAPACITY), - } -} - -pub(crate) fn framed_write2_with_buffer(inner: T, mut buf: BytesMut) -> FramedWrite2 { - if buf.capacity() < INITIAL_CAPACITY { - let bytes_to_reserve = INITIAL_CAPACITY - buf.capacity(); - buf.reserve(bytes_to_reserve); - } - FramedWrite2 { inner, buffer: buf } -} - -impl FramedWrite2 { - pub(crate) fn get_ref(&self) -> &T { - &self.inner - } - - pub(crate) fn into_inner(self) -> T { - self.inner - } - - pub(crate) fn into_parts(self) -> (T, BytesMut) { - (self.inner, self.buffer) - } - - pub(crate) fn get_mut(&mut self) -> &mut T { - &mut self.inner - } -} - -impl Sink for FramedWrite2 -where - T: ProjectFuse + AsyncWrite, - T::Codec: Encoder, -{ - type Error = >::Error; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // If the buffer is already over 8KiB, then attempt to flush it. If after flushing it's - // *still* over 8KiB, then apply backpressure (reject the send). - if self.buffer.len() >= BACKPRESSURE_BOUNDARY { - match self.as_mut().poll_flush(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok(())) => (), - }; - - if self.buffer.len() >= BACKPRESSURE_BOUNDARY { - return Poll::Pending; - } - } - Poll::Ready(Ok(())) - } - - fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { - let mut pinned = self.project(); - pinned - .inner - .project() - .codec - .encode(item, &mut pinned.buffer)?; - Ok(()) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - trace!("flushing framed transport"); - let mut pinned = self.project(); - - while !pinned.buffer.is_empty() { - trace!("writing; remaining={}", pinned.buffer.len()); - - let buf = &pinned.buffer; - let n = ready!(pinned.inner.as_mut().poll_write(cx, &buf))?; - - if n == 0 { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to \ - write frame to transport", - ) - .into())); - } - - pinned.buffer.advance(n); - } - - // Try flushing the underlying IO - ready!(pinned.inner.poll_flush(cx))?; - - trace!("framed transport flushed"); - Poll::Ready(Ok(())) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().poll_flush(cx))?; - ready!(self.project().inner.poll_shutdown(cx))?; - - Poll::Ready(Ok(())) - } -} - -impl Decoder for FramedWrite2 { - type Item = T::Item; - type Error = T::Error; - - fn decode(&mut self, src: &mut BytesMut) -> Result, T::Error> { - self.inner.decode(src) - } - - fn decode_eof(&mut self, src: &mut BytesMut) -> Result, T::Error> { - self.inner.decode_eof(src) - } -} - -impl Read for FramedWrite2 { - fn read(&mut self, dst: &mut [u8]) -> io::Result { - self.inner.read(dst) - } -} - -impl BufRead for FramedWrite2 { - fn fill_buf(&mut self) -> io::Result<&[u8]> { - self.inner.fill_buf() - } - - fn consume(&mut self, amt: usize) { - self.inner.consume(amt) - } -} - -impl AsyncRead for FramedWrite2 { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } - - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - self.project().inner.poll_read(cx, buf) - } -} - -impl AsyncBufRead for FramedWrite2 { - fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_fill_buf(cx) - } - - fn consume(self: Pin<&mut Self>, amt: usize) { - self.project().inner.consume(amt) - } -} - -impl ProjectFuse for FramedWrite2 -where - T: ProjectFuse, -{ - type Io = T::Io; - type Codec = T::Codec; - - fn project(self: Pin<&mut Self>) -> Fuse, &mut Self::Codec> { - self.project().inner.project() - } -} diff --git a/tokio-util/src/codec/length_delimited.rs b/tokio-util/src/codec/length_delimited.rs index 90684d73428..2426b771ae5 100644 --- a/tokio-util/src/codec/length_delimited.rs +++ b/tokio-util/src/codec/length_delimited.rs @@ -370,7 +370,7 @@ //! [`AsyncRead`]: trait@tokio::io::AsyncRead //! [`AsyncWrite`]: trait@tokio::io::AsyncWrite //! [`Encoder`]: trait@Encoder -//! [`BytesMut`]: https://docs.rs/bytes/0.4/bytes/struct.BytesMut.html +//! [`BytesMut`]: bytes::BytesMut use crate::codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite}; diff --git a/tokio-util/src/codec/mod.rs b/tokio-util/src/codec/mod.rs index ec76a6419f0..e89aa7c9ac8 100644 --- a/tokio-util/src/codec/mod.rs +++ b/tokio-util/src/codec/mod.rs @@ -1,8 +1,13 @@ -//! Utilities for encoding and decoding frames. +//! Adaptors from AsyncRead/AsyncWrite to Stream/Sink //! -//! Contains adapters to go from streams of bytes, [`AsyncRead`] and -//! [`AsyncWrite`], to framed streams implementing [`Sink`] and [`Stream`]. -//! Framed streams are also known as transports. +//! Raw I/O objects work with byte sequences, but higher-level code +//! usually wants to batch these into meaningful chunks, called +//! "frames". +//! +//! This module contains adapters to go from streams of bytes, +//! [`AsyncRead`] and [`AsyncWrite`], to framed streams implementing +//! [`Sink`] and [`Stream`]. Framed streams are also known as +//! transports. //! //! [`AsyncRead`]: tokio::io::AsyncRead //! [`AsyncWrite`]: tokio::io::AsyncWrite @@ -18,6 +23,10 @@ pub use self::decoder::Decoder; mod encoder; pub use self::encoder::Encoder; +mod framed_impl; +#[allow(unused_imports)] +pub(crate) use self::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; + mod framed; pub use self::framed::{Framed, FramedParts}; diff --git a/tokio-util/src/udp/frame.rs b/tokio-util/src/udp/frame.rs index 5b098bd49b2..560f35c9cfa 100644 --- a/tokio-util/src/udp/frame.rs +++ b/tokio-util/src/udp/frame.rs @@ -6,6 +6,7 @@ use bytes::{BufMut, BytesMut}; use futures_core::ready; use futures_sink::Sink; use std::io; +use std::mem::MaybeUninit; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -36,6 +37,8 @@ pub struct UdpFramed { wr: BytesMut, out_addr: SocketAddr, flushed: bool, + is_readable: bool, + current_addr: Option, } impl Stream for UdpFramed { @@ -46,27 +49,39 @@ impl Stream for UdpFramed { pin.rd.reserve(INITIAL_RD_CAPACITY); - let (_n, addr) = unsafe { - // Read into the buffer without having to initialize the memory. - // - // safety: we know tokio::net::UdpSocket never reads from the memory - // during a recv - let res = { - let bytes = &mut *(pin.rd.bytes_mut() as *mut _ as *mut [u8]); - ready!(Pin::new(&mut pin.socket).poll_recv_from(cx, bytes)) - }; + loop { + // Are there are still bytes left in the read buffer to decode? + if pin.is_readable { + if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? { + let current_addr = pin + .current_addr + .expect("will always be set before this line is called"); - let (n, addr) = res?; - pin.rd.advance_mut(n); - (n, addr) - }; + return Poll::Ready(Some(Ok((frame, current_addr)))); + } + + // if this line has been reached then decode has returned `None`. + pin.is_readable = false; + pin.rd.clear(); + } - let frame_res = pin.codec.decode(&mut pin.rd); - pin.rd.clear(); - let frame = frame_res?; - let result = frame.map(|frame| Ok((frame, addr))); // frame -> (frame, addr) + // We're out of data. Try and fetch more data to decode + let addr = unsafe { + // Convert `&mut [MaybeUnit]` to `&mut [u8]` because we will be + // writing to it via `poll_recv_from` and therefore initializing the memory. + let buf: &mut [u8] = + &mut *(pin.rd.bytes_mut() as *mut [MaybeUninit] as *mut [u8]); - Poll::Ready(result) + let res = ready!(Pin::new(&mut pin.socket).poll_recv_from(cx, buf)); + + let (n, addr) = res?; + pin.rd.advance_mut(n); + addr + }; + + pin.current_addr = Some(addr); + pin.is_readable = true; + } } } @@ -148,6 +163,8 @@ impl UdpFramed { rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY), wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY), flushed: true, + is_readable: false, + current_addr: None, } } diff --git a/tokio-util/tests/udp.rs b/tokio-util/tests/udp.rs index 0ba0574281c..d0320beb185 100644 --- a/tokio-util/tests/udp.rs +++ b/tokio-util/tests/udp.rs @@ -1,5 +1,5 @@ use tokio::{net::UdpSocket, stream::StreamExt}; -use tokio_util::codec::{Decoder, Encoder}; +use tokio_util::codec::{Decoder, Encoder, LinesCodec}; use tokio_util::udp::UdpFramed; use bytes::{BufMut, BytesMut}; @@ -10,7 +10,7 @@ use std::io; #[cfg_attr(any(target_os = "macos", target_os = "ios"), allow(unused_assignments))] #[tokio::test] -async fn send_framed() -> std::io::Result<()> { +async fn send_framed_byte_codec() -> std::io::Result<()> { let mut a_soc = UdpSocket::bind("127.0.0.1:0").await?; let mut b_soc = UdpSocket::bind("127.0.0.1:0").await?; @@ -77,3 +77,24 @@ impl Encoder<&[u8]> for ByteCodec { Ok(()) } } + +#[tokio::test] +async fn send_framed_lines_codec() -> std::io::Result<()> { + let a_soc = UdpSocket::bind("127.0.0.1:0").await?; + let b_soc = UdpSocket::bind("127.0.0.1:0").await?; + + let a_addr = a_soc.local_addr()?; + let b_addr = b_soc.local_addr()?; + + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, LinesCodec::new()); + + let msg = b"1\r\n2\r\n3\r\n".to_vec(); + a.send((&msg, b_addr)).await?; + + assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr)); + + Ok(()) +} diff --git a/tokio/CHANGELOG.md b/tokio/CHANGELOG.md index 26d052c5682..9d9b612d1f8 100644 --- a/tokio/CHANGELOG.md +++ b/tokio/CHANGELOG.md @@ -1,3 +1,52 @@ +# 0.2.22 (July 21, 2020) + +### Fixes +- docs: misc improvements (#2572, #2658, #2663, #2656, #2647, #2630, #2487, #2621, + #2624, #2600, #2623, #2622, #2577, #2569, #2589, #2575, #2540, #2564, #2567, + #2520, #2521, #2493) +- rt: allow calls to `block_on` inside calls to `block_in_place` that are + themselves inside `block_on` (#2645) +- net: fix non-portable behavior when dropping `TcpStream` `OwnedWriteHalf` (#2597) +- io: improve stack usage by allocating large buffers on directly on the heap + (#2634) +- io: fix unsound pin projection in `AsyncReadExt::read_buf` and + `AsyncWriteExt::write_buf` (#2612) +- io: fix unnecessary zeroing for `AsyncRead` implementors (#2525) +- io: Fix `BufReader` not correctly forwarding `poll_write_buf` (#2654) +- io: fix panic in `AsyncReadExt::read_line` (#2541) + +### Changes +- coop: returning `Poll::Pending` no longer decrements the task budget (#2549) + +### Added +- io: little-endian variants of `AsyncReadExt` and `AsyncWriteExt` methods + (#1915) +- task: add [`tracing`] instrumentation to spawned tasks (#2655) +- sync: allow unsized types in `Mutex` and `RwLock` (via `default` constructors) + (#2615) +- net: add `ToSocketAddrs` implementation for `&[SocketAddr]` (#2604) +- fs: add `OpenOptionsExt` for `OpenOptions` (#2515) +- fs: add `DirBuilder` (#2524) + +[`tracing`]: https://crates.io/crates/tracing + +# 0.2.21 (May 13, 2020) + +### Fixes + +- macros: disambiguate built-in `#[test]` attribute in macro expansion (#2503) +- rt: `LocalSet` and task budgeting (#2462). +- rt: task budgeting with `block_in_place` (#2502). +- sync: release `broadcast` channel memory without sending a value (#2509). +- time: notify when resetting a `Delay` to a time in the past (#2290) + +### Added +- io: `get_mut`, `get_ref`, and `into_inner` to `Lines` (#2450). +- io: `mio::Ready` argument to `PollEvented` (#2419). +- os: illumos support (#2486). +- rt: `Handle::spawn_blocking` (#2501). +- sync: `OwnedMutexGuard` for `Arc>` (#2455). + # 0.2.20 (April 28, 2020) ### Fixes diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index 063d208efe0..dbcccdf662d 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -8,12 +8,12 @@ name = "tokio" # - README.md # - Update CHANGELOG.md. # - Create "v0.2.x" git tag. -version = "0.2.20" +version = "0.2.22" edition = "2018" authors = ["Tokio Contributors "] license = "MIT" readme = "README.md" -documentation = "https://docs.rs/tokio/0.2.20/tokio/" +documentation = "https://docs.rs/tokio/0.2.22/tokio/" repository = "https://github.com/tokio-rs/tokio" homepage = "https://tokio.rs" description = """ @@ -106,6 +106,7 @@ iovec = { version = "0.1.4", optional = true } num_cpus = { version = "1.8.0", optional = true } parking_lot = { version = "0.10.0", optional = true } # Not in full slab = { version = "0.4.1", optional = true } # Backs `DelayQueue` +tracing = { version = "0.1.16", default-features = false, features = ["std"], optional = true } # Not in full [target.'cfg(unix)'.dependencies] mio-uds = { version = "0.6.5", optional = true } @@ -123,13 +124,12 @@ optional = true [dev-dependencies] tokio-test = { version = "0.2.0", path = "../tokio-test" } futures = { version = "0.3.0", features = ["async-await"] } +futures-test = "0.3.0" proptest = "0.9.4" tempfile = "3.1.0" -# loom is currently not compiling on windows. -# See: https://github.com/Xudong-Huang/generator-rs/issues/19 -[target.'cfg(not(windows))'.dev-dependencies] -loom = { version = "0.3.1", features = ["futures", "checkpoint"] } +[target.'cfg(loom)'.dev-dependencies] +loom = { version = "0.3.4", features = ["futures", "checkpoint"] } [package.metadata.docs.rs] all-features = true diff --git a/tokio/README.md b/tokio/README.md index 080181f8975..da9078c5824 100644 --- a/tokio/README.md +++ b/tokio/README.md @@ -20,16 +20,17 @@ the Rust programming language. It is: [crates-badge]: https://img.shields.io/crates/v/tokio.svg [crates-url]: https://crates.io/crates/tokio [mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg -[mit-url]: LICENSE +[mit-url]: https://github.com/tokio-rs/tokio/blob/master/LICENSE [azure-badge]: https://dev.azure.com/tokio-rs/Tokio/_apis/build/status/tokio-rs.tokio?branchName=master [azure-url]: https://dev.azure.com/tokio-rs/Tokio/_build/latest?definitionId=1&branchName=master [discord-badge]: https://img.shields.io/discord/500028886025895936.svg?logo=discord&style=flat-square -[discord-url]: https://discord.gg/6yGkFeN +[discord-url]: https://discord.gg/tokio [Website](https://tokio.rs) | -[Guides](https://tokio.rs/docs/) | -[API Docs](https://docs.rs/tokio/0.2/tokio) | -[Chat](https://discord.gg/6yGkFeN) +[Guides](https://tokio.rs/tokio/tutorial) | +[API Docs](https://docs.rs/tokio/latest/tokio) | +[Roadmap](https://github.com/tokio-rs/tokio/blob/master/ROADMAP.md) | +[Chat](https://discord.gg/tokio) ## Overview @@ -45,20 +46,11 @@ level, it provides a few major components: These components provide the runtime components necessary for building an asynchronous application. -[net]: https://docs.rs/tokio/0.2/tokio/net/index.html -[scheduler]: https://docs.rs/tokio/0.2/tokio/runtime/index.html +[net]: https://docs.rs/tokio/latest/tokio/net/index.html +[scheduler]: https://docs.rs/tokio/latest/tokio/runtime/index.html ## Example -To get started, add the following to `Cargo.toml`. - -```toml -tokio = { version = "0.2", features = ["full"] } -``` - -Tokio requires components to be explicitly enabled using feature flags. As a -shorthand, the `full` feature enables all components. - A basic TCP echo server with Tokio: ```rust,no_run @@ -98,19 +90,27 @@ async fn main() -> Result<(), Box> { } ``` -More examples can be found [here](../examples). +More examples can be found [here][examples]. For a larger "real world" example, see the +[mini-redis] repository. + +[examples]: https://github.com/tokio-rs/tokio/tree/master/examples +[mini-redis]: https://github.com/tokio-rs/mini-redis/ + +To see a list of the available features flags that can be enabled, check our +[docs][feature-flag-docs]. ## Getting Help First, see if the answer to your question can be found in the [Guides] or the [API documentation]. If the answer is not there, there is an active community in the [Tokio Discord server][chat]. We would be happy to try to answer your -question. Last, if that doesn't work, try opening an [issue] with the question. +question. You can also ask your question on [the discussions page][discussions]. -[Guides]: https://tokio.rs/docs/ -[API documentation]: https://docs.rs/tokio/0.2 -[chat]: https://discord.gg/6yGkFeN -[issue]: https://github.com/tokio-rs/tokio/issues/new +[Guides]: https://tokio.rs/tokio/tutorial +[API documentation]: https://docs.rs/tokio/latest/tokio +[chat]: https://discord.gg/tokio +[discussions]: https://github.com/tokio-rs/tokio/discussions +[feature-flag-docs]: https://docs.rs/tokio/#feature-flags ## Contributing @@ -118,36 +118,54 @@ question. Last, if that doesn't work, try opening an [issue] with the question. you! We have a [contributing guide][guide] to help you get involved in the Tokio project. -[guide]: CONTRIBUTING.md +[guide]: https://github.com/tokio-rs/tokio/blob/master/CONTRIBUTING.md ## Related Projects In addition to the crates in this repository, the Tokio project also maintains several other libraries, including: +* [`hyper`]: A fast and correct HTTP/1.1 and HTTP/2 implementation for Rust. + +* [`tonic`]: A gRPC over HTTP/2 implementation focused on high performance, interoperability, and flexibility. + +* [`warp`]: A super-easy, composable, web server framework for warp speeds. + +* [`tower`]: A library of modular and reusable components for building robust networking clients and servers. + * [`tracing`] (formerly `tokio-trace`): A framework for application-level tracing and async-aware diagnostics. +* [`rdbc`]: A Rust database connectivity library for MySQL, Postgres and SQLite. + * [`mio`]: A low-level, cross-platform abstraction over OS I/O APIs that powers `tokio`. * [`bytes`]: Utilities for working with bytes, including efficient byte buffers. +* [`loom`]: A testing tool for concurrent Rust code + +[`warp`]: https://github.com/seanmonstar/warp +[`hyper`]: https://github.com/hyperium/hyper +[`tonic`]: https://github.com/hyperium/tonic +[`tower`]: https://github.com/tower-rs/tower +[`loom`]: https://github.com/tokio-rs/loom +[`rdbc`]: https://github.com/tokio-rs/rdbc [`tracing`]: https://github.com/tokio-rs/tracing [`mio`]: https://github.com/tokio-rs/mio [`bytes`]: https://github.com/tokio-rs/bytes ## Supported Rust Versions -Tokio is built against the latest stable, nightly, and beta Rust releases. The -minimum version supported is the stable release from three months before the -current stable release version. For example, if the latest stable Rust is 1.29, -the minimum version supported is 1.26. The current Tokio version is not -guaranteed to build on Rust versions earlier than the minimum supported version. +Tokio is built against the latest stable release. The minimum supported version is 1.39. +The current Tokio version is not guaranteed to build on Rust versions earlier than the +minimum supported version. ## License -This project is licensed under the [MIT license](LICENSE). +This project is licensed under the [MIT license]. + +[MIT license]: https://github.com/tokio-rs/tokio/blob/master/LICENSE ### Contribution diff --git a/tokio/src/coop.rs b/tokio/src/coop.rs index 606ba3a7395..27e969c59d4 100644 --- a/tokio/src/coop.rs +++ b/tokio/src/coop.rs @@ -1,11 +1,12 @@ //! Opt-in yield points for improved cooperative scheduling. //! -//! A single call to [`poll`] on a top-level task may potentially do a lot of work before it -//! returns `Poll::Pending`. If a task runs for a long period of time without yielding back to the -//! executor, it can starve other tasks waiting on that executor to execute them, or drive -//! underlying resources. Since Rust does not have a runtime, it is difficult to forcibly preempt a -//! long-running task. Instead, this module provides an opt-in mechanism for futures to collaborate -//! with the executor to avoid starvation. +//! A single call to [`poll`] on a top-level task may potentially do a lot of +//! work before it returns `Poll::Pending`. If a task runs for a long period of +//! time without yielding back to the executor, it can starve other tasks +//! waiting on that executor to execute them, or drive underlying resources. +//! Since Rust does not have a runtime, it is difficult to forcibly preempt a +//! long-running task. Instead, this module provides an opt-in mechanism for +//! futures to collaborate with the executor to avoid starvation. //! //! Consider a future like this one: //! @@ -16,9 +17,10 @@ //! } //! ``` //! -//! It may look harmless, but consider what happens under heavy load if the input stream is -//! _always_ ready. If we spawn `drop_all`, the task will never yield, and will starve other tasks -//! and resources on the same executor. With opt-in yield points, this problem is alleviated: +//! It may look harmless, but consider what happens under heavy load if the +//! input stream is _always_ ready. If we spawn `drop_all`, the task will never +//! yield, and will starve other tasks and resources on the same executor. With +//! opt-in yield points, this problem is alleviated: //! //! ```ignore //! # use tokio::stream::{Stream, StreamExt}; @@ -29,334 +31,195 @@ //! } //! ``` //! -//! The `proceed` future will coordinate with the executor to make sure that every so often control -//! is yielded back to the executor so it can run other tasks. +//! The `proceed` future will coordinate with the executor to make sure that +//! every so often control is yielded back to the executor so it can run other +//! tasks. //! //! # Placing yield points //! -//! Voluntary yield points should be placed _after_ at least some work has been done. If they are -//! not, a future sufficiently deep in the task hierarchy may end up _never_ getting to run because -//! of the number of yield points that inevitably appear before it is reached. In general, you will -//! want yield points to only appear in "leaf" futures -- those that do not themselves poll other -//! futures. By doing this, you avoid double-counting each iteration of the outer future against -//! the cooperating budget. +//! Voluntary yield points should be placed _after_ at least some work has been +//! done. If they are not, a future sufficiently deep in the task hierarchy may +//! end up _never_ getting to run because of the number of yield points that +//! inevitably appear before it is reached. In general, you will want yield +//! points to only appear in "leaf" futures -- those that do not themselves poll +//! other futures. By doing this, you avoid double-counting each iteration of +//! the outer future against the cooperating budget. //! -//! [`poll`]: https://doc.rust-lang.org/std/future/trait.Future.html#tymethod.poll +//! [`poll`]: method@std::future::Future::poll // NOTE: The doctests in this module are ignored since the whole module is (currently) private. use std::cell::Cell; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; - -/// Constant used to determine how much "work" a task is allowed to do without yielding. -/// -/// The value itself is chosen somewhat arbitrarily. It needs to be high enough to amortize wakeup -/// and scheduling costs, but low enough that we do not starve other tasks for too long. The value -/// also needs to be high enough that particularly deep tasks are able to do at least some useful -/// work at all. -/// -/// Note that as more yield points are added in the ecosystem, this value will probably also have -/// to be raised. -const BUDGET: usize = 128; - -/// Constant used to determine if budgeting has been disabled. -const UNCONSTRAINED: usize = usize::max_value(); thread_local! { - static HITS: Cell = Cell::new(UNCONSTRAINED); + static CURRENT: Cell = Cell::new(Budget::unconstrained()); } -/// Run the given closure with a cooperative task budget. -/// -/// Enabling budgeting when it is already enabled is a no-op. -#[inline(always)] -pub(crate) fn budget(f: F) -> R -where - F: FnOnce() -> R, -{ - HITS.with(move |hits| { - if hits.get() != UNCONSTRAINED { - // We are already being budgeted. - // - // Arguably this should be an error, but it can happen "correctly" - // such as with block_on + LocalSet, so we make it a no-op. - return f(); - } +/// Opaque type tracking the amount of "work" a task may still do before +/// yielding back to the scheduler. +#[derive(Debug, Copy, Clone)] +pub(crate) struct Budget(Option); - hits.set(BUDGET); - let _guard = ResetGuard { - hits, - prev: UNCONSTRAINED, - }; - f() - }) +impl Budget { + /// Budget assigned to a task on each poll. + /// + /// The value itself is chosen somewhat arbitrarily. It needs to be high + /// enough to amortize wakeup and scheduling costs, but low enough that we + /// do not starve other tasks for too long. The value also needs to be high + /// enough that particularly deep tasks are able to do at least some useful + /// work at all. + /// + /// Note that as more yield points are added in the ecosystem, this value + /// will probably also have to be raised. + const fn initial() -> Budget { + Budget(Some(128)) + } + + /// Returns an unconstrained budget. Operations will not be limited. + const fn unconstrained() -> Budget { + Budget(None) + } } cfg_rt_threaded! { - #[inline(always)] - pub(crate) fn has_budget_remaining() -> bool { - HITS.with(|hits| hits.get() > 0) + impl Budget { + fn has_remaining(self) -> bool { + self.0.map(|budget| budget > 0).unwrap_or(true) + } } } -cfg_blocking_impl! { - /// Forcibly remove the budgeting constraints early. - pub(crate) fn stop() { - HITS.with(|hits| { - hits.set(UNCONSTRAINED); - }); - } +/// Run the given closure with a cooperative task budget. When the function +/// returns, the budget is reset to the value prior to calling the function. +#[inline(always)] +pub(crate) fn budget(f: impl FnOnce() -> R) -> R { + with_budget(Budget::initial(), f) } -cfg_rt_core! { - cfg_rt_util! { - /// Run the given closure with a new task budget, resetting the previous - /// budget when the closure finishes. - /// - /// This is intended for internal use by `LocalSet` and (potentially) other - /// similar schedulers which are themselves futures, and need a fresh budget - /// for each of their children. - #[inline(always)] - pub(crate) fn reset(f: F) -> R - where - F: FnOnce() -> R, - { - HITS.with(move |hits| { - let prev = hits.get(); - hits.set(UNCONSTRAINED); - let _guard = ResetGuard { - hits, - prev, - }; - f() - }) - } +cfg_rt_threaded! { + /// Set the current task's budget + #[cfg(feature = "blocking")] + pub(crate) fn set(budget: Budget) { + CURRENT.with(|cell| cell.set(budget)) } } -/// Invoke `f` with a subset of the remaining budget. -/// -/// This is useful if you have sub-futures that you need to poll, but that you want to restrict -/// from using up your entire budget. For example, imagine the following future: -/// -/// ```rust -/// # use std::{future::Future, pin::Pin, task::{Context, Poll}}; -/// use futures::stream::FuturesUnordered; -/// struct MyFuture { -/// big: FuturesUnordered, -/// small: F2, -/// } -/// -/// use tokio::stream::Stream; -/// impl Future for MyFuture -/// where F1: Future, F2: Future -/// # , F1: Unpin, F2: Unpin -/// { -/// type Output = F2::Output; -/// -/// // fn poll(...) -/// # fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { -/// # let this = &mut *self; -/// let mut big = // something to pin self.big -/// # Pin::new(&mut this.big); -/// let small = // something to pin self.small -/// # Pin::new(&mut this.small); -/// -/// // see if any of the big futures have finished -/// while let Some(e) = futures::ready!(big.as_mut().poll_next(cx)) { -/// // do something with e -/// # let _ = e; -/// } -/// -/// // see if the small future has finished -/// small.poll(cx) -/// } -/// # } -/// ``` -/// -/// It could be that every time `poll` gets called, `big` ends up spending the entire budget, and -/// `small` never gets polled. That would be sad. If you want to stick up for the little future, -/// that's what `limit` is for. It lets you portion out a smaller part of the yield budget to a -/// particular segment of your code. In the code above, you would write -/// -/// ```rust,ignore -/// # use std::{future::Future, pin::Pin, task::{Context, Poll}}; -/// # use futures::stream::FuturesUnordered; -/// # struct MyFuture { -/// # big: FuturesUnordered, -/// # small: F2, -/// # } -/// # -/// # use tokio::stream::Stream; -/// # impl Future for MyFuture -/// # where F1: Future, F2: Future -/// # , F1: Unpin, F2: Unpin -/// # { -/// # type Output = F2::Output; -/// # fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { -/// # let this = &mut *self; -/// # let mut big = Pin::new(&mut this.big); -/// # let small = Pin::new(&mut this.small); -/// # -/// // see if any of the big futures have finished -/// while let Some(e) = futures::ready!(tokio::coop::limit(64, || big.as_mut().poll_next(cx))) { -/// # // do something with e -/// # let _ = e; -/// # } -/// # small.poll(cx) -/// # } -/// # } -/// ``` -/// -/// Now, even if `big` spends its entire budget, `small` will likely be left with some budget left -/// to also do useful work. In particular, if the remaining budget was `N` at the start of `poll`, -/// `small` will have at least a budget of `N - 64`. It may be more if `big` did not spend its -/// entire budget. -/// -/// Note that you cannot _increase_ your budget by calling `limit`. The budget provided to the code -/// inside the buget is the _minimum_ of the _current_ budget and the bound. -/// -#[allow(unreachable_pub, dead_code)] -pub fn limit(bound: usize, f: impl FnOnce() -> R) -> R { - HITS.with(|hits| { - let budget = hits.get(); - // with_bound cannot _increase_ the remaining budget - let bound = std::cmp::min(budget, bound); - // When f() exits, how much should we add to what is left? - let floor = budget.saturating_sub(bound); - // Make sure we restore the remaining budget even on panic - struct RestoreBudget<'a>(&'a Cell, usize); - impl<'a> Drop for RestoreBudget<'a> { - fn drop(&mut self) { - let left = self.0.get(); - self.0.set(self.1 + left); - } - } - // Time to restrict! - hits.set(bound); - let _restore = RestoreBudget(&hits, floor); - f() - }) -} +#[inline(always)] +fn with_budget(budget: Budget, f: impl FnOnce() -> R) -> R { + struct ResetGuard<'a> { + cell: &'a Cell, + prev: Budget, + } -/// Returns `Poll::Pending` if the current task has exceeded its budget and should yield. -#[allow(unreachable_pub, dead_code)] -#[inline] -pub fn poll_proceed(cx: &mut Context<'_>) -> Poll<()> { - HITS.with(|hits| { - let n = hits.get(); - if n == UNCONSTRAINED { - // opted out of budgeting - Poll::Ready(()) - } else if n == 0 { - cx.waker().wake_by_ref(); - Poll::Pending - } else { - hits.set(n.saturating_sub(1)); - Poll::Ready(()) + impl<'a> Drop for ResetGuard<'a> { + fn drop(&mut self) { + self.cell.set(self.prev); } + } + + CURRENT.with(move |cell| { + let prev = cell.get(); + + cell.set(budget); + + let _guard = ResetGuard { cell, prev }; + + f() }) } -/// Resolves immediately unless the current task has already exceeded its budget. -/// -/// This should be placed after at least some work has been done. Otherwise a future sufficiently -/// deep in the task hierarchy may end up never getting to run because of the number of yield -/// points that inevitably appear before it is even reached. For example: -/// -/// ```ignore -/// # use tokio::stream::{Stream, StreamExt}; -/// async fn drop_all(mut input: I) { -/// while let Some(_) = input.next().await { -/// tokio::coop::proceed().await; -/// } -/// } -/// ``` -#[allow(unreachable_pub, dead_code)] -#[inline] -pub async fn proceed() { - use crate::future::poll_fn; - poll_fn(|cx| poll_proceed(cx)).await; +cfg_rt_threaded! { + #[inline(always)] + pub(crate) fn has_budget_remaining() -> bool { + CURRENT.with(|cell| cell.get().has_remaining()) + } } -pin_project_lite::pin_project! { - /// A future that cooperatively yields to the task scheduler when polling, - /// if the task's budget is exhausted. - /// - /// Internally, this is simply a future combinator which calls - /// [`poll_proceed`] in its `poll` implementation before polling the wrapped - /// future. - /// - /// # Examples - /// - /// ```rust,ignore - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::coop::CoopFutureExt; - /// - /// async { /* ... */ } - /// .cooperate() - /// .await; - /// # } - /// ``` +cfg_blocking_impl! { + /// Forcibly remove the budgeting constraints early. /// - /// [`poll_proceed`]: fn@poll_proceed - #[derive(Debug)] - #[allow(unreachable_pub, dead_code)] - pub struct CoopFuture { - #[pin] - future: F, + /// Returns the remaining budget + pub(crate) fn stop() -> Budget { + CURRENT.with(|cell| { + let prev = cell.get(); + cell.set(Budget::unconstrained()); + prev + }) } } -struct ResetGuard<'a> { - hits: &'a Cell, - prev: usize, -} +cfg_coop! { + use std::task::{Context, Poll}; -impl Future for CoopFuture { - type Output = F::Output; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - ready!(poll_proceed(cx)); - self.project().future.poll(cx) + #[must_use] + pub(crate) struct RestoreOnPending(Cell); + + impl RestoreOnPending { + pub(crate) fn made_progress(&self) { + self.0.set(Budget::unconstrained()); + } } -} -impl CoopFuture { - /// Returns a new `CoopFuture` wrapping the given future. - /// - #[allow(unreachable_pub, dead_code)] - pub fn new(future: F) -> Self { - Self { future } + impl Drop for RestoreOnPending { + fn drop(&mut self) { + // Don't reset if budget was unconstrained or if we made progress. + // They are both represented as the remembered budget being unconstrained. + let budget = self.0.get(); + if !budget.is_unconstrained() { + CURRENT.with(|cell| { + cell.set(budget); + }); + } + } } -} -// Currently only used by `tokio::sync`; and if we make this combinator public, -// it should probably be on the `FutureExt` trait instead. -cfg_sync! { - /// Extension trait providing `Future::cooperate` extension method. + /// Returns `Poll::Pending` if the current task has exceeded its budget and should yield. /// - /// Note: if/when the co-op API becomes public, this method should probably be - /// provided by `FutureExt`, instead. - pub(crate) trait CoopFutureExt: Future { - /// Wrap `self` to cooperatively yield to the scheduler when polling, if the - /// task's budget is exhausted. - fn cooperate(self) -> CoopFuture - where - Self: Sized, - { - CoopFuture::new(self) - } + /// When you call this method, the current budget is decremented. However, to ensure that + /// progress is made every time a task is polled, the budget is automatically restored to its + /// former value if the returned `RestoreOnPending` is dropped. It is the caller's + /// responsibility to call `RestoreOnPending::made_progress` if it made progress, to ensure + /// that the budget empties appropriately. + /// + /// Note that `RestoreOnPending` restores the budget **as it was before `poll_proceed`**. + /// Therefore, if the budget is _further_ adjusted between when `poll_proceed` returns and + /// `RestRestoreOnPending` is dropped, those adjustments are erased unless the caller indicates + /// that progress was made. + #[inline] + pub(crate) fn poll_proceed(cx: &mut Context<'_>) -> Poll { + CURRENT.with(|cell| { + let mut budget = cell.get(); + + if budget.decrement() { + let restore = RestoreOnPending(Cell::new(cell.get())); + cell.set(budget); + Poll::Ready(restore) + } else { + cx.waker().wake_by_ref(); + Poll::Pending + } + }) } - impl CoopFutureExt for F where F: Future {} -} + impl Budget { + /// Decrement the budget. Returns `true` if successful. Decrementing fails + /// when there is not enough remaining budget. + fn decrement(&mut self) -> bool { + if let Some(num) = &mut self.0 { + if *num > 0 { + *num -= 1; + true + } else { + false + } + } else { + true + } + } -impl<'a> Drop for ResetGuard<'a> { - fn drop(&mut self) { - self.hits.set(self.prev); + fn is_unconstrained(self) -> bool { + self.0.is_none() + } } } @@ -364,49 +227,75 @@ impl<'a> Drop for ResetGuard<'a> { mod test { use super::*; - fn get() -> usize { - HITS.with(|hits| hits.get()) + fn get() -> Budget { + CURRENT.with(|cell| cell.get()) } #[test] fn bugeting() { + use futures::future::poll_fn; use tokio_test::*; - assert_eq!(get(), UNCONSTRAINED); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), UNCONSTRAINED); + assert!(get().0.is_none()); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + + assert!(get().0.is_none()); + drop(coop); + assert!(get().0.is_none()); + budget(|| { - assert_eq!(get(), BUDGET); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), BUDGET - 1); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), BUDGET - 2); + assert_eq!(get().0, Budget::initial().0); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + drop(coop); + // we didn't make progress + assert_eq!(get().0, Budget::initial().0); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + coop.made_progress(); + drop(coop); + // we _did_ make progress + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2); + coop.made_progress(); + drop(coop); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2); + + budget(|| { + assert_eq!(get().0, Budget::initial().0); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + coop.made_progress(); + drop(coop); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + }); + + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2); }); - assert_eq!(get(), UNCONSTRAINED); + + assert!(get().0.is_none()); budget(|| { - limit(3, || { - assert_eq!(get(), 3); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), 2); - limit(4, || { - assert_eq!(get(), 2); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), 1); - }); - assert_eq!(get(), 1); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), 0); - assert_pending!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), 0); - assert_pending!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), 0); - }); - assert_eq!(get(), BUDGET - 3); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); - assert_eq!(get(), BUDGET - 4); - assert_ready!(task::spawn(proceed()).poll()); - assert_eq!(get(), BUDGET - 5); + let n = get().0.unwrap(); + + for _ in 0..n { + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + coop.made_progress(); + } + + let mut task = task::spawn(poll_fn(|cx| { + let coop = ready!(poll_proceed(cx)); + coop.made_progress(); + Poll::Ready(()) + })); + + assert_pending!(task.poll()); }); } } diff --git a/tokio/src/fs/dir_builder.rs b/tokio/src/fs/dir_builder.rs new file mode 100644 index 00000000000..8752a3716aa --- /dev/null +++ b/tokio/src/fs/dir_builder.rs @@ -0,0 +1,117 @@ +use crate::fs::asyncify; + +use std::io; +use std::path::Path; + +/// A builder for creating directories in various manners. +/// +/// Additional Unix-specific options are available via importing the +/// [`DirBuilderExt`] trait. +/// +/// This is a specialized version of [`std::fs::DirBuilder`] for usage on +/// the Tokio runtime. +/// +/// [std::fs::DirBuilder]: std::fs::DirBuilder +/// [`DirBuilderExt`]: crate::fs::os::unix::DirBuilderExt +#[derive(Debug, Default)] +pub struct DirBuilder { + /// Indicates whether to create parent directories if they are missing. + recursive: bool, + + /// Set the Unix mode for newly created directories. + #[cfg(unix)] + pub(super) mode: Option, +} + +impl DirBuilder { + /// Creates a new set of options with default mode/security settings for all + /// platforms and also non-recursive. + /// + /// This is an async version of [`std::fs::DirBuilder::new`][std] + /// + /// [std]: std::fs::DirBuilder::new + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::DirBuilder; + /// + /// let builder = DirBuilder::new(); + /// ``` + pub fn new() -> Self { + Default::default() + } + + /// Indicates whether to create directories recursively (including all parent directories). + /// Parents that do not exist are created with the same security and permissions settings. + /// + /// This option defaults to `false`. + /// + /// This is an async version of [`std::fs::DirBuilder::recursive`][std] + /// + /// [std]: std::fs::DirBuilder::recursive + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::DirBuilder; + /// + /// let mut builder = DirBuilder::new(); + /// builder.recursive(true); + /// ``` + pub fn recursive(&mut self, recursive: bool) -> &mut Self { + self.recursive = recursive; + self + } + + /// Creates the specified directory with the configured options. + /// + /// It is considered an error if the directory already exists unless + /// recursive mode is enabled. + /// + /// This is an async version of [`std::fs::DirBuilder::create`][std] + /// + /// [std]: std::fs::DirBuilder::create + /// + /// # Errors + /// + /// An error will be returned under the following circumstances: + /// + /// * Path already points to an existing file. + /// * Path already points to an existing directory and the mode is + /// non-recursive. + /// * The calling process doesn't have permissions to create the directory + /// or its missing parents. + /// * Other I/O error occurred. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::DirBuilder; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// DirBuilder::new() + /// .recursive(true) + /// .create("/tmp/foo/bar/baz") + /// .await?; + /// + /// Ok(()) + /// } + /// ``` + pub async fn create>(&self, path: P) -> io::Result<()> { + let path = path.as_ref().to_owned(); + let mut builder = std::fs::DirBuilder::new(); + builder.recursive(self.recursive); + + #[cfg(unix)] + { + if let Some(mode) = self.mode { + std::os::unix::fs::DirBuilderExt::mode(&mut builder, mode); + } + } + + asyncify(move || builder.create(path)).await + } +} diff --git a/tokio/src/fs/file.rs b/tokio/src/fs/file.rs index cc4a187d78f..f3bc98546a9 100644 --- a/tokio/src/fs/file.rs +++ b/tokio/src/fs/file.rs @@ -537,6 +537,11 @@ impl File { } impl AsyncRead for File { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { + // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/fs.rs#L668 + false + } + fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/tokio/src/fs/mod.rs b/tokio/src/fs/mod.rs index 3eb03764631..a2b062b1a30 100644 --- a/tokio/src/fs/mod.rs +++ b/tokio/src/fs/mod.rs @@ -33,6 +33,9 @@ pub use self::create_dir::create_dir; mod create_dir_all; pub use self::create_dir_all::create_dir_all; +mod dir_builder; +pub use self::dir_builder::DirBuilder; + mod file; pub use self::file::File; diff --git a/tokio/src/fs/open_options.rs b/tokio/src/fs/open_options.rs index 3210f4b7b5d..ba3d9a6cf67 100644 --- a/tokio/src/fs/open_options.rs +++ b/tokio/src/fs/open_options.rs @@ -382,6 +382,12 @@ impl OpenOptions { let std = asyncify(move || opts.open(path)).await?; Ok(File::from_std(std)) } + + /// Returns a mutable reference to the the underlying std::fs::OpenOptions + #[cfg(unix)] + pub(super) fn as_inner_mut(&mut self) -> &mut std::fs::OpenOptions { + &mut self.0 + } } impl From for OpenOptions { diff --git a/tokio/src/fs/os/unix/dir_builder_ext.rs b/tokio/src/fs/os/unix/dir_builder_ext.rs new file mode 100644 index 00000000000..e9a25b959c6 --- /dev/null +++ b/tokio/src/fs/os/unix/dir_builder_ext.rs @@ -0,0 +1,29 @@ +use crate::fs::dir_builder::DirBuilder; + +/// Unix-specific extensions to [`DirBuilder`]. +/// +/// [`DirBuilder`]: crate::fs::DirBuilder +pub trait DirBuilderExt { + /// Sets the mode to create new directories with. + /// + /// This option defaults to 0o777. + /// + /// # Examples + /// + /// + /// ```no_run + /// use tokio::fs::DirBuilder; + /// use tokio::fs::os::unix::DirBuilderExt; + /// + /// let mut builder = DirBuilder::new(); + /// builder.mode(0o775); + /// ``` + fn mode(&mut self, mode: u32) -> &mut Self; +} + +impl DirBuilderExt for DirBuilder { + fn mode(&mut self, mode: u32) -> &mut Self { + self.mode = Some(mode); + self + } +} diff --git a/tokio/src/fs/os/unix/mod.rs b/tokio/src/fs/os/unix/mod.rs index 3b0bec38bd5..826222ebf23 100644 --- a/tokio/src/fs/os/unix/mod.rs +++ b/tokio/src/fs/os/unix/mod.rs @@ -2,3 +2,9 @@ mod symlink; pub use self::symlink::symlink; + +mod open_options_ext; +pub use self::open_options_ext::OpenOptionsExt; + +mod dir_builder_ext; +pub use self::dir_builder_ext::DirBuilderExt; diff --git a/tokio/src/fs/os/unix/open_options_ext.rs b/tokio/src/fs/os/unix/open_options_ext.rs new file mode 100644 index 00000000000..ff892758804 --- /dev/null +++ b/tokio/src/fs/os/unix/open_options_ext.rs @@ -0,0 +1,79 @@ +use crate::fs::open_options::OpenOptions; +use std::os::unix::fs::OpenOptionsExt as StdOpenOptionsExt; + +/// Unix-specific extensions to [`fs::OpenOptions`]. +/// +/// This mirrors the definition of [`std::os::unix::fs::OpenOptionsExt`]. +/// +/// +/// [`fs::OpenOptions`]: crate::fs::OpenOptions +/// [`std::os::unix::fs::OpenOptionsExt`]: std::os::unix::fs::OpenOptionsExt +pub trait OpenOptionsExt { + /// Sets the mode bits that a new file will be created with. + /// + /// If a new file is created as part of an `OpenOptions::open` call then this + /// specified `mode` will be used as the permission bits for the new file. + /// If no `mode` is set, the default of `0o666` will be used. + /// The operating system masks out bits with the system's `umask`, to produce + /// the final permissions. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// use tokio::fs::os::unix::OpenOptionsExt; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut options = OpenOptions::new(); + /// options.mode(0o644); // Give read/write for owner and read for others. + /// let file = options.open("foo.txt").await?; + /// + /// Ok(()) + /// } + /// ``` + fn mode(&mut self, mode: u32) -> &mut Self; + + /// Pass custom flags to the `flags` argument of `open`. + /// + /// The bits that define the access mode are masked out with `O_ACCMODE`, to + /// ensure they do not interfere with the access mode set by Rusts options. + /// + /// Custom flags can only set flags, not remove flags set by Rusts options. + /// This options overwrites any previously set custom flags. + /// + /// # Examples + /// + /// ```no_run + /// use libc; + /// use tokio::fs::OpenOptions; + /// use tokio::fs::os::unix::OpenOptionsExt; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut options = OpenOptions::new(); + /// options.write(true); + /// if cfg!(unix) { + /// options.custom_flags(libc::O_NOFOLLOW); + /// } + /// let file = options.open("foo.txt").await?; + /// + /// Ok(()) + /// } + /// ``` + fn custom_flags(&mut self, flags: i32) -> &mut Self; +} + +impl OpenOptionsExt for OpenOptions { + fn mode(&mut self, mode: u32) -> &mut OpenOptions { + self.as_inner_mut().mode(mode); + self + } + + fn custom_flags(&mut self, flags: i32) -> &mut OpenOptions { + self.as_inner_mut().custom_flags(flags); + self + } +} diff --git a/tokio/src/fs/remove_dir_all.rs b/tokio/src/fs/remove_dir_all.rs index 3b2b2e0453e..0a237550f9c 100644 --- a/tokio/src/fs/remove_dir_all.rs +++ b/tokio/src/fs/remove_dir_all.rs @@ -7,7 +7,7 @@ use std::path::Path; /// /// This is an async version of [`std::fs::remove_dir_all`][std] /// -/// [std]: https://doc.rust-lang.org/std/fs/fn.remove_dir_all.html +/// [std]: fn@std::fs::remove_dir_all pub async fn remove_dir_all(path: impl AsRef) -> io::Result<()> { let path = path.as_ref().to_owned(); asyncify(move || std::fs::remove_dir_all(path)).await diff --git a/tokio/src/fs/set_permissions.rs b/tokio/src/fs/set_permissions.rs index b6249d13f00..09be02ea013 100644 --- a/tokio/src/fs/set_permissions.rs +++ b/tokio/src/fs/set_permissions.rs @@ -8,7 +8,7 @@ use std::path::Path; /// /// This is an async version of [`std::fs::set_permissions`][std] /// -/// [std]: https://doc.rust-lang.org/std/fs/fn.set_permissions.html +/// [std]: fn@std::fs::set_permissions pub async fn set_permissions(path: impl AsRef, perm: Permissions) -> io::Result<()> { let path = path.as_ref().to_owned(); asyncify(|| std::fs::set_permissions(path, perm)).await diff --git a/tokio/src/fs/symlink_metadata.rs b/tokio/src/fs/symlink_metadata.rs index 682b43a70eb..1d0df125760 100644 --- a/tokio/src/fs/symlink_metadata.rs +++ b/tokio/src/fs/symlink_metadata.rs @@ -8,7 +8,7 @@ use std::path::Path; /// /// This is an async version of [`std::fs::symlink_metadata`][std] /// -/// [std]: https://doc.rust-lang.org/std/fs/fn.symlink_metadata.html +/// [std]: fn@std::fs::symlink_metadata pub async fn symlink_metadata(path: impl AsRef) -> io::Result { let path = path.as_ref().to_owned(); asyncify(|| std::fs::symlink_metadata(path)).await diff --git a/tokio/src/future/ready.rs b/tokio/src/future/ready.rs index d74f999e5de..de2d60c13a2 100644 --- a/tokio/src/future/ready.rs +++ b/tokio/src/future/ready.rs @@ -2,7 +2,7 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -/// Future for the [`ready`](ready()) function. +/// Future for the [`ok`](ok()) function. /// /// `pub` in order to use the future as an associated type in a sealed trait. #[derive(Debug)] diff --git a/tokio/src/io/async_buf_read.rs b/tokio/src/io/async_buf_read.rs index 1c1fae36588..ecaafba4c27 100644 --- a/tokio/src/io/async_buf_read.rs +++ b/tokio/src/io/async_buf_read.rs @@ -7,14 +7,18 @@ use std::task::{Context, Poll}; /// Reads bytes asynchronously. /// -/// This trait inherits from [`std::io::BufRead`] and indicates that an I/O object is -/// **non-blocking**. All non-blocking I/O objects must return an error when -/// bytes are unavailable instead of blocking the current thread. +/// This trait is analogous to [`std::io::BufRead`], but integrates with +/// the asynchronous task system. In particular, the [`poll_fill_buf`] method, +/// unlike [`BufRead::fill_buf`], will automatically queue the current task for wakeup +/// and return if data is not yet available, rather than blocking the calling +/// thread. /// /// Utilities for working with `AsyncBufRead` values are provided by /// [`AsyncBufReadExt`]. /// /// [`std::io::BufRead`]: std::io::BufRead +/// [`poll_fill_buf`]: AsyncBufRead::poll_fill_buf +/// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf /// [`AsyncBufReadExt`]: crate::io::AsyncBufReadExt pub trait AsyncBufRead: AsyncRead { /// Attempts to return the contents of the internal buffer, filling it with more data diff --git a/tokio/src/io/async_read.rs b/tokio/src/io/async_read.rs index cc9091c99c1..1aef4150166 100644 --- a/tokio/src/io/async_read.rs +++ b/tokio/src/io/async_read.rs @@ -73,10 +73,10 @@ pub trait AsyncRead { /// that they did not write to. /// /// [`io::Read`]: std::io::Read - /// [`poll_read_buf`]: #method.poll_read_buf + /// [`poll_read_buf`]: method@Self::poll_read_buf unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { for x in buf { - *x.as_mut_ptr() = 0; + *x = MaybeUninit::new(0); } true diff --git a/tokio/src/io/driver/mod.rs b/tokio/src/io/driver/mod.rs index d8535d9ab27..dbfb6e16e3c 100644 --- a/tokio/src/io/driver/mod.rs +++ b/tokio/src/io/driver/mod.rs @@ -237,10 +237,14 @@ impl fmt::Debug for Handle { // ===== impl Inner ===== impl Inner { - /// Registers an I/O resource with the reactor. + /// Registers an I/O resource with the reactor for a given `mio::Ready` state. /// /// The registration token is returned. - pub(super) fn add_source(&self, source: &dyn Evented) -> io::Result
{ + pub(super) fn add_source( + &self, + source: &dyn Evented, + ready: mio::Ready, + ) -> io::Result
{ let address = self.io_dispatch.alloc().ok_or_else(|| { io::Error::new( io::ErrorKind::Other, @@ -253,7 +257,7 @@ impl Inner { self.io.register( source, mio::Token(address.to_usize()), - mio::Ready::all(), + ready, mio::PollOpt::edge(), )?; @@ -339,12 +343,12 @@ mod tests { let inner = reactor.inner; let inner2 = inner.clone(); - let token_1 = inner.add_source(&NotEvented).unwrap(); + let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); let thread = thread::spawn(move || { inner2.drop_source(token_1); }); - let token_2 = inner.add_source(&NotEvented).unwrap(); + let token_2 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); thread.join().unwrap(); assert!(token_1 != token_2); @@ -360,15 +364,15 @@ mod tests { // add sources to fill up the first page so that the dropped index // may be reused. for _ in 0..31 { - inner.add_source(&NotEvented).unwrap(); + inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); } - let token_1 = inner.add_source(&NotEvented).unwrap(); + let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); let thread = thread::spawn(move || { inner2.drop_source(token_1); }); - let token_2 = inner.add_source(&NotEvented).unwrap(); + let token_2 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); thread.join().unwrap(); assert!(token_1 != token_2); @@ -383,11 +387,11 @@ mod tests { let inner2 = inner.clone(); let thread = thread::spawn(move || { - let token_2 = inner2.add_source(&NotEvented).unwrap(); + let token_2 = inner2.add_source(&NotEvented, mio::Ready::all()).unwrap(); token_2 }); - let token_1 = inner.add_source(&NotEvented).unwrap(); + let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); let token_2 = thread.join().unwrap(); assert!(token_1 != token_2); diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs index 29d8bc5554e..9e0e063195c 100644 --- a/tokio/src/io/mod.rs +++ b/tokio/src/io/mod.rs @@ -15,19 +15,19 @@ //! type will _yield_ to the Tokio scheduler when IO is not ready, rather than //! blocking. This allows other tasks to run while waiting on IO. //! -//! Another difference is that [`AsyncRead`] and [`AsyncWrite`] only contain +//! Another difference is that `AsyncRead` and `AsyncWrite` only contain //! core methods needed to provide asynchronous reading and writing //! functionality. Instead, utility methods are defined in the [`AsyncReadExt`] //! and [`AsyncWriteExt`] extension traits. These traits are automatically -//! implemented for all values that implement [`AsyncRead`] and [`AsyncWrite`] +//! implemented for all values that implement `AsyncRead` and `AsyncWrite` //! respectively. //! -//! End users will rarely interact directly with [`AsyncRead`] and -//! [`AsyncWrite`]. Instead, they will use the async functions defined in the -//! extension traits. Library authors are expected to implement [`AsyncRead`] -//! and [`AsyncWrite`] in order to provide types that behave like byte streams. +//! End users will rarely interact directly with `AsyncRead` and +//! `AsyncWrite`. Instead, they will use the async functions defined in the +//! extension traits. Library authors are expected to implement `AsyncRead` +//! and `AsyncWrite` in order to provide types that behave like byte streams. //! -//! Even with these differences, Tokio's [`AsyncRead`] and [`AsyncWrite`] traits +//! Even with these differences, Tokio's `AsyncRead` and `AsyncWrite` traits //! can be used in almost exactly the same manner as the standard library's //! `Read` and `Write`. Most types in the standard library that implement `Read` //! and `Write` have asynchronous equivalents in `tokio` that implement @@ -57,7 +57,7 @@ //! [`File`]: crate::fs::File //! [`TcpStream`]: crate::net::TcpStream //! [`std::fs::File`]: std::fs::File -//! [std_example]: https://doc.rust-lang.org/std/io/index.html#read-and-write +//! [std_example]: std::io#read-and-write //! //! ## Buffered Readers and Writers //! @@ -93,7 +93,8 @@ //! ``` //! //! [`BufWriter`] doesn't add any new ways of writing; it just buffers every call -//! to [`write`](crate::io::AsyncWriteExt::write): +//! to [`write`](crate::io::AsyncWriteExt::write). However, you **must** flush +//! [`BufWriter`] to ensure that any buffered data is written. //! //! ```no_run //! use tokio::io::{self, BufWriter, AsyncWriteExt}; @@ -105,16 +106,19 @@ //! { //! let mut writer = BufWriter::new(f); //! -//! // write a byte to the buffer +//! // Write a byte to the buffer. //! writer.write(&[42u8]).await?; //! -//! } // the buffer is flushed once writer goes out of scope +//! // Flush the buffer before it goes out of scope. +//! writer.flush().await?; +//! +//! } // Unless flushed or shut down, the contents of the buffer is discarded on drop. //! //! Ok(()) //! } //! ``` //! -//! [stdbuf]: https://doc.rust-lang.org/std/io/index.html#bufreader-and-bufwriter +//! [stdbuf]: std::io#bufreader-and-bufwriter //! [`std::io::BufRead`]: std::io::BufRead //! [`AsyncBufRead`]: crate::io::AsyncBufRead //! [`BufReader`]: crate::io::BufReader @@ -122,12 +126,26 @@ //! //! ## Implementing AsyncRead and AsyncWrite //! -//! Because they are traits, we can implement `AsyncRead` and `AsyncWrite` for +//! Because they are traits, we can implement [`AsyncRead`] and [`AsyncWrite`] for //! our own types, as well. Note that these traits must only be implemented for //! non-blocking I/O types that integrate with the futures type system. In //! other words, these types must never block the thread, and instead the //! current task is notified when the I/O resource is ready. //! +//! ## Conversion to and from Sink/Stream +//! +//! It is often convenient to encapsulate the reading and writing of +//! bytes and instead work with a [`Sink`] or [`Stream`] of some data +//! type that is encoded as bytes and/or decoded from bytes. Tokio +//! provides some utility traits in the [tokio-util] crate that +//! abstract the asynchronous buffering that is required and allows +//! you to write [`Encoder`] and [`Decoder`] functions working with a +//! buffer of bytes, and then use that ["codec"] to transform anything +//! that implements [`AsyncRead`] and [`AsyncWrite`] into a `Sink`/`Stream` of +//! your structured data. +//! +//! [tokio-util]: https://docs.rs/tokio-util/0.3/tokio_util/codec/index.html +//! //! # Standard input and output //! //! Tokio provides asynchronous APIs to standard [input], [output], and [error]. @@ -144,15 +162,23 @@ //! //! # `std` re-exports //! -//! Additionally, [`Error`], [`ErrorKind`], and [`Result`] are re-exported -//! from `std::io` for ease of use. +//! Additionally, [`Error`], [`ErrorKind`], [`Result`], and [`SeekFrom`] are +//! re-exported from `std::io` for ease of use. //! //! [`AsyncRead`]: trait@AsyncRead //! [`AsyncWrite`]: trait@AsyncWrite +//! [`AsyncReadExt`]: trait@AsyncReadExt +//! [`AsyncWriteExt`]: trait@AsyncWriteExt +//! ["codec"]: https://docs.rs/tokio-util/0.3/tokio_util/codec/index.html +//! [`Encoder`]: https://docs.rs/tokio-util/0.3/tokio_util/codec/trait.Encoder.html +//! [`Decoder`]: https://docs.rs/tokio-util/0.3/tokio_util/codec/trait.Decoder.html //! [`Error`]: struct@Error //! [`ErrorKind`]: enum@ErrorKind //! [`Result`]: type@Result //! [`Read`]: std::io::Read +//! [`SeekFrom`]: enum@SeekFrom +//! [`Sink`]: https://docs.rs/futures/0.3/futures/sink/trait.Sink.html +//! [`Stream`]: crate::stream::Stream //! [`Write`]: std::io::Write cfg_io_blocking! { pub(crate) mod blocking; @@ -170,6 +196,10 @@ pub use self::async_seek::AsyncSeek; mod async_write; pub use self::async_write::AsyncWrite; +// Re-export some types from `std::io` so that users don't have to deal +// with conflicts when `use`ing `tokio::io` and `std::io`. +pub use std::io::{Error, ErrorKind, Result, SeekFrom}; + cfg_io_driver! { pub(crate) mod driver; @@ -200,17 +230,13 @@ cfg_io_util! { pub(crate) mod util; pub use util::{ - copy, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, - BufReader, BufStream, BufWriter, Copy, Empty, Lines, Repeat, Sink, Split, Take, + copy, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, + BufReader, BufStream, BufWriter, DuplexStream, Copy, Empty, Lines, Repeat, Sink, Split, Take, }; cfg_stream! { pub use util::{stream_reader, StreamReader}; } - - // Re-export io::Error so that users don't have to deal with conflicts when - // `use`ing `tokio::io` and `std::io`. - pub use std::io::{Error, ErrorKind, Result}; } cfg_not_io_util! { diff --git a/tokio/src/io/poll_evented.rs b/tokio/src/io/poll_evented.rs index 3ca30a82bf9..5295bd71ad8 100644 --- a/tokio/src/io/poll_evented.rs +++ b/tokio/src/io/poll_evented.rs @@ -90,17 +90,17 @@ cfg_io_driver! { /// These events are included as part of the read readiness event stream. The /// write readiness event stream is only for `Ready::writable()` events. /// - /// [`std::io::Read`]: https://doc.rust-lang.org/std/io/trait.Read.html - /// [`std::io::Write`]: https://doc.rust-lang.org/std/io/trait.Write.html + /// [`std::io::Read`]: trait@std::io::Read + /// [`std::io::Write`]: trait@std::io::Write /// [`AsyncRead`]: trait@AsyncRead /// [`AsyncWrite`]: trait@AsyncWrite - /// [`mio::Evented`]: https://docs.rs/mio/0.6/mio/trait.Evented.html + /// [`mio::Evented`]: trait@mio::Evented /// [`Registration`]: struct@Registration /// [`TcpListener`]: struct@crate::net::TcpListener - /// [`clear_read_ready`]: #method.clear_read_ready - /// [`clear_write_ready`]: #method.clear_write_ready - /// [`poll_read_ready`]: #method.poll_read_ready - /// [`poll_write_ready`]: #method.poll_write_ready + /// [`clear_read_ready`]: method@Self::clear_read_ready + /// [`clear_write_ready`]: method@Self::clear_write_ready + /// [`poll_read_ready`]: method@Self::poll_read_ready + /// [`poll_write_ready`]: method@Self::poll_write_ready pub struct PollEvented { io: Option, inner: Inner, @@ -175,7 +175,35 @@ where /// from a future driven by a tokio runtime, otherwise runtime can be set /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. pub fn new(io: E) -> io::Result { - let registration = Registration::new(&io)?; + PollEvented::new_with_ready(io, mio::Ready::all()) + } + + /// Creates a new `PollEvented` associated with the default reactor, for specific `mio::Ready` + /// state. `new_with_ready` should be used over `new` when you need control over the readiness + /// state, such as when a file descriptor only allows reads. This does not add `hup` or `error` + /// so if you are interested in those states, you will need to add them to the readiness state + /// passed to this function. + /// + /// An example to listen to read only + /// + /// ```rust + /// ##[cfg(unix)] + /// mio::Ready::from_usize( + /// mio::Ready::readable().as_usize() + /// | mio::unix::UnixReady::error().as_usize() + /// | mio::unix::UnixReady::hup().as_usize() + /// ); + /// ``` + /// + /// # Panics + /// + /// This function panics if thread-local runtime is not set. + /// + /// The runtime is usually set implicitly when this function is called + /// from a future driven by a tokio runtime, otherwise runtime can be set + /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. + pub fn new_with_ready(io: E, ready: mio::Ready) -> io::Result { + let registration = Registration::new_with_ready(&io, ready)?; Ok(Self { io: Some(io), inner: Inner { @@ -225,7 +253,7 @@ where /// The I/O resource will remain in a read-ready state until readiness is /// cleared by calling [`clear_read_ready`]. /// - /// [`clear_read_ready`]: #method.clear_read_ready + /// [`clear_read_ready`]: method@Self::clear_read_ready /// /// # Panics /// @@ -233,6 +261,11 @@ where /// /// * `ready` includes writable. /// * called from outside of a task context. + /// + /// # Warning + /// + /// This method may not be called concurrently. It takes `&self` to allow + /// calling it concurrently with `poll_write_ready`. pub fn poll_read_ready( &self, cx: &mut Context<'_>, @@ -291,7 +324,7 @@ where /// The I/O resource will remain in a write-ready state until readiness is /// cleared by calling [`clear_write_ready`]. /// - /// [`clear_write_ready`]: #method.clear_write_ready + /// [`clear_write_ready`]: method@Self::clear_write_ready /// /// # Panics /// @@ -299,6 +332,11 @@ where /// /// * `ready` contains bits besides `writable` and `hup`. /// * called from outside of a task context. + /// + /// # Warning + /// + /// This method may not be called concurrently. It takes `&self` to allow + /// calling it concurrently with `poll_read_ready`. pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { poll_ready!( self, diff --git a/tokio/src/io/registration.rs b/tokio/src/io/registration.rs index 4df11999f5e..77fe6dbc723 100644 --- a/tokio/src/io/registration.rs +++ b/tokio/src/io/registration.rs @@ -34,9 +34,9 @@ cfg_io_driver! { /// stream. The write readiness event stream is only for `Ready::writable()` /// events. /// - /// [`new`]: #method.new - /// [`poll_read_ready`]: #method.poll_read_ready`] - /// [`poll_write_ready`]: #method.poll_write_ready`] + /// [`new`]: method@Self::new + /// [`poll_read_ready`]: method@Self::poll_read_ready` + /// [`poll_write_ready`]: method@Self::poll_write_ready` #[derive(Debug)] pub struct Registration { handle: Handle, @@ -63,12 +63,49 @@ impl Registration { /// from a future driven by a tokio runtime, otherwise runtime can be set /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. pub fn new(io: &T) -> io::Result + where + T: Evented, + { + Registration::new_with_ready(io, mio::Ready::all()) + } + + /// Registers the I/O resource with the default reactor, for a specific `mio::Ready` state. + /// `new_with_ready` should be used over `new` when you need control over the readiness state, + /// such as when a file descriptor only allows reads. This does not add `hup` or `error` so if + /// you are interested in those states, you will need to add them to the readiness state passed + /// to this function. + /// + /// An example to listen to read only + /// + /// ```rust + /// ##[cfg(unix)] + /// mio::Ready::from_usize( + /// mio::Ready::readable().as_usize() + /// | mio::unix::UnixReady::error().as_usize() + /// | mio::unix::UnixReady::hup().as_usize() + /// ); + /// ``` + /// + /// # Return + /// + /// - `Ok` if the registration happened successfully + /// - `Err` if an error was encountered during registration + /// + /// + /// # Panics + /// + /// This function panics if thread-local runtime is not set. + /// + /// The runtime is usually set implicitly when this function is called + /// from a future driven by a tokio runtime, otherwise runtime can be set + /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. + pub fn new_with_ready(io: &T, ready: mio::Ready) -> io::Result where T: Evented, { let handle = Handle::current(); let address = if let Some(inner) = handle.inner() { - inner.add_source(io)? + inner.add_source(io, ready)? } else { return Err(io::Error::new( io::ErrorKind::Other, @@ -116,8 +153,6 @@ impl Registration { /// the function will always return `Ready(HUP)`. This should be treated as /// the end of the readiness stream. /// - /// Ensure that [`register`] has been called first. - /// /// # Return value /// /// There are several possible return values: @@ -129,22 +164,26 @@ impl Registration { /// since the last call to `poll_read_ready`. /// /// * `Poll::Ready(Err(err))` means that the registration has encountered an - /// error. This error either represents a permanent internal error **or** - /// the fact that [`register`] was not called first. + /// error. This could represent a permanent internal error for example. /// - /// [`register`]: #method.register - /// [edge-triggered]: https://docs.rs/mio/0.6/mio/struct.Poll.html#edge-triggered-and-level-triggered + /// [edge-triggered]: struct@mio::Poll#edge-triggered-and-level-triggered /// /// # Panics /// /// This function will panic if called from outside of a task context. pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); - let v = self.poll_ready(Direction::Read, Some(cx))?; + let v = self.poll_ready(Direction::Read, Some(cx)).map_err(|e| { + coop.made_progress(); + e + })?; match v { - Some(v) => Poll::Ready(Ok(v)), + Some(v) => { + coop.made_progress(); + Poll::Ready(Ok(v)) + } None => Poll::Pending, } } @@ -155,7 +194,7 @@ impl Registration { /// will not notify the current task when a new event is received. As such, /// it is safe to call this function from outside of a task context. /// - /// [`poll_read_ready`]: #method.poll_read_ready + /// [`poll_read_ready`]: method@Self::poll_read_ready pub fn take_read_ready(&self) -> io::Result> { self.poll_ready(Direction::Read, None) } @@ -170,8 +209,6 @@ impl Registration { /// the function will always return `Ready(HUP)`. This should be treated as /// the end of the readiness stream. /// - /// Ensure that [`register`] has been called first. - /// /// # Return value /// /// There are several possible return values: @@ -183,22 +220,26 @@ impl Registration { /// since the last call to `poll_write_ready`. /// /// * `Poll::Ready(Err(err))` means that the registration has encountered an - /// error. This error either represents a permanent internal error **or** - /// the fact that [`register`] was not called first. + /// error. This could represent a permanent internal error for example. /// - /// [`register`]: #method.register - /// [edge-triggered]: https://docs.rs/mio/0.6/mio/struct.Poll.html#edge-triggered-and-level-triggered + /// [edge-triggered]: struct@mio::Poll#edge-triggered-and-level-triggered /// /// # Panics /// /// This function will panic if called from outside of a task context. pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); - let v = self.poll_ready(Direction::Write, Some(cx))?; + let v = self.poll_ready(Direction::Write, Some(cx)).map_err(|e| { + coop.made_progress(); + e + })?; match v { - Some(v) => Poll::Ready(Ok(v)), + Some(v) => { + coop.made_progress(); + Poll::Ready(Ok(v)) + } None => Poll::Pending, } } @@ -209,7 +250,7 @@ impl Registration { /// will not notify the current task when a new event is received. As such, /// it is safe to call this function from outside of a task context. /// - /// [`poll_write_ready`]: #method.poll_write_ready + /// [`poll_write_ready`]: method@Self::poll_write_ready pub fn take_write_ready(&self) -> io::Result> { self.poll_ready(Direction::Write, None) } diff --git a/tokio/src/io/stdin.rs b/tokio/src/io/stdin.rs index d986d3abeaf..325b8757ec1 100644 --- a/tokio/src/io/stdin.rs +++ b/tokio/src/io/stdin.rs @@ -63,6 +63,11 @@ impl std::os::windows::io::AsRawHandle for Stdin { } impl AsyncRead for Stdin { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { + // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/io/stdio.rs#L97 + false + } + fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/tokio/src/io/util/async_read_ext.rs b/tokio/src/io/util/async_read_ext.rs index d4402db621b..dd280a188d7 100644 --- a/tokio/src/io/util/async_read_ext.rs +++ b/tokio/src/io/util/async_read_ext.rs @@ -2,8 +2,12 @@ use crate::io::util::chain::{chain, Chain}; use crate::io::util::read::{read, Read}; use crate::io::util::read_buf::{read_buf, ReadBuf}; use crate::io::util::read_exact::{read_exact, ReadExact}; -use crate::io::util::read_int::{ReadI128, ReadI16, ReadI32, ReadI64, ReadI8}; -use crate::io::util::read_int::{ReadU128, ReadU16, ReadU32, ReadU64, ReadU8}; +use crate::io::util::read_int::{ + ReadI128, ReadI128Le, ReadI16, ReadI16Le, ReadI32, ReadI32Le, ReadI64, ReadI64Le, ReadI8, +}; +use crate::io::util::read_int::{ + ReadU128, ReadU128Le, ReadU16, ReadU16Le, ReadU32, ReadU32Le, ReadU64, ReadU64Le, ReadU8, +}; use crate::io::util::read_to_end::{read_to_end, ReadToEnd}; use crate::io::util::read_to_string::{read_to_string, ReadToString}; use crate::io::util::take::{take, Take}; @@ -221,7 +225,7 @@ cfg_io_util! { /// ``` fn read_buf<'a, B>(&'a mut self, buf: &'a mut B) -> ReadBuf<'a, Self, B> where - Self: Sized, + Self: Sized + Unpin, B: BufMut, { read_buf(self, buf) @@ -238,12 +242,6 @@ cfg_io_util! { /// This function reads as many bytes as necessary to completely fill /// the specified buffer `buf`. /// - /// No guarantees are provided about the contents of `buf` when this - /// function is called, implementations cannot rely on any property of - /// the contents of `buf` being `true`. It is recommended that - /// implementations only write data to `buf` instead of reading its - /// contents. - /// /// # Errors /// /// If the operation encounters an "end of file" before completely @@ -669,6 +667,313 @@ cfg_io_util! { /// } /// ``` fn read_i128(&mut self) -> ReadI128; + + /// Reads an unsigned 16-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u16_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 16 bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![2, 5, 3, 0]); + /// + /// assert_eq!(1282, reader.read_u16_le().await?); + /// assert_eq!(3, reader.read_u16_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_u16_le(&mut self) -> ReadU16Le; + + /// Reads a signed 16-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i16_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 16 bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0x00, 0xc1, 0xff, 0x7c]); + /// + /// assert_eq!(-16128, reader.read_i16_le().await?); + /// assert_eq!(31999, reader.read_i16_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_i16_le(&mut self) -> ReadI16Le; + + /// Reads an unsigned 32-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u32_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 32-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0x00, 0x00, 0x01, 0x0b]); + /// + /// assert_eq!(184614912, reader.read_u32_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_u32_le(&mut self) -> ReadU32Le; + + /// Reads a signed 32-bit integer in little-endian order from the + /// underlying reader. + /// + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i32_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 32-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0xff, 0xff, 0x7a, 0x33]); + /// + /// assert_eq!(863698943, reader.read_i32_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_i32_le(&mut self) -> ReadI32Le; + + /// Reads an unsigned 64-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u64_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 64-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0x00, 0x03, 0x43, 0x95, 0x4d, 0x60, 0x86, 0x83 + /// ]); + /// + /// assert_eq!(9477368352180732672, reader.read_u64_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_u64_le(&mut self) -> ReadU64Le; + + /// Reads an signed 64-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i64_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 64-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0x80, 0, 0, 0, 0, 0, 0, 0]); + /// + /// assert_eq!(128, reader.read_i64_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_i64_le(&mut self) -> ReadI64Le; + + /// Reads an unsigned 128-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u128_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 128-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0x00, 0x03, 0x43, 0x95, 0x4d, 0x60, 0x86, 0x83, + /// 0x00, 0x03, 0x43, 0x95, 0x4d, 0x60, 0x86, 0x83 + /// ]); + /// + /// assert_eq!(174826588484952389081207917399662330624, reader.read_u128_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_u128_le(&mut self) -> ReadU128Le; + + /// Reads an signed 128-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i128_le(&mut self) -> io::Result; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 128-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0x80, 0, 0, 0, 0, 0, 0, 0, + /// 0, 0, 0, 0, 0, 0, 0, 0 + /// ]); + /// + /// assert_eq!(128, reader.read_i128_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_i128_le(&mut self) -> ReadI128Le; } /// Reads all bytes until EOF in this source, placing them into `buf`. @@ -681,10 +986,12 @@ cfg_io_util! { /// /// All bytes read from this source will be appended to the specified /// buffer `buf`. This function will continuously call [`read()`] to - /// append more data to `buf` until [`read()`][read] returns `Ok(0)`. + /// append more data to `buf` until [`read()`] returns `Ok(0)`. /// /// If successful, the total number of bytes read is returned. /// + /// [`read()`]: AsyncReadExt::read + /// /// # Errors /// /// If a read error is encountered then the `read_to_end` operation @@ -713,7 +1020,7 @@ cfg_io_util! { /// (See also the [`tokio::fs::read`] convenience function for reading from a /// file.) /// - /// [`tokio::fs::read`]: crate::fs::read::read + /// [`tokio::fs::read`]: fn@crate::fs::read fn read_to_end<'a>(&'a mut self, buf: &'a mut Vec) -> ReadToEnd<'a, Self> where Self: Unpin, @@ -773,7 +1080,11 @@ cfg_io_util! { /// This function returns a new instance of `AsyncRead` which will read /// at most `limit` bytes, after which it will always return EOF /// (`Ok(0)`). Any read errors will not count towards the number of - /// bytes read and future calls to [`read()`][read] may succeed. + /// bytes read and future calls to [`read()`] may succeed. + /// + /// [`read()`]: fn@crate::io::AsyncReadExt::read + /// + /// [read]: AsyncReadExt::read /// /// # Examples /// diff --git a/tokio/src/io/util/async_write_ext.rs b/tokio/src/io/util/async_write_ext.rs index 377f4ecaf80..321301e2897 100644 --- a/tokio/src/io/util/async_write_ext.rs +++ b/tokio/src/io/util/async_write_ext.rs @@ -3,8 +3,14 @@ use crate::io::util::shutdown::{shutdown, Shutdown}; use crate::io::util::write::{write, Write}; use crate::io::util::write_all::{write_all, WriteAll}; use crate::io::util::write_buf::{write_buf, WriteBuf}; -use crate::io::util::write_int::{WriteI128, WriteI16, WriteI32, WriteI64, WriteI8}; -use crate::io::util::write_int::{WriteU128, WriteU16, WriteU32, WriteU64, WriteU8}; +use crate::io::util::write_int::{ + WriteI128, WriteI128Le, WriteI16, WriteI16Le, WriteI32, WriteI32Le, WriteI64, WriteI64Le, + WriteI8, +}; +use crate::io::util::write_int::{ + WriteU128, WriteU128Le, WriteU16, WriteU16Le, WriteU32, WriteU32Le, WriteU64, WriteU64Le, + WriteU8, +}; use crate::io::AsyncWrite; use bytes::Buf; @@ -180,7 +186,7 @@ cfg_io_util! { /// ``` fn write_buf<'a, B>(&'a mut self, src: &'a mut B) -> WriteBuf<'a, Self, B> where - Self: Sized, + Self: Sized + Unpin, B: Buf, { write_buf(self, src) @@ -608,6 +614,315 @@ cfg_io_util! { /// } /// ``` fn write_i128(&mut self, n: i128) -> WriteI128; + + + /// Writes an unsigned 16-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u16_le(&mut self, n: u16) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 16-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u16_le(517).await?; + /// writer.write_u16_le(768).await?; + /// + /// assert_eq!(writer, b"\x05\x02\x00\x03"); + /// Ok(()) + /// } + /// ``` + fn write_u16_le(&mut self, n: u16) -> WriteU16Le; + + /// Writes a signed 16-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i16_le(&mut self, n: i16) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 16-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i16_le(193).await?; + /// writer.write_i16_le(-132).await?; + /// + /// assert_eq!(writer, b"\xc1\x00\x7c\xff"); + /// Ok(()) + /// } + /// ``` + fn write_i16_le(&mut self, n: i16) -> WriteI16Le; + + /// Writes an unsigned 32-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u32_le(&mut self, n: u32) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 32-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u32_le(267).await?; + /// writer.write_u32_le(1205419366).await?; + /// + /// assert_eq!(writer, b"\x0b\x01\x00\x00\x66\x3d\xd9\x47"); + /// Ok(()) + /// } + /// ``` + fn write_u32_le(&mut self, n: u32) -> WriteU32Le; + + /// Writes a signed 32-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i32_le(&mut self, n: i32) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 32-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i32_le(267).await?; + /// writer.write_i32_le(1205419366).await?; + /// + /// assert_eq!(writer, b"\x0b\x01\x00\x00\x66\x3d\xd9\x47"); + /// Ok(()) + /// } + /// ``` + fn write_i32_le(&mut self, n: i32) -> WriteI32Le; + + /// Writes an unsigned 64-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u64_le(&mut self, n: u64) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 64-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u64_le(918733457491587).await?; + /// writer.write_u64_le(143).await?; + /// + /// assert_eq!(writer, b"\x83\x86\x60\x4d\x95\x43\x03\x00\x8f\x00\x00\x00\x00\x00\x00\x00"); + /// Ok(()) + /// } + /// ``` + fn write_u64_le(&mut self, n: u64) -> WriteU64Le; + + /// Writes an signed 64-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i64_le(&mut self, n: i64) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 64-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i64_le(i64::min_value()).await?; + /// writer.write_i64_le(i64::max_value()).await?; + /// + /// assert_eq!(writer, b"\x00\x00\x00\x00\x00\x00\x00\x80\xff\xff\xff\xff\xff\xff\xff\x7f"); + /// Ok(()) + /// } + /// ``` + fn write_i64_le(&mut self, n: i64) -> WriteI64Le; + + /// Writes an unsigned 128-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u128_le(&mut self, n: u128) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 128-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u128_le(16947640962301618749969007319746179).await?; + /// + /// assert_eq!(writer, vec![ + /// 0x83, 0x86, 0x60, 0x4d, 0x95, 0x43, 0x03, 0x00, + /// 0x83, 0x86, 0x60, 0x4d, 0x95, 0x43, 0x03, 0x00, + /// ]); + /// Ok(()) + /// } + /// ``` + fn write_u128_le(&mut self, n: u128) -> WriteU128Le; + + /// Writes an signed 128-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i128_le(&mut self, n: i128) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 128-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i128_le(i128::min_value()).await?; + /// + /// assert_eq!(writer, vec![ + /// 0, 0, 0, 0, 0, 0, 0, + /// 0, 0, 0, 0, 0, 0, 0, 0, 0x80 + /// ]); + /// Ok(()) + /// } + /// ``` + fn write_i128_le(&mut self, n: i128) -> WriteI128Le; } /// Flushes this output stream, ensuring that all intermediately buffered @@ -661,6 +976,8 @@ cfg_io_util! { /// no longer attempt to write to the stream. For example, the /// `TcpStream` implementation will issue a `shutdown(Write)` sys call. /// + /// [`flush`]: fn@crate::io::AsyncWriteExt::flush + /// /// # Examples /// /// ```no_run diff --git a/tokio/src/io/util/buf_reader.rs b/tokio/src/io/util/buf_reader.rs index 0177c0e344a..a1c5990a644 100644 --- a/tokio/src/io/util/buf_reader.rs +++ b/tokio/src/io/util/buf_reader.rs @@ -1,6 +1,7 @@ use crate::io::util::DEFAULT_BUF_SIZE; use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use bytes::Buf; use pin_project_lite::pin_project; use std::io::{self, Read}; use std::mem::MaybeUninit; @@ -82,7 +83,7 @@ impl BufReader { self.project().inner } - /// Consumes this `BufWriter`, returning the underlying reader. + /// Consumes this `BufReader`, returning the underlying reader. /// /// Note that any leftover data in the internal buffer is lost. pub fn into_inner(self) -> R { @@ -162,6 +163,14 @@ impl AsyncWrite for BufReader { self.get_pin_mut().poll_write(cx, buf) } + fn poll_write_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> { + self.get_pin_mut().poll_write_buf(cx, buf) + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.get_pin_mut().poll_flush(cx) } diff --git a/tokio/src/io/util/chain.rs b/tokio/src/io/util/chain.rs index bc76af341da..8ba9194f5de 100644 --- a/tokio/src/io/util/chain.rs +++ b/tokio/src/io/util/chain.rs @@ -84,6 +84,15 @@ where T: AsyncRead, U: AsyncRead, { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit]) -> bool { + if self.first.prepare_uninitialized_buffer(buf) { + return true; + } + if self.second.prepare_uninitialized_buffer(buf) { + return true; + } + false + } fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/tokio/src/io/util/copy.rs b/tokio/src/io/util/copy.rs index 8e0058c1c2d..7bfe296941e 100644 --- a/tokio/src/io/util/copy.rs +++ b/tokio/src/io/util/copy.rs @@ -70,7 +70,7 @@ cfg_io_util! { amt: 0, pos: 0, cap: 0, - buf: Box::new([0; 2048]), + buf: vec![0; 2048].into_boxed_slice(), } } } diff --git a/tokio/src/io/util/empty.rs b/tokio/src/io/util/empty.rs index 121102c78f2..576058d52d1 100644 --- a/tokio/src/io/util/empty.rs +++ b/tokio/src/io/util/empty.rs @@ -47,6 +47,9 @@ cfg_io_util! { } impl AsyncRead for Empty { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { + false + } #[inline] fn poll_read( self: Pin<&mut Self>, diff --git a/tokio/src/io/util/flush.rs b/tokio/src/io/util/flush.rs index 1465f304486..534a5160c1a 100644 --- a/tokio/src/io/util/flush.rs +++ b/tokio/src/io/util/flush.rs @@ -8,7 +8,8 @@ use std::task::{Context, Poll}; cfg_io_util! { /// A future used to fully flush an I/O object. /// - /// Created by the [`AsyncWriteExt::flush`] function. + /// Created by the [`AsyncWriteExt::flush`][flush] function. + /// [flush]: crate::io::AsyncWriteExt::flush #[derive(Debug)] pub struct Flush<'a, A: ?Sized> { a: &'a mut A, diff --git a/tokio/src/io/util/lines.rs b/tokio/src/io/util/lines.rs index be4e86648d8..ee27400c9de 100644 --- a/tokio/src/io/util/lines.rs +++ b/tokio/src/io/util/lines.rs @@ -91,6 +91,7 @@ where let me = self.project(); let n = ready!(read_line_internal(me.reader, cx, me.buf, me.bytes, me.read))?; + debug_assert_eq!(*me.read, 0); if n == 0 && me.buf.is_empty() { return Poll::Ready(Ok(None)); diff --git a/tokio/src/io/util/mem.rs b/tokio/src/io/util/mem.rs new file mode 100644 index 00000000000..02ba6aa7e91 --- /dev/null +++ b/tokio/src/io/util/mem.rs @@ -0,0 +1,222 @@ +//! In-process memory IO types. + +use crate::io::{AsyncRead, AsyncWrite}; +use crate::loom::sync::Mutex; + +use bytes::{Buf, BytesMut}; +use std::{ + pin::Pin, + sync::Arc, + task::{self, Poll, Waker}, +}; + +/// A bidirectional pipe to read and write bytes in memory. +/// +/// A pair of `DuplexStream`s are created together, and they act as a "channel" +/// that can be used as in-memory IO types. Writing to one of the pairs will +/// allow that data to be read from the other, and vice versa. +/// +/// # Example +/// +/// ``` +/// # async fn ex() -> std::io::Result<()> { +/// # use tokio::io::{AsyncReadExt, AsyncWriteExt}; +/// let (mut client, mut server) = tokio::io::duplex(64); +/// +/// client.write_all(b"ping").await?; +/// +/// let mut buf = [0u8; 4]; +/// server.read_exact(&mut buf).await?; +/// assert_eq!(&buf, b"ping"); +/// +/// server.write_all(b"pong").await?; +/// +/// client.read_exact(&mut buf).await?; +/// assert_eq!(&buf, b"pong"); +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug)] +pub struct DuplexStream { + read: Arc>, + write: Arc>, +} + +/// A unidirectional IO over a piece of memory. +/// +/// Data can be written to the pipe, and reading will return that data. +#[derive(Debug)] +struct Pipe { + /// The buffer storing the bytes written, also read from. + /// + /// Using a `BytesMut` because it has efficient `Buf` and `BufMut` + /// functionality already. Additionally, it can try to copy data in the + /// same buffer if there read index has advanced far enough. + buffer: BytesMut, + /// Determines if the write side has been closed. + is_closed: bool, + /// The maximum amount of bytes that can be written before returning + /// `Poll::Pending`. + max_buf_size: usize, + /// If the `read` side has been polled and is pending, this is the waker + /// for that parked task. + read_waker: Option, + /// If the `write` side has filled the `max_buf_size` and returned + /// `Poll::Pending`, this is the waker for that parked task. + write_waker: Option, +} + +// ===== impl DuplexStream ===== + +/// Create a new pair of `DuplexStream`s that act like a pair of connected sockets. +/// +/// The `max_buf_size` argument is the maximum amount of bytes that can be +/// written to a side before the write returns `Poll::Pending`. +pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) { + let one = Arc::new(Mutex::new(Pipe::new(max_buf_size))); + let two = Arc::new(Mutex::new(Pipe::new(max_buf_size))); + + ( + DuplexStream { + read: one.clone(), + write: two.clone(), + }, + DuplexStream { + read: two, + write: one, + }, + ) +} + +impl AsyncRead for DuplexStream { + // Previous rustc required this `self` to be `mut`, even though newer + // versions recognize it isn't needed to call `lock()`. So for + // compatibility, we include the `mut` and `allow` the lint. + // + // See https://github.com/rust-lang/rust/issues/73592 + #[allow(unused_mut)] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut *self.read.lock().unwrap()).poll_read(cx, buf) + } +} + +impl AsyncWrite for DuplexStream { + #[allow(unused_mut)] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut *self.write.lock().unwrap()).poll_write(cx, buf) + } + + #[allow(unused_mut)] + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll> { + Pin::new(&mut *self.write.lock().unwrap()).poll_flush(cx) + } + + #[allow(unused_mut)] + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll> { + Pin::new(&mut *self.write.lock().unwrap()).poll_shutdown(cx) + } +} + +impl Drop for DuplexStream { + fn drop(&mut self) { + // notify the other side of the closure + self.write.lock().unwrap().close(); + } +} + +// ===== impl Pipe ===== + +impl Pipe { + fn new(max_buf_size: usize) -> Self { + Pipe { + buffer: BytesMut::new(), + is_closed: false, + max_buf_size, + read_waker: None, + write_waker: None, + } + } + + fn close(&mut self) { + self.is_closed = true; + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + } +} + +impl AsyncRead for Pipe { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut [u8], + ) -> Poll> { + if self.buffer.has_remaining() { + let max = self.buffer.remaining().min(buf.len()); + self.buffer.copy_to_slice(&mut buf[..max]); + if max > 0 { + // The passed `buf` might have been empty, don't wake up if + // no bytes have been moved. + if let Some(waker) = self.write_waker.take() { + waker.wake(); + } + } + Poll::Ready(Ok(max)) + } else if self.is_closed { + Poll::Ready(Ok(0)) + } else { + self.read_waker = Some(cx.waker().clone()); + Poll::Pending + } + } +} + +impl AsyncWrite for Pipe { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll> { + if self.is_closed { + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + } + let avail = self.max_buf_size - self.buffer.len(); + if avail == 0 { + self.write_waker = Some(cx.waker().clone()); + return Poll::Pending; + } + + let len = buf.len().min(avail); + self.buffer.extend_from_slice(&buf[..len]); + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + Poll::Ready(Ok(len)) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + _: &mut task::Context<'_>, + ) -> Poll> { + self.close(); + Poll::Ready(Ok(())) + } +} diff --git a/tokio/src/io/util/mod.rs b/tokio/src/io/util/mod.rs index c4754abf051..609ff2386a6 100644 --- a/tokio/src/io/util/mod.rs +++ b/tokio/src/io/util/mod.rs @@ -35,6 +35,9 @@ cfg_io_util! { mod lines; pub use lines::Lines; + mod mem; + pub use mem::{duplex, DuplexStream}; + mod read; mod read_buf; mod read_exact; diff --git a/tokio/src/io/util/read_buf.rs b/tokio/src/io/util/read_buf.rs index 550499b9334..6ee3d249f82 100644 --- a/tokio/src/io/util/read_buf.rs +++ b/tokio/src/io/util/read_buf.rs @@ -8,14 +8,14 @@ use std::task::{Context, Poll}; pub(crate) fn read_buf<'a, R, B>(reader: &'a mut R, buf: &'a mut B) -> ReadBuf<'a, R, B> where - R: AsyncRead, + R: AsyncRead + Unpin, B: BufMut, { ReadBuf { reader, buf } } cfg_io_util! { - /// Future returned by [`read_buf`](AsyncReadExt::read_buf). + /// Future returned by [`read_buf`](crate::io::AsyncReadExt::read_buf). #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadBuf<'a, R, B> { @@ -26,16 +26,13 @@ cfg_io_util! { impl Future for ReadBuf<'_, R, B> where - R: AsyncRead, + R: AsyncRead + Unpin, B: BufMut, { type Output = io::Result; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // safety: no data is moved from self - unsafe { - let me = self.get_unchecked_mut(); - Pin::new_unchecked(&mut *me.reader).poll_read_buf(cx, &mut me.buf) - } + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = &mut *self; + Pin::new(&mut *me.reader).poll_read_buf(cx, me.buf) } } diff --git a/tokio/src/io/util/read_exact.rs b/tokio/src/io/util/read_exact.rs index d6983c99530..86b8412954b 100644 --- a/tokio/src/io/util/read_exact.rs +++ b/tokio/src/io/util/read_exact.rs @@ -9,7 +9,8 @@ use std::task::{Context, Poll}; /// A future which can be used to easily read exactly enough bytes to fill /// a buffer. /// -/// Created by the [`AsyncRead::read_exact`]. +/// Created by the [`AsyncReadExt::read_exact`][read_exact]. +/// [read_exact]: [crate::io::AsyncReadExt::read_exact] pub(crate) fn read_exact<'a, A>(reader: &'a mut A, buf: &'a mut [u8]) -> ReadExact<'a, A> where A: AsyncRead + Unpin + ?Sized, diff --git a/tokio/src/io/util/read_int.rs b/tokio/src/io/util/read_int.rs index 9dc4402f88f..9d37dc7a400 100644 --- a/tokio/src/io/util/read_int.rs +++ b/tokio/src/io/util/read_int.rs @@ -121,3 +121,13 @@ reader!(ReadI16, i16, get_i16); reader!(ReadI32, i32, get_i32); reader!(ReadI64, i64, get_i64); reader!(ReadI128, i128, get_i128); + +reader!(ReadU16Le, u16, get_u16_le); +reader!(ReadU32Le, u32, get_u32_le); +reader!(ReadU64Le, u64, get_u64_le); +reader!(ReadU128Le, u128, get_u128_le); + +reader!(ReadI16Le, i16, get_i16_le); +reader!(ReadI32Le, i32, get_i32_le); +reader!(ReadI64Le, i64, get_i64_le); +reader!(ReadI128Le, i128, get_i128_le); diff --git a/tokio/src/io/util/read_line.rs b/tokio/src/io/util/read_line.rs index c5ee597486f..d625a76b80a 100644 --- a/tokio/src/io/util/read_line.rs +++ b/tokio/src/io/util/read_line.rs @@ -5,7 +5,6 @@ use std::future::Future; use std::io; use std::mem; use std::pin::Pin; -use std::str; use std::task::{Context, Poll}; cfg_io_util! { @@ -14,45 +13,72 @@ cfg_io_util! { #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadLine<'a, R: ?Sized> { reader: &'a mut R, - buf: &'a mut String, - bytes: Vec, + /// This is the buffer we were provided. It will be replaced with an empty string + /// while reading to postpone utf-8 handling until after reading. + output: &'a mut String, + /// The actual allocation of the string is moved into a vector instead. + buf: Vec, + /// The number of bytes appended to buf. This can be less than buf.len() if + /// the buffer was not empty when the operation was started. read: usize, } } -pub(crate) fn read_line<'a, R>(reader: &'a mut R, buf: &'a mut String) -> ReadLine<'a, R> +pub(crate) fn read_line<'a, R>(reader: &'a mut R, string: &'a mut String) -> ReadLine<'a, R> where R: AsyncBufRead + ?Sized + Unpin, { ReadLine { reader, - bytes: unsafe { mem::replace(buf.as_mut_vec(), Vec::new()) }, - buf, + buf: mem::replace(string, String::new()).into_bytes(), + output: string, read: 0, } } +fn put_back_original_data(output: &mut String, mut vector: Vec, num_bytes_read: usize) { + let original_len = vector.len() - num_bytes_read; + vector.truncate(original_len); + *output = String::from_utf8(vector).expect("The original data must be valid utf-8."); +} + pub(super) fn read_line_internal( reader: Pin<&mut R>, cx: &mut Context<'_>, - buf: &mut String, - bytes: &mut Vec, + output: &mut String, + buf: &mut Vec, read: &mut usize, ) -> Poll> { - let ret = ready!(read_until_internal(reader, cx, b'\n', bytes, read)); - if str::from_utf8(&bytes).is_err() { - Poll::Ready(ret.and_then(|_| { - Err(io::Error::new( + let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read)); + let utf8_res = String::from_utf8(mem::replace(buf, Vec::new())); + + // At this point both buf and output are empty. The allocation is in utf8_res. + + debug_assert!(buf.is_empty()); + match (io_res, utf8_res) { + (Ok(num_bytes), Ok(string)) => { + debug_assert_eq!(*read, 0); + *output = string; + Poll::Ready(Ok(num_bytes)) + } + (Err(io_err), Ok(string)) => { + *output = string; + Poll::Ready(Err(io_err)) + } + (Ok(num_bytes), Err(utf8_err)) => { + debug_assert_eq!(*read, 0); + put_back_original_data(output, utf8_err.into_bytes(), num_bytes); + + Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidData, "stream did not contain valid UTF-8", - )) - })) - } else { - debug_assert!(buf.is_empty()); - debug_assert_eq!(*read, 0); - // Safety: `bytes` is a valid UTF-8 because `str::from_utf8` returned `Ok`. - mem::swap(unsafe { buf.as_mut_vec() }, bytes); - Poll::Ready(ret) + ))) + } + (Err(io_err), Err(utf8_err)) => { + put_back_original_data(output, utf8_err.into_bytes(), *read); + + Poll::Ready(Err(io_err)) + } } } @@ -62,11 +88,12 @@ impl Future for ReadLine<'_, R> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let Self { reader, + output, buf, - bytes, read, } = &mut *self; - read_line_internal(Pin::new(reader), cx, buf, bytes, read) + + read_line_internal(Pin::new(reader), cx, output, buf, read) } } diff --git a/tokio/src/io/util/read_until.rs b/tokio/src/io/util/read_until.rs index 1adeda66f05..78dac8c2a14 100644 --- a/tokio/src/io/util/read_until.rs +++ b/tokio/src/io/util/read_until.rs @@ -8,19 +8,22 @@ use std::task::{Context, Poll}; cfg_io_util! { /// Future for the [`read_until`](crate::io::AsyncBufReadExt::read_until) method. + /// The delimeter is included in the resulting vector. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadUntil<'a, R: ?Sized> { reader: &'a mut R, - byte: u8, + delimeter: u8, buf: &'a mut Vec, + /// The number of bytes appended to buf. This can be less than buf.len() if + /// the buffer was not empty when the operation was started. read: usize, } } pub(crate) fn read_until<'a, R>( reader: &'a mut R, - byte: u8, + delimeter: u8, buf: &'a mut Vec, ) -> ReadUntil<'a, R> where @@ -28,7 +31,7 @@ where { ReadUntil { reader, - byte, + delimeter, buf, read: 0, } @@ -37,14 +40,14 @@ where pub(super) fn read_until_internal( mut reader: Pin<&mut R>, cx: &mut Context<'_>, - byte: u8, + delimeter: u8, buf: &mut Vec, read: &mut usize, ) -> Poll> { loop { let (done, used) = { let available = ready!(reader.as_mut().poll_fill_buf(cx))?; - if let Some(i) = memchr::memchr(byte, available) { + if let Some(i) = memchr::memchr(delimeter, available) { buf.extend_from_slice(&available[..=i]); (true, i + 1) } else { @@ -66,11 +69,11 @@ impl Future for ReadUntil<'_, R> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let Self { reader, - byte, + delimeter, buf, read, } = &mut *self; - read_until_internal(Pin::new(reader), cx, *byte, buf, read) + read_until_internal(Pin::new(reader), cx, *delimeter, buf, read) } } diff --git a/tokio/src/io/util/repeat.rs b/tokio/src/io/util/repeat.rs index 6b9067e8534..eeef7cc187b 100644 --- a/tokio/src/io/util/repeat.rs +++ b/tokio/src/io/util/repeat.rs @@ -47,6 +47,9 @@ cfg_io_util! { } impl AsyncRead for Repeat { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { + false + } #[inline] fn poll_read( self: Pin<&mut Self>, diff --git a/tokio/src/io/util/shutdown.rs b/tokio/src/io/util/shutdown.rs index f24e2885414..33ac0ac0db7 100644 --- a/tokio/src/io/util/shutdown.rs +++ b/tokio/src/io/util/shutdown.rs @@ -8,7 +8,8 @@ use std::task::{Context, Poll}; cfg_io_util! { /// A future used to shutdown an I/O object. /// - /// Created by the [`AsyncWriteExt::shutdown`] function. + /// Created by the [`AsyncWriteExt::shutdown`][shutdown] function. + /// [shutdown]: crate::io::AsyncWriteExt::shutdown #[derive(Debug)] pub struct Shutdown<'a, A: ?Sized> { a: &'a mut A, diff --git a/tokio/src/io/util/split.rs b/tokio/src/io/util/split.rs index f1ed2fd89d3..f552ed503d1 100644 --- a/tokio/src/io/util/split.rs +++ b/tokio/src/io/util/split.rs @@ -75,6 +75,8 @@ where let n = ready!(read_until_internal( me.reader, cx, *me.delim, me.buf, me.read, ))?; + // read_until_internal resets me.read to zero once it finds the delimeter + debug_assert_eq!(*me.read, 0); if n == 0 && me.buf.is_empty() { return Poll::Ready(Ok(None)); diff --git a/tokio/src/io/util/write_buf.rs b/tokio/src/io/util/write_buf.rs index e49282fe0c4..cedfde64e6e 100644 --- a/tokio/src/io/util/write_buf.rs +++ b/tokio/src/io/util/write_buf.rs @@ -20,7 +20,7 @@ cfg_io_util! { /// asynchronous manner, returning a future. pub(crate) fn write_buf<'a, W, B>(writer: &'a mut W, buf: &'a mut B) -> WriteBuf<'a, W, B> where - W: AsyncWrite, + W: AsyncWrite + Unpin, B: Buf, { WriteBuf { writer, buf } @@ -28,16 +28,13 @@ where impl Future for WriteBuf<'_, W, B> where - W: AsyncWrite, + W: AsyncWrite + Unpin, B: Buf, { type Output = io::Result; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // safety: no data is moved from self - unsafe { - let me = self.get_unchecked_mut(); - Pin::new_unchecked(&mut *me.writer).poll_write_buf(cx, &mut me.buf) - } + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = &mut *self; + Pin::new(&mut *me.writer).poll_write_buf(cx, me.buf) } } diff --git a/tokio/src/io/util/write_int.rs b/tokio/src/io/util/write_int.rs index 672c35f0768..ee992de1832 100644 --- a/tokio/src/io/util/write_int.rs +++ b/tokio/src/io/util/write_int.rs @@ -120,3 +120,13 @@ writer!(WriteI16, i16, put_i16); writer!(WriteI32, i32, put_i32); writer!(WriteI64, i64, put_i64); writer!(WriteI128, i128, put_i128); + +writer!(WriteU16Le, u16, put_u16_le); +writer!(WriteU32Le, u32, put_u32_le); +writer!(WriteU64Le, u64, put_u64_le); +writer!(WriteU128Le, u128, put_u128_le); + +writer!(WriteI16Le, i16, put_i16_le); +writer!(WriteI32Le, i32, put_i32_le); +writer!(WriteI64Le, i64, put_i64_le); +writer!(WriteI128Le, i128, put_i128_le); diff --git a/tokio/src/lib.rs b/tokio/src/lib.rs index 04258f3d6b6..88707c4d1c6 100644 --- a/tokio/src/lib.rs +++ b/tokio/src/lib.rs @@ -1,4 +1,4 @@ -#![doc(html_root_url = "https://docs.rs/tokio/0.2.20")] +#![doc(html_root_url = "https://docs.rs/tokio/0.2.22")] #![allow( clippy::cognitive_complexity, clippy::large_enum_variant, @@ -16,6 +16,7 @@ attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) ))] #![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(docsrs, feature(doc_alias))] //! A runtime for writing reliable, asynchronous, and slim applications. //! @@ -44,7 +45,7 @@ //! [signal]: crate::signal //! [fs]: crate::fs //! [runtime]: crate::runtime -//! [website]: https://tokio.rs/docs/overview/ +//! [website]: https://tokio.rs/tokio/tutorial //! //! # A Tour of Tokio //! @@ -228,7 +229,7 @@ //! on the number of blocking threads is very large. These limits can be //! configured on the [`Builder`]. //! -//! Two spawn a blocking task, you should use the [`spawn_blocking`] function. +//! To spawn a blocking task, you should use the [`spawn_blocking`] function. //! //! [`Builder`]: crate::runtime::Builder //! [`spawn_blocking`]: crate::task::spawn_blocking() diff --git a/tokio/src/loom/std/atomic_ptr.rs b/tokio/src/loom/std/atomic_ptr.rs index eb8e47557a2..f7fd56cc69b 100644 --- a/tokio/src/loom/std/atomic_ptr.rs +++ b/tokio/src/loom/std/atomic_ptr.rs @@ -11,10 +11,6 @@ impl AtomicPtr { let inner = std::sync::atomic::AtomicPtr::new(ptr); AtomicPtr { inner } } - - pub(crate) fn with_mut(&mut self, f: impl FnOnce(&mut *mut T) -> R) -> R { - f(self.inner.get_mut()) - } } impl Deref for AtomicPtr { diff --git a/tokio/src/loom/std/mod.rs b/tokio/src/loom/std/mod.rs index 595bdf60ed7..60ee56ad202 100644 --- a/tokio/src/loom/std/mod.rs +++ b/tokio/src/loom/std/mod.rs @@ -6,6 +6,8 @@ mod atomic_u32; mod atomic_u64; mod atomic_u8; mod atomic_usize; +#[cfg(feature = "parking_lot")] +mod parking_lot; mod unsafe_cell; pub(crate) mod cell { @@ -41,24 +43,21 @@ pub(crate) mod rand { pub(crate) mod sync { pub(crate) use std::sync::Arc; - #[cfg(feature = "parking_lot")] - mod pl_wrappers; - // Below, make sure all the feature-influenced types are exported for // internal use. Note however that some are not _currently_ named by // consuming code. #[cfg(feature = "parking_lot")] #[allow(unused_imports)] - pub(crate) use pl_wrappers::{Condvar, Mutex}; - - #[cfg(feature = "parking_lot")] - #[allow(unused_imports)] - pub(crate) use parking_lot::{MutexGuard, WaitTimeoutResult}; + pub(crate) use crate::loom::std::parking_lot::{ + Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult, + }; #[cfg(not(feature = "parking_lot"))] #[allow(unused_imports)] - pub(crate) use std::sync::{Condvar, Mutex, MutexGuard, WaitTimeoutResult}; + pub(crate) use std::sync::{ + Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult, + }; pub(crate) mod atomic { pub(crate) use crate::loom::std::atomic_ptr::AtomicPtr; diff --git a/tokio/src/loom/std/sync/pl_wrappers.rs b/tokio/src/loom/std/parking_lot.rs similarity index 59% rename from tokio/src/loom/std/sync/pl_wrappers.rs rename to tokio/src/loom/std/parking_lot.rs index 3be8ba1c108..25d94af44f5 100644 --- a/tokio/src/loom/std/sync/pl_wrappers.rs +++ b/tokio/src/loom/std/parking_lot.rs @@ -6,25 +6,33 @@ use std::sync::{LockResult, TryLockError, TryLockResult}; use std::time::Duration; -use parking_lot as pl; +// Types that do not need wrapping +pub(crate) use parking_lot::{MutexGuard, RwLockReadGuard, RwLockWriteGuard, WaitTimeoutResult}; /// Adapter for `parking_lot::Mutex` to the `std::sync::Mutex` interface. #[derive(Debug)] -pub(crate) struct Mutex(pl::Mutex); +pub(crate) struct Mutex(parking_lot::Mutex); + +#[derive(Debug)] +pub(crate) struct RwLock(parking_lot::RwLock); + +/// Adapter for `parking_lot::Condvar` to the `std::sync::Condvar` interface. +#[derive(Debug)] +pub(crate) struct Condvar(parking_lot::Condvar); impl Mutex { #[inline] pub(crate) fn new(t: T) -> Mutex { - Mutex(pl::Mutex::new(t)) + Mutex(parking_lot::Mutex::new(t)) } #[inline] - pub(crate) fn lock(&self) -> LockResult> { + pub(crate) fn lock(&self) -> LockResult> { Ok(self.0.lock()) } #[inline] - pub(crate) fn try_lock(&self) -> TryLockResult> { + pub(crate) fn try_lock(&self) -> TryLockResult> { match self.0.try_lock() { Some(guard) => Ok(guard), None => Err(TryLockError::WouldBlock), @@ -35,14 +43,24 @@ impl Mutex { // provided here as needed. } -/// Adapter for `parking_lot::Condvar` to the `std::sync::Condvar` interface. -#[derive(Debug)] -pub(crate) struct Condvar(pl::Condvar); +impl RwLock { + pub(crate) fn new(t: T) -> RwLock { + RwLock(parking_lot::RwLock::new(t)) + } + + pub(crate) fn read(&self) -> LockResult> { + Ok(self.0.read()) + } + + pub(crate) fn write(&self) -> LockResult> { + Ok(self.0.write()) + } +} impl Condvar { #[inline] pub(crate) fn new() -> Condvar { - Condvar(pl::Condvar::new()) + Condvar(parking_lot::Condvar::new()) } #[inline] @@ -58,8 +76,8 @@ impl Condvar { #[inline] pub(crate) fn wait<'a, T>( &self, - mut guard: pl::MutexGuard<'a, T>, - ) -> LockResult> { + mut guard: MutexGuard<'a, T>, + ) -> LockResult> { self.0.wait(&mut guard); Ok(guard) } @@ -67,9 +85,9 @@ impl Condvar { #[inline] pub(crate) fn wait_timeout<'a, T>( &self, - mut guard: pl::MutexGuard<'a, T>, + mut guard: MutexGuard<'a, T>, timeout: Duration, - ) -> LockResult<(pl::MutexGuard<'a, T>, pl::WaitTimeoutResult)> { + ) -> LockResult<(MutexGuard<'a, T>, WaitTimeoutResult)> { let wtr = self.0.wait_for(&mut guard, timeout); Ok((guard, wtr)) } diff --git a/tokio/src/macros/cfg.rs b/tokio/src/macros/cfg.rs index 85f95cbd3d4..4b77544eb5c 100644 --- a/tokio/src/macros/cfg.rs +++ b/tokio/src/macros/cfg.rs @@ -353,3 +353,52 @@ macro_rules! cfg_uds { )* } } + +macro_rules! cfg_unstable { + ($($item:item)*) => { + $( + #[cfg(tokio_unstable)] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + $item + )* + } +} + +macro_rules! cfg_trace { + ($($item:item)*) => { + $( + #[cfg(feature = "tracing")] + #[cfg_attr(docsrs, doc(cfg(feature = "tracing")))] + $item + )* + } +} + +macro_rules! cfg_not_trace { + ($($item:item)*) => { + $( + #[cfg(not(feature = "tracing"))] + $item + )* + } +} + +macro_rules! cfg_coop { + ($($item:item)*) => { + $( + #[cfg(any( + feature = "blocking", + feature = "dns", + feature = "fs", + feature = "io-driver", + feature = "io-std", + feature = "process", + feature = "rt-core", + feature = "sync", + feature = "stream", + feature = "time" + ))] + $item + )* + } +} diff --git a/tokio/src/macros/pin.rs b/tokio/src/macros/pin.rs index 33d8499e10d..ed844ef7d11 100644 --- a/tokio/src/macros/pin.rs +++ b/tokio/src/macros/pin.rs @@ -43,7 +43,7 @@ /// Pinning is useful when using `select!` and stream operators that require `T: /// Stream + Unpin`. /// -/// [`Future`]: https://doc.rust-lang.org/std/future/trait.Future.html +/// [`Future`]: trait@std::future::Future /// [`Box::pin`]: # /// /// # Usage diff --git a/tokio/src/macros/select.rs b/tokio/src/macros/select.rs index 51b6fcd608c..52c8fdd3404 100644 --- a/tokio/src/macros/select.rs +++ b/tokio/src/macros/select.rs @@ -30,7 +30,7 @@ /// /// The complete lifecycle of a `select!` expression is as follows: /// -/// 1. Evaluate all provded `` expressions. If the precondition +/// 1. Evaluate all provided `` expressions. If the precondition /// returns `false`, disable the branch for the remainder of the current call /// to `select!`. Re-entering `select!` due to a loop clears the "disabled" /// state. @@ -359,8 +359,11 @@ macro_rules! select { let start = $crate::macros::support::thread_rng_n(BRANCHES); for i in 0..BRANCHES { - let branch = (start + i) % BRANCHES; - + let branch; + #[allow(clippy::modulo_one)] + { + branch = (start + i) % BRANCHES; + } match branch { $( $crate::count!( $($skip)* ) => { diff --git a/tokio/src/net/addr.rs b/tokio/src/net/addr.rs index 343d4e21ff2..5ba898a15a7 100644 --- a/tokio/src/net/addr.rs +++ b/tokio/src/net/addr.rs @@ -18,7 +18,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV /// conversion directly, use [`lookup_host()`](super::lookup_host()). /// /// This trait is sealed and is intended to be opaque. The details of the trait -/// will change. Stabilization is pending enhancements to the Rust langague. +/// will change. Stabilization is pending enhancements to the Rust language. pub trait ToSocketAddrs: sealed::ToSocketAddrsPriv {} type ReadyFuture = future::Ready>; @@ -121,6 +121,20 @@ impl sealed::ToSocketAddrsPriv for (Ipv6Addr, u16) { } } +// ===== impl &[SocketAddr] ===== + +impl ToSocketAddrs for &[SocketAddr] {} + +impl sealed::ToSocketAddrsPriv for &[SocketAddr] { + type Iter = std::vec::IntoIter; + type Future = ReadyFuture; + + fn to_socket_addrs(&self) -> Self::Future { + let iter = self.to_vec().into_iter(); + future::ok(iter) + } +} + cfg_dns! { // ===== impl str ===== diff --git a/tokio/src/net/mod.rs b/tokio/src/net/mod.rs index eb24ac0ba57..da6ad1fc4a3 100644 --- a/tokio/src/net/mod.rs +++ b/tokio/src/net/mod.rs @@ -43,7 +43,7 @@ cfg_udp! { cfg_uds! { pub mod unix; - pub use unix::datagram::UnixDatagram; + pub use unix::datagram::socket::UnixDatagram; pub use unix::listener::UnixListener; pub use unix::stream::UnixStream; } diff --git a/tokio/src/net/tcp/listener.rs b/tokio/src/net/tcp/listener.rs index 262e0e1d51f..fd79b259b92 100644 --- a/tokio/src/net/tcp/listener.rs +++ b/tokio/src/net/tcp/listener.rs @@ -72,7 +72,7 @@ cfg_tcp! { } impl TcpListener { - /// Creates a new TcpListener which will be bound to the specified address. + /// Creates a new TcpListener, which will be bound to the specified address. /// /// The returned listener is ready for accepting connections. /// @@ -80,7 +80,9 @@ impl TcpListener { /// to this listener. The port allocated can be queried via the `local_addr` /// method. /// - /// The address type can be any implementor of `ToSocketAddrs` trait. + /// The address type can be any implementor of the [`ToSocketAddrs`] trait. + /// Note that strings only implement this trait when the **`dns`** feature + /// is enabled, as strings may contain domain names that need to be resolved. /// /// If `addr` yields multiple addresses, bind will be attempted with each of /// the addresses until one succeeds and returns the listener. If none of @@ -89,6 +91,8 @@ impl TcpListener { /// /// This function sets the `SO_REUSEADDR` option on the socket. /// + /// [`ToSocketAddrs`]: trait@crate::net::ToSocketAddrs + /// /// # Examples /// /// ```no_run @@ -98,7 +102,26 @@ impl TcpListener { /// /// #[tokio::main] /// async fn main() -> io::Result<()> { - /// let listener = TcpListener::bind("127.0.0.1:0").await?; + /// let listener = TcpListener::bind("127.0.0.1:2345").await?; + /// + /// // use the listener + /// + /// # let _ = listener; + /// Ok(()) + /// } + /// ``` + /// + /// Without the `dns` feature: + /// + /// ```no_run + /// use tokio::net::TcpListener; + /// use std::net::Ipv4Addr; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let listener = TcpListener::bind((Ipv4Addr::new(127, 0, 0, 1), 2345)).await?; /// /// // use the listener /// @@ -322,7 +345,7 @@ impl TcpListener { /// /// For more information about this option, see [`set_ttl`]. /// - /// [`set_ttl`]: #method.set_ttl + /// [`set_ttl`]: method@Self::set_ttl /// /// # Examples /// diff --git a/tokio/src/net/tcp/split.rs b/tokio/src/net/tcp/split.rs index 469056acc5b..0c1e359f72d 100644 --- a/tokio/src/net/tcp/split.rs +++ b/tokio/src/net/tcp/split.rs @@ -19,7 +19,7 @@ use std::net::Shutdown; use std::pin::Pin; use std::task::{Context, Poll}; -/// Read half of a [`TcpStream`], created by [`split`]. +/// Borrowed read half of a [`TcpStream`], created by [`split`]. /// /// Reading from a `ReadHalf` is usually done using the convenience methods found on the /// [`AsyncReadExt`] trait. Examples import this trait through [the prelude]. @@ -31,12 +31,12 @@ use std::task::{Context, Poll}; #[derive(Debug)] pub struct ReadHalf<'a>(&'a TcpStream); -/// Write half of a [`TcpStream`], created by [`split`]. +/// Borrowed write half of a [`TcpStream`], created by [`split`]. /// /// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will /// shut down the TCP stream in the write direction. /// -/// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found +/// Writing to an `WriteHalf` is usually done using the convenience methods found /// on the [`AsyncWriteExt`] trait. Examples import this trait through [the prelude]. /// /// [`TcpStream`]: TcpStream diff --git a/tokio/src/net/tcp/split_owned.rs b/tokio/src/net/tcp/split_owned.rs index ff82f6eda2b..6c2b9e6977e 100644 --- a/tokio/src/net/tcp/split_owned.rs +++ b/tokio/src/net/tcp/split_owned.rs @@ -37,10 +37,9 @@ pub struct OwnedReadHalf { /// Owned write half of a [`TcpStream`], created by [`into_split`]. /// -/// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will -/// shut down the TCP stream in the write direction. -/// -/// Dropping the write half will close the TCP stream in both directions. +/// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will +/// shut down the TCP stream in the write direction. Dropping the write half +/// will also shut down the write half of the TCP stream. /// /// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found /// on the [`AsyncWriteExt`] trait. Examples import this trait through [the prelude]. @@ -77,13 +76,13 @@ pub(crate) fn reunite( write.forget(); // This unwrap cannot fail as the api does not allow creating more than two Arcs, // and we just dropped the other half. - Ok(Arc::try_unwrap(read.inner).expect("Too many handles to Arc")) + Ok(Arc::try_unwrap(read.inner).expect("TcpStream: try_unwrap failed in reunite")) } else { Err(ReuniteError(read, write)) } } -/// Error indicating two halves were not from the same socket, and thus could +/// Error indicating that two halves were not from the same socket, and thus could /// not be reunited. #[derive(Debug)] pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf); @@ -209,9 +208,10 @@ impl OwnedWriteHalf { pub fn reunite(self, other: OwnedReadHalf) -> Result { reunite(other, self) } - /// Destroy the write half, but don't close the stream until the read half - /// is dropped. If the read half has already been dropped, this closes the - /// stream. + + /// Destroy the write half, but don't close the write half of the stream + /// until the read half is dropped. If the read half has already been + /// dropped, this closes the stream. pub fn forget(mut self) { self.shutdown_on_drop = false; drop(self); @@ -221,7 +221,7 @@ impl OwnedWriteHalf { impl Drop for OwnedWriteHalf { fn drop(&mut self) { if self.shutdown_on_drop { - let _ = self.inner.shutdown(Shutdown::Both); + let _ = self.inner.shutdown(Shutdown::Write); } } } @@ -251,7 +251,11 @@ impl AsyncWrite for OwnedWriteHalf { // `poll_shutdown` on a write half shutdowns the stream in the "write" direction. fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - self.inner.shutdown(Shutdown::Write).into() + let res = self.inner.shutdown(Shutdown::Write); + if res.is_ok() { + Pin::into_inner(self).shutdown_on_drop = false; + } + res.into() } } diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index ee44f810c37..02b5262723e 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -63,14 +63,18 @@ cfg_tcp! { impl TcpStream { /// Opens a TCP connection to a remote host. /// - /// `addr` is an address of the remote host. Anything which implements - /// `ToSocketAddrs` trait can be supplied for the address. + /// `addr` is an address of the remote host. Anything which implements the + /// [`ToSocketAddrs`] trait can be supplied as the address. Note that + /// strings only implement this trait when the **`dns`** feature is enabled, + /// as strings may contain domain names that need to be resolved. /// /// If `addr` yields multiple addresses, connect will be attempted with each /// of the addresses until a connection is successful. If none of the /// addresses result in a successful connection, the error returned from the /// last connection attempt (the last address) is returned. /// + /// [`ToSocketAddrs`]: trait@crate::net::ToSocketAddrs + /// /// # Examples /// /// ```no_run @@ -90,6 +94,26 @@ impl TcpStream { /// } /// ``` /// + /// Without the `dns` feature: + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use tokio::prelude::*; + /// use std::error::Error; + /// use std::net::Ipv4Addr; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// // Connect to a peer + /// let mut stream = TcpStream::connect((Ipv4Addr::new(127, 0, 0, 1), 8080)).await?; + /// + /// // Write some data. + /// stream.write_all(b"hello world!").await?; + /// + /// Ok(()) + /// } + /// ``` + /// /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait. /// /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all @@ -635,6 +659,9 @@ impl TcpStream { self.io.get_ref().set_linger(dur) } + // These lifetime markers also appear in the generated documentation, and make + // it more clear that this is a *borrowed* split. + #[allow(clippy::needless_lifetimes)] /// Splits a `TcpStream` into a read half and a write half, which can be used /// to read and write the stream concurrently. /// @@ -642,7 +669,7 @@ impl TcpStream { /// moved into independently spawned tasks. /// /// [`into_split`]: TcpStream::into_split() - pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) { + pub fn split<'a>(&'a mut self) -> (ReadHalf<'a>, WriteHalf<'a>) { split(self) } @@ -652,7 +679,11 @@ impl TcpStream { /// Unlike [`split`], the owned halves can be moved to separate tasks, however /// this comes at the cost of a heap allocation. /// + /// **Note:** Dropping the write half will shut down the write half of the TCP + /// stream. This is equivalent to calling [`shutdown(Write)`] on the `TcpStream`. + /// /// [`split`]: TcpStream::split() + /// [`shutdown(Write)`]: fn@crate::net::TcpStream::shutdown pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { split_owned(self) } diff --git a/tokio/src/net/udp/socket.rs b/tokio/src/net/udp/socket.rs index faf1dca615a..97090a206d3 100644 --- a/tokio/src/net/udp/socket.rs +++ b/tokio/src/net/udp/socket.rs @@ -111,7 +111,7 @@ impl UdpSocket { /// The [`connect`] method will connect this socket to a remote address. The future /// will resolve to an error if the socket is not connected. /// - /// [`connect`]: #method.connect + /// [`connect`]: method@Self::connect pub async fn send(&mut self, buf: &[u8]) -> io::Result { poll_fn(|cx| self.poll_send(cx, buf)).await } @@ -150,7 +150,7 @@ impl UdpSocket { /// The [`connect`] method will connect this socket to a remote address. The future /// will fail if the socket is not connected. /// - /// [`connect`]: #method.connect + /// [`connect`]: method@Self::connect pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result { poll_fn(|cx| self.poll_recv(cx, buf)).await } @@ -235,7 +235,7 @@ impl UdpSocket { /// /// For more information about this option, see [`set_broadcast`]. /// - /// [`set_broadcast`]: #method.set_broadcast + /// [`set_broadcast`]: method@Self::set_broadcast pub fn broadcast(&self) -> io::Result { self.io.get_ref().broadcast() } @@ -252,7 +252,7 @@ impl UdpSocket { /// /// For more information about this option, see [`set_multicast_loop_v4`]. /// - /// [`set_multicast_loop_v4`]: #method.set_multicast_loop_v4 + /// [`set_multicast_loop_v4`]: method@Self::set_multicast_loop_v4 pub fn multicast_loop_v4(&self) -> io::Result { self.io.get_ref().multicast_loop_v4() } @@ -272,7 +272,7 @@ impl UdpSocket { /// /// For more information about this option, see [`set_multicast_ttl_v4`]. /// - /// [`set_multicast_ttl_v4`]: #method.set_multicast_ttl_v4 + /// [`set_multicast_ttl_v4`]: method@Self::set_multicast_ttl_v4 pub fn multicast_ttl_v4(&self) -> io::Result { self.io.get_ref().multicast_ttl_v4() } @@ -294,7 +294,7 @@ impl UdpSocket { /// /// For more information about this option, see [`set_multicast_loop_v6`]. /// - /// [`set_multicast_loop_v6`]: #method.set_multicast_loop_v6 + /// [`set_multicast_loop_v6`]: method@Self::set_multicast_loop_v6 pub fn multicast_loop_v6(&self) -> io::Result { self.io.get_ref().multicast_loop_v6() } @@ -314,7 +314,7 @@ impl UdpSocket { /// /// For more information about this option, see [`set_ttl`]. /// - /// [`set_ttl`]: #method.set_ttl + /// [`set_ttl`]: method@Self::set_ttl pub fn ttl(&self) -> io::Result { self.io.get_ref().ttl() } @@ -351,7 +351,7 @@ impl UdpSocket { /// /// For more information about this option, see [`join_multicast_v4`]. /// - /// [`join_multicast_v4`]: #method.join_multicast_v4 + /// [`join_multicast_v4`]: method@Self::join_multicast_v4 pub fn leave_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> io::Result<()> { self.io.get_ref().leave_multicast_v4(&multiaddr, &interface) } @@ -360,7 +360,7 @@ impl UdpSocket { /// /// For more information about this option, see [`join_multicast_v6`]. /// - /// [`join_multicast_v6`]: #method.join_multicast_v6 + /// [`join_multicast_v6`]: method@Self::join_multicast_v6 pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> { self.io.get_ref().leave_multicast_v6(multiaddr, interface) } diff --git a/tokio/src/net/udp/split.rs b/tokio/src/net/udp/split.rs index 55542cb6316..8d87f1c7c67 100644 --- a/tokio/src/net/udp/split.rs +++ b/tokio/src/net/udp/split.rs @@ -1,6 +1,6 @@ -//! [`UdpSocket`](../struct.UdpSocket.html) split support. +//! [`UdpSocket`](crate::net::UdpSocket) split support. //! -//! The [`split`](../struct.UdpSocket.html#method.split) method splits a +//! The [`split`](method@crate::net::UdpSocket::split) method splits a //! `UdpSocket` into a receive half and a send half, which can be used to //! receive and send datagrams concurrently, even from two different tasks. //! @@ -23,14 +23,14 @@ use std::sync::Arc; /// The send half after [`split`](super::UdpSocket::split). /// -/// Use [`send_to`](#method.send_to) or [`send`](#method.send) to send +/// Use [`send_to`](method@Self::send_to) or [`send`](method@Self::send) to send /// datagrams. #[derive(Debug)] pub struct SendHalf(Arc); /// The recv half after [`split`](super::UdpSocket::split). /// -/// Use [`recv_from`](#method.recv_from) or [`recv`](#method.recv) to receive +/// Use [`recv_from`](method@Self::recv_from) or [`recv`](method@Self::recv) to receive /// datagrams. #[derive(Debug)] pub struct RecvHalf(Arc); @@ -42,7 +42,7 @@ pub(crate) fn split(socket: UdpSocket) -> (RecvHalf, SendHalf) { (RecvHalf(recv), SendHalf(send)) } -/// Error indicating two halves were not from the same socket, and thus could +/// Error indicating that two halves were not from the same socket, and thus could /// not be `reunite`d. #[derive(Debug)] pub struct ReuniteError(pub SendHalf, pub RecvHalf); diff --git a/tokio/src/net/unix/datagram/mod.rs b/tokio/src/net/unix/datagram/mod.rs new file mode 100644 index 00000000000..f484ae34a34 --- /dev/null +++ b/tokio/src/net/unix/datagram/mod.rs @@ -0,0 +1,8 @@ +//! Unix datagram types. + +pub(crate) mod socket; +pub(crate) mod split; +pub(crate) mod split_owned; + +pub use split::{RecvHalf, SendHalf}; +pub use split_owned::{OwnedRecvHalf, OwnedSendHalf, ReuniteError}; diff --git a/tokio/src/net/unix/datagram.rs b/tokio/src/net/unix/datagram/socket.rs similarity index 67% rename from tokio/src/net/unix/datagram.rs rename to tokio/src/net/unix/datagram/socket.rs index ff0f4241d5a..a332d2afb45 100644 --- a/tokio/src/net/unix/datagram.rs +++ b/tokio/src/net/unix/datagram/socket.rs @@ -1,5 +1,7 @@ use crate::future::poll_fn; use crate::io::PollEvented; +use crate::net::unix::datagram::split::{split, RecvHalf, SendHalf}; +use crate::net::unix::datagram::split_owned::{split_owned, OwnedRecvHalf, OwnedSendHalf}; use std::convert::TryFrom; use std::fmt; @@ -83,6 +85,73 @@ impl UnixDatagram { poll_fn(|cx| self.poll_send_priv(cx, buf)).await } + /// Try to send a datagram to the peer without waiting. + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// use tokio::net::UnixDatagram; + /// + /// let bytes = b"bytes"; + /// // We use a socket pair so that they are assigned + /// // each other as a peer. + /// let (mut first, mut second) = UnixDatagram::pair()?; + /// + /// let size = first.try_send(bytes)?; + /// assert_eq!(size, bytes.len()); + /// + /// let mut buffer = vec![0u8; 24]; + /// let size = second.try_recv(&mut buffer)?; + /// + /// let dgram = &buffer.as_slice()[..size]; + /// assert_eq!(dgram, bytes); + /// # Ok(()) + /// # } + /// ``` + pub fn try_send(&mut self, buf: &[u8]) -> io::Result { + self.io.get_ref().send(buf) + } + + /// Try to send a datagram to the peer without waiting. + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// use { + /// tokio::net::UnixDatagram, + /// tempfile::tempdir, + /// }; + /// + /// let bytes = b"bytes"; + /// // We use a temporary directory so that the socket + /// // files left by the bound sockets will get cleaned up. + /// let tmp = tempdir().unwrap(); + /// + /// let server_path = tmp.path().join("server"); + /// let mut server = UnixDatagram::bind(&server_path)?; + /// + /// let client_path = tmp.path().join("client"); + /// let mut client = UnixDatagram::bind(&client_path)?; + /// + /// let size = client.try_send_to(bytes, &server_path)?; + /// assert_eq!(size, bytes.len()); + /// + /// let mut buffer = vec![0u8; 24]; + /// let (size, addr) = server.try_recv_from(&mut buffer)?; + /// + /// let dgram = &buffer.as_slice()[..size]; + /// assert_eq!(dgram, bytes); + /// assert_eq!(addr.as_pathname().unwrap(), &client_path); + /// # Ok(()) + /// # } + /// ``` + pub fn try_send_to

(&mut self, buf: &[u8], target: P) -> io::Result + where + P: AsRef, + { + self.io.get_ref().send_to(buf, target) + } + // Poll IO functions that takes `&self` are provided for the split API. // // They are not public because (taken from the doc of `PollEvented`): @@ -114,6 +183,11 @@ impl UnixDatagram { poll_fn(|cx| self.poll_recv_priv(cx, buf)).await } + /// Try to receive a datagram from the peer without waiting. + pub fn try_recv(&mut self, buf: &mut [u8]) -> io::Result { + self.io.get_ref().recv(buf) + } + pub(crate) fn poll_recv_priv( &self, cx: &mut Context<'_>, @@ -160,6 +234,11 @@ impl UnixDatagram { poll_fn(|cx| self.poll_recv_from_priv(cx, buf)).await } + /// Try to receive data from the socket without waiting. + pub fn try_recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.io.get_ref().recv_from(buf) + } + pub(crate) fn poll_recv_from_priv( &self, cx: &mut Context<'_>, @@ -201,6 +280,35 @@ impl UnixDatagram { pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { self.io.get_ref().shutdown(how) } + + // These lifetime markers also appear in the generated documentation, and make + // it more clear that this is a *borrowed* split. + #[allow(clippy::needless_lifetimes)] + /// Split a `UnixDatagram` into a receive half and a send half, which can be used + /// to receive and send the datagram concurrently. + /// + /// This method is more efficient than [`into_split`], but the halves cannot + /// be moved into independently spawned tasks. + /// + /// [`into_split`]: fn@crate::net::UnixDatagram::into_split + pub fn split<'a>(&'a mut self) -> (RecvHalf<'a>, SendHalf<'a>) { + split(self) + } + + /// Split a `UnixDatagram` into a receive half and a send half, which can be used + /// to receive and send the datagram concurrently. + /// + /// Unlike [`split`], the owned halves can be moved to separate tasks, + /// however this comes at the cost of a heap allocation. + /// + /// **Note:** Dropping the write half will shut down the write half of the + /// datagram. This is equivalent to calling [`shutdown(Write)`]. + /// + /// [`split`]: fn@crate::net::UnixDatagram::split + /// [`shutdown(Write)`]:fn@crate::net::UnixDatagram::shutdown + pub fn into_split(self) -> (OwnedRecvHalf, OwnedSendHalf) { + split_owned(self) + } } impl TryFrom for mio_uds::UnixDatagram { diff --git a/tokio/src/net/unix/datagram/split.rs b/tokio/src/net/unix/datagram/split.rs new file mode 100644 index 00000000000..e42eeda8844 --- /dev/null +++ b/tokio/src/net/unix/datagram/split.rs @@ -0,0 +1,68 @@ +//! `UnixDatagram` split support. +//! +//! A `UnixDatagram` can be split into a `RecvHalf` and a `SendHalf` with the +//! `UnixDatagram::split` method. + +use crate::future::poll_fn; +use crate::net::UnixDatagram; + +use std::io; +use std::os::unix::net::SocketAddr; +use std::path::Path; + +/// Borrowed receive half of a [`UnixDatagram`], created by [`split`]. +/// +/// [`UnixDatagram`]: UnixDatagram +/// [`split`]: crate::net::UnixDatagram::split() +#[derive(Debug)] +pub struct RecvHalf<'a>(&'a UnixDatagram); + +/// Borrowed send half of a [`UnixDatagram`], created by [`split`]. +/// +/// [`UnixDatagram`]: UnixDatagram +/// [`split`]: crate::net::UnixDatagram::split() +#[derive(Debug)] +pub struct SendHalf<'a>(&'a UnixDatagram); + +pub(crate) fn split(stream: &mut UnixDatagram) -> (RecvHalf<'_>, SendHalf<'_>) { + (RecvHalf(&*stream), SendHalf(&*stream)) +} + +impl RecvHalf<'_> { + /// Receives data from the socket. + pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + poll_fn(|cx| self.0.poll_recv_from_priv(cx, buf)).await + } + + /// Receives data from the socket. + pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result { + poll_fn(|cx| self.0.poll_recv_priv(cx, buf)).await + } +} + +impl SendHalf<'_> { + /// Sends data on the socket to the specified address. + pub async fn send_to

(&mut self, buf: &[u8], target: P) -> io::Result + where + P: AsRef + Unpin, + { + poll_fn(|cx| self.0.poll_send_to_priv(cx, buf, target.as_ref())).await + } + + /// Sends data on the socket to the socket's peer. + pub async fn send(&mut self, buf: &[u8]) -> io::Result { + poll_fn(|cx| self.0.poll_send_priv(cx, buf)).await + } +} + +impl AsRef for RecvHalf<'_> { + fn as_ref(&self) -> &UnixDatagram { + self.0 + } +} + +impl AsRef for SendHalf<'_> { + fn as_ref(&self) -> &UnixDatagram { + self.0 + } +} diff --git a/tokio/src/net/unix/datagram/split_owned.rs b/tokio/src/net/unix/datagram/split_owned.rs new file mode 100644 index 00000000000..699771f30e6 --- /dev/null +++ b/tokio/src/net/unix/datagram/split_owned.rs @@ -0,0 +1,148 @@ +//! `UnixDatagram` owned split support. +//! +//! A `UnixDatagram` can be split into an `OwnedSendHalf` and a `OwnedRecvHalf` +//! with the `UnixDatagram::into_split` method. + +use crate::future::poll_fn; +use crate::net::UnixDatagram; + +use std::error::Error; +use std::net::Shutdown; +use std::os::unix::net::SocketAddr; +use std::path::Path; +use std::sync::Arc; +use std::{fmt, io}; + +pub(crate) fn split_owned(socket: UnixDatagram) -> (OwnedRecvHalf, OwnedSendHalf) { + let shared = Arc::new(socket); + let send = shared.clone(); + let recv = shared; + ( + OwnedRecvHalf { inner: recv }, + OwnedSendHalf { + inner: send, + shutdown_on_drop: true, + }, + ) +} + +/// Owned send half of a [`UnixDatagram`], created by [`into_split`]. +/// +/// [`UnixDatagram`]: UnixDatagram +/// [`into_split`]: UnixDatagram::into_split() +#[derive(Debug)] +pub struct OwnedSendHalf { + inner: Arc, + shutdown_on_drop: bool, +} + +/// Owned receive half of a [`UnixDatagram`], created by [`into_split`]. +/// +/// [`UnixDatagram`]: UnixDatagram +/// [`into_split`]: UnixDatagram::into_split() +#[derive(Debug)] +pub struct OwnedRecvHalf { + inner: Arc, +} + +/// Error indicating that two halves were not from the same socket, and thus could +/// not be `reunite`d. +#[derive(Debug)] +pub struct ReuniteError(pub OwnedSendHalf, pub OwnedRecvHalf); + +impl fmt::Display for ReuniteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "tried to reunite halves that are not from the same socket" + ) + } +} + +impl Error for ReuniteError {} + +fn reunite(s: OwnedSendHalf, r: OwnedRecvHalf) -> Result { + if Arc::ptr_eq(&s.inner, &r.inner) { + s.forget(); + // Only two instances of the `Arc` are ever created, one for the + // receiver and one for the sender, and those `Arc`s are never exposed + // externally. And so when we drop one here, the other one must be the + // only remaining one. + Ok(Arc::try_unwrap(r.inner).expect("UnixDatagram: try_unwrap failed in reunite")) + } else { + Err(ReuniteError(s, r)) + } +} + +impl OwnedRecvHalf { + /// Attempts to put the two "halves" of a `UnixDatagram` back together and + /// recover the original socket. Succeeds only if the two "halves" + /// originated from the same call to [`into_split`]. + /// + /// [`into_split`]: UnixDatagram::into_split() + pub fn reunite(self, other: OwnedSendHalf) -> Result { + reunite(other, self) + } + + /// Receives data from the socket. + pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + poll_fn(|cx| self.inner.poll_recv_from_priv(cx, buf)).await + } + + /// Receives data from the socket. + pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result { + poll_fn(|cx| self.inner.poll_recv_priv(cx, buf)).await + } +} + +impl OwnedSendHalf { + /// Attempts to put the two "halves" of a `UnixDatagram` back together and + /// recover the original socket. Succeeds only if the two "halves" + /// originated from the same call to [`into_split`]. + /// + /// [`into_split`]: UnixDatagram::into_split() + pub fn reunite(self, other: OwnedRecvHalf) -> Result { + reunite(self, other) + } + + /// Sends data on the socket to the specified address. + pub async fn send_to

(&mut self, buf: &[u8], target: P) -> io::Result + where + P: AsRef + Unpin, + { + poll_fn(|cx| self.inner.poll_send_to_priv(cx, buf, target.as_ref())).await + } + + /// Sends data on the socket to the socket's peer. + pub async fn send(&mut self, buf: &[u8]) -> io::Result { + poll_fn(|cx| self.inner.poll_send_priv(cx, buf)).await + } + + /// Destroy the send half, but don't close the send half of the stream + /// until the receive half is dropped. If the read half has already been + /// dropped, this closes the stream. + pub fn forget(mut self) { + self.shutdown_on_drop = false; + drop(self); + } +} + +impl Drop for OwnedSendHalf { + fn drop(&mut self) { + if self.shutdown_on_drop { + let _ = self.inner.shutdown(Shutdown::Write); + } + } +} + +impl AsRef for OwnedSendHalf { + fn as_ref(&self) -> &UnixDatagram { + &self.inner + } +} + +impl AsRef for OwnedRecvHalf { + fn as_ref(&self) -> &UnixDatagram { + &self.inner + } +} diff --git a/tokio/src/net/unix/mod.rs b/tokio/src/net/unix/mod.rs index ddba60d10ac..b079fe04d7d 100644 --- a/tokio/src/net/unix/mod.rs +++ b/tokio/src/net/unix/mod.rs @@ -1,6 +1,6 @@ //! Unix domain socket utility types -pub(crate) mod datagram; +pub mod datagram; mod incoming; pub use incoming::Incoming; @@ -11,6 +11,9 @@ pub(crate) use listener::UnixListener; mod split; pub use split::{ReadHalf, WriteHalf}; +mod split_owned; +pub use split_owned::{OwnedReadHalf, OwnedWriteHalf, ReuniteError}; + pub(crate) mod stream; pub(crate) use stream::UnixStream; diff --git a/tokio/src/net/unix/split.rs b/tokio/src/net/unix/split.rs index 9b9fa5ee1dc..4fd85774e9a 100644 --- a/tokio/src/net/unix/split.rs +++ b/tokio/src/net/unix/split.rs @@ -17,11 +17,32 @@ use std::net::Shutdown; use std::pin::Pin; use std::task::{Context, Poll}; -/// Read half of a `UnixStream`. +/// Borrowed read half of a [`UnixStream`], created by [`split`]. +/// +/// Reading from a `ReadHalf` is usually done using the convenience methods found on the +/// [`AsyncReadExt`] trait. Examples import this trait through [the prelude]. +/// +/// [`UnixStream`]: UnixStream +/// [`split`]: UnixStream::split() +/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt +/// [the prelude]: crate::prelude #[derive(Debug)] pub struct ReadHalf<'a>(&'a UnixStream); -/// Write half of a `UnixStream`. +/// Borrowed write half of a [`UnixStream`], created by [`split`]. +/// +/// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will +/// shut down the UnixStream stream in the write direction. +/// +/// Writing to an `WriteHalf` is usually done using the convenience methods found +/// on the [`AsyncWriteExt`] trait. Examples import this trait through [the prelude]. +/// +/// [`UnixStream`]: UnixStream +/// [`split`]: UnixStream::split() +/// [`AsyncWrite`]: trait@crate::io::AsyncWrite +/// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown +/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt +/// [the prelude]: crate::prelude #[derive(Debug)] pub struct WriteHalf<'a>(&'a UnixStream); diff --git a/tokio/src/net/unix/split_owned.rs b/tokio/src/net/unix/split_owned.rs new file mode 100644 index 00000000000..eb35304bfa2 --- /dev/null +++ b/tokio/src/net/unix/split_owned.rs @@ -0,0 +1,187 @@ +//! `UnixStream` owned split support. +//! +//! A `UnixStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf` +//! with the `UnixStream::into_split` method. `OwnedReadHalf` implements +//! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`. +//! +//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized +//! split has no associated overhead and enforces all invariants at the type +//! level. + +use crate::io::{AsyncRead, AsyncWrite}; +use crate::net::UnixStream; + +use std::error::Error; +use std::mem::MaybeUninit; +use std::net::Shutdown; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::{fmt, io}; + +/// Owned read half of a [`UnixStream`], created by [`into_split`]. +/// +/// Reading from an `OwnedReadHalf` is usually done using the convenience methods found +/// on the [`AsyncReadExt`] trait. Examples import this trait through [the prelude]. +/// +/// [`UnixStream`]: crate::net::UnixStream +/// [`into_split`]: crate::net::UnixStream::into_split() +/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt +/// [the prelude]: crate::prelude +#[derive(Debug)] +pub struct OwnedReadHalf { + inner: Arc, +} + +/// Owned write half of a [`UnixStream`], created by [`into_split`]. +/// +/// Note that in the [`AsyncWrite`] implementation of this type, +/// [`poll_shutdown`] will shut down the stream in the write direction. +/// Dropping the write half will also shut down the write half of the stream. +/// +/// Writing to an `OwnedWriteHalf` is usually done using the convenience methods +/// found on the [`AsyncWriteExt`] trait. Examples import this trait through +/// [the prelude]. +/// +/// [`UnixStream`]: crate::net::UnixStream +/// [`into_split`]: crate::net::UnixStream::into_split() +/// [`AsyncWrite`]: trait@crate::io::AsyncWrite +/// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown +/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt +/// [the prelude]: crate::prelude +#[derive(Debug)] +pub struct OwnedWriteHalf { + inner: Arc, + shutdown_on_drop: bool, +} + +pub(crate) fn split_owned(stream: UnixStream) -> (OwnedReadHalf, OwnedWriteHalf) { + let arc = Arc::new(stream); + let read = OwnedReadHalf { + inner: Arc::clone(&arc), + }; + let write = OwnedWriteHalf { + inner: arc, + shutdown_on_drop: true, + }; + (read, write) +} + +pub(crate) fn reunite( + read: OwnedReadHalf, + write: OwnedWriteHalf, +) -> Result { + if Arc::ptr_eq(&read.inner, &write.inner) { + write.forget(); + // This unwrap cannot fail as the api does not allow creating more than two Arcs, + // and we just dropped the other half. + Ok(Arc::try_unwrap(read.inner).expect("UnixStream: try_unwrap failed in reunite")) + } else { + Err(ReuniteError(read, write)) + } +} + +/// Error indicating that two halves were not from the same socket, and thus could +/// not be reunited. +#[derive(Debug)] +pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf); + +impl fmt::Display for ReuniteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "tried to reunite halves that are not from the same socket" + ) + } +} + +impl Error for ReuniteError {} + +impl OwnedReadHalf { + /// Attempts to put the two halves of a `UnixStream` back together and + /// recover the original socket. Succeeds only if the two halves + /// originated from the same call to [`into_split`]. + /// + /// [`into_split`]: crate::net::UnixStream::into_split() + pub fn reunite(self, other: OwnedWriteHalf) -> Result { + reunite(self, other) + } +} + +impl AsyncRead for OwnedReadHalf { + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { + false + } + + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.inner.poll_read_priv(cx, buf) + } +} + +impl OwnedWriteHalf { + /// Attempts to put the two halves of a `UnixStream` back together and + /// recover the original socket. Succeeds only if the two halves + /// originated from the same call to [`into_split`]. + /// + /// [`into_split`]: crate::net::UnixStream::into_split() + pub fn reunite(self, other: OwnedReadHalf) -> Result { + reunite(other, self) + } + + /// Destroy the write half, but don't close the write half of the stream + /// until the read half is dropped. If the read half has already been + /// dropped, this closes the stream. + pub fn forget(mut self) { + self.shutdown_on_drop = false; + drop(self); + } +} + +impl Drop for OwnedWriteHalf { + fn drop(&mut self) { + if self.shutdown_on_drop { + let _ = self.inner.shutdown(Shutdown::Write); + } + } +} + +impl AsyncWrite for OwnedWriteHalf { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.inner.poll_write_priv(cx, buf) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + // flush is a no-op + Poll::Ready(Ok(())) + } + + // `poll_shutdown` on a write half shutdowns the stream in the "write" direction. + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + let res = self.inner.shutdown(Shutdown::Write); + if res.is_ok() { + Pin::into_inner(self).shutdown_on_drop = false; + } + res.into() + } +} + +impl AsRef for OwnedReadHalf { + fn as_ref(&self) -> &UnixStream { + &*self.inner + } +} + +impl AsRef for OwnedWriteHalf { + fn as_ref(&self) -> &UnixStream { + &*self.inner + } +} diff --git a/tokio/src/net/unix/stream.rs b/tokio/src/net/unix/stream.rs index beae699962d..5fe242d0887 100644 --- a/tokio/src/net/unix/stream.rs +++ b/tokio/src/net/unix/stream.rs @@ -1,6 +1,7 @@ use crate::future::poll_fn; use crate::io::{AsyncRead, AsyncWrite, PollEvented}; use crate::net::unix::split::{split, ReadHalf, WriteHalf}; +use crate::net::unix::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf}; use crate::net::unix::ucred::{self, UCred}; use std::convert::TryFrom; @@ -109,11 +110,34 @@ impl UnixStream { self.io.get_ref().shutdown(how) } + // These lifetime markers also appear in the generated documentation, and make + // it more clear that this is a *borrowed* split. + #[allow(clippy::needless_lifetimes)] /// Split a `UnixStream` into a read half and a write half, which can be used /// to read and write the stream concurrently. - pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) { + /// + /// This method is more efficient than [`into_split`], but the halves cannot be + /// moved into independently spawned tasks. + /// + /// [`into_split`]: Self::into_split() + pub fn split<'a>(&'a mut self) -> (ReadHalf<'a>, WriteHalf<'a>) { split(self) } + + /// Splits a `UnixStream` into a read half and a write half, which can be used + /// to read and write the stream concurrently. + /// + /// Unlike [`split`], the owned halves can be moved to separate tasks, however + /// this comes at the cost of a heap allocation. + /// + /// **Note:** Dropping the write half will shut down the write half of the + /// stream. This is equivalent to calling [`shutdown(Write)`] on the `UnixStream`. + /// + /// [`split`]: Self::split() + /// [`shutdown(Write)`]: fn@Self::shutdown + pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { + split_owned(self) + } } impl TryFrom for mio_uds::UnixStream { diff --git a/tokio/src/net/unix/ucred.rs b/tokio/src/net/unix/ucred.rs index cdd77ea4140..466aedc21fe 100644 --- a/tokio/src/net/unix/ucred.rs +++ b/tokio/src/net/unix/ucred.rs @@ -22,7 +22,7 @@ pub(crate) use self::impl_linux::get_peer_cred; ))] pub(crate) use self::impl_macos::get_peer_cred; -#[cfg(any(target_os = "solaris"))] +#[cfg(any(target_os = "solaris", target_os = "illumos"))] pub(crate) use self::impl_solaris::get_peer_cred; #[cfg(any(target_os = "linux", target_os = "android"))] @@ -110,7 +110,7 @@ pub(crate) mod impl_macos { } } -#[cfg(any(target_os = "solaris"))] +#[cfg(any(target_os = "solaris", target_os = "illumos"))] pub(crate) mod impl_solaris { use crate::net::unix::UnixStream; use std::io; diff --git a/tokio/src/process/mod.rs b/tokio/src/process/mod.rs index 7231511235e..e04a43510a9 100644 --- a/tokio/src/process/mod.rs +++ b/tokio/src/process/mod.rs @@ -6,6 +6,8 @@ //! variants) return "future aware" types that interoperate with Tokio. The asynchronous process //! support is provided through signal handling on Unix and system APIs on Windows. //! +//! [`std::process::Command`]: std::process::Command +//! //! # Examples //! //! Here's an example program which will spawn `echo hello world` and then wait @@ -140,6 +142,9 @@ use std::task::Poll; /// [output](Command::output). /// /// `Command` uses asynchronous versions of some `std` types (for example [`Child`]). +/// +/// [`std::process::Command`]: std::process::Command +/// [`Child`]: struct@Child #[derive(Debug)] pub struct Command { std: StdCommand, @@ -171,7 +176,7 @@ impl Command { /// The search path to be used may be controlled by setting the /// `PATH` environment variable on the Command, /// but this has some implementation limitations on Windows - /// (see issue rust-lang/rust#37519). + /// (see issue [rust-lang/rust#37519]). /// /// # Examples /// @@ -181,6 +186,8 @@ impl Command { /// use tokio::process::Command; /// let command = Command::new("sh"); /// ``` + /// + /// [rust-lang/rust#37519]: https://github.com/rust-lang/rust/issues/37519 pub fn new>(program: S) -> Command { Self::from(StdCommand::new(program)) } @@ -204,7 +211,7 @@ impl Command { /// /// To pass multiple arguments see [`args`]. /// - /// [`args`]: #method.args + /// [`args`]: method@Self::args /// /// # Examples /// @@ -226,7 +233,7 @@ impl Command { /// /// To pass a single argument see [`arg`]. /// - /// [`arg`]: #method.arg + /// [`arg`]: method@Self::arg /// /// # Examples /// @@ -701,7 +708,7 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); let ret = Pin::new(&mut self.inner).poll(cx); @@ -710,6 +717,10 @@ where self.kill_on_drop = false; } + if ret.is_ready() { + coop.made_progress(); + } + ret } } @@ -872,6 +883,11 @@ impl AsyncWrite for ChildStdin { } impl AsyncRead for ChildStdout { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { + // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/process.rs#L314 + false + } + fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -882,6 +898,11 @@ impl AsyncRead for ChildStdout { } impl AsyncRead for ChildStderr { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { + // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/process.rs#L375 + false + } + fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/tokio/src/runtime/blocking/mod.rs b/tokio/src/runtime/blocking/mod.rs index 5c808335cc0..0b36a75f655 100644 --- a/tokio/src/runtime/blocking/mod.rs +++ b/tokio/src/runtime/blocking/mod.rs @@ -9,7 +9,7 @@ cfg_blocking_impl! { mod schedule; mod shutdown; - mod task; + pub(crate) mod task; use crate::runtime::Builder; diff --git a/tokio/src/runtime/blocking/pool.rs b/tokio/src/runtime/blocking/pool.rs index a3b208d1710..40d417b19f5 100644 --- a/tokio/src/runtime/blocking/pool.rs +++ b/tokio/src/runtime/blocking/pool.rs @@ -148,7 +148,7 @@ impl fmt::Debug for BlockingPool { // ===== impl Spawner ===== impl Spawner { - fn spawn(&self, task: Task, rt: &Handle) -> Result<(), ()> { + pub(crate) fn spawn(&self, task: Task, rt: &Handle) -> Result<(), ()> { let shutdown_tx = { let mut shared = self.inner.shared.lock().unwrap(); diff --git a/tokio/src/runtime/blocking/schedule.rs b/tokio/src/runtime/blocking/schedule.rs index e10778d5304..4e044ab2987 100644 --- a/tokio/src/runtime/blocking/schedule.rs +++ b/tokio/src/runtime/blocking/schedule.rs @@ -6,7 +6,7 @@ use crate::runtime::task::{self, Task}; /// /// We avoid storing the task by forgetting it in `bind` and re-materializing it /// in `release. -pub(super) struct NoopSchedule; +pub(crate) struct NoopSchedule; impl task::Schedule for NoopSchedule { fn bind(_task: Task) -> NoopSchedule { diff --git a/tokio/src/runtime/blocking/shutdown.rs b/tokio/src/runtime/blocking/shutdown.rs index f3c60ee301d..e76a7013552 100644 --- a/tokio/src/runtime/blocking/shutdown.rs +++ b/tokio/src/runtime/blocking/shutdown.rs @@ -33,15 +33,25 @@ impl Receiver { /// duration. If `timeout` is `None`, then the thread is blocked until the /// shutdown signal is received. pub(crate) fn wait(&mut self, timeout: Option) { - use crate::runtime::enter::{enter, try_enter}; + use crate::runtime::enter::try_enter; - let mut e = if std::thread::panicking() { - match try_enter(false) { - Some(enter) => enter, - _ => return, + if timeout == Some(Duration::from_nanos(0)) { + return; + } + + let mut e = match try_enter(false) { + Some(enter) => enter, + _ => { + if std::thread::panicking() { + // Don't panic in a panic + return; + } else { + panic!( + "Cannot drop a runtime in a context where blocking is not allowed. \ + This happens when a runtime is dropped from within an asynchronous context." + ); + } } - } else { - enter(false) }; // The oneshot completes with an Err diff --git a/tokio/src/runtime/blocking/task.rs b/tokio/src/runtime/blocking/task.rs index f98b85494cc..a521af4630c 100644 --- a/tokio/src/runtime/blocking/task.rs +++ b/tokio/src/runtime/blocking/task.rs @@ -3,25 +3,28 @@ use std::pin::Pin; use std::task::{Context, Poll}; /// Converts a function to a future that completes on poll -pub(super) struct BlockingTask { +pub(crate) struct BlockingTask { func: Option, } impl BlockingTask { /// Initializes a new blocking task from the given function - pub(super) fn new(func: T) -> BlockingTask { + pub(crate) fn new(func: T) -> BlockingTask { BlockingTask { func: Some(func) } } } +// The closure `F` is never pinned +impl Unpin for BlockingTask {} + impl Future for BlockingTask where T: FnOnce() -> R, { type Output = R; - fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { - let me = unsafe { self.get_unchecked_mut() }; + fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + let me = &mut *self; let func = me .func .take() diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index 7c7b2d3f0f0..fad72c7ad94 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -16,8 +16,8 @@ use std::sync::Arc; /// See function level documentation for details on the various configuration /// settings. /// -/// [`build`]: #method.build -/// [`Builder::new`]: #method.new +/// [`build`]: method@Self::build +/// [`Builder::new`]: method@Self::new /// /// # Examples /// diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index 4af2df23eb7..1b267f481e2 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -47,9 +47,9 @@ cfg_rt_core! { } } -/// Set this [`ThreadContext`] as the current active [`ThreadContext`]. +/// Set this [`Handle`] as the current active [`Handle`]. /// -/// [`ThreadContext`]: struct@ThreadContext +/// [`Handle`]: Handle pub(crate) fn enter(new: Handle, f: F) -> R where F: FnOnce() -> R, diff --git a/tokio/src/runtime/enter.rs b/tokio/src/runtime/enter.rs index ad5580ccacf..56a7c57b6c6 100644 --- a/tokio/src/runtime/enter.rs +++ b/tokio/src/runtime/enter.rs @@ -142,12 +142,11 @@ cfg_block_on! { impl Enter { /// Blocks the thread on the specified future, returning the value with /// which that future completes. - pub(crate) fn block_on(&mut self, mut f: F) -> Result + pub(crate) fn block_on(&mut self, f: F) -> Result where F: std::future::Future, { use crate::park::{CachedParkThread, Park}; - use std::pin::Pin; use std::task::Context; use std::task::Poll::Ready; @@ -155,9 +154,7 @@ cfg_block_on! { let waker = park.get_unpark()?.into_waker(); let mut cx = Context::from_waker(&waker); - // `block_on` takes ownership of `f`. Once it is pinned here, the original `f` binding can - // no longer be accessed, making the pinning safe. - let mut f = unsafe { Pin::new_unchecked(&mut f) }; + pin!(f); loop { if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) { @@ -179,12 +176,11 @@ cfg_blocking_impl! { /// /// If the future completes before `timeout`, the result is returned. If /// `timeout` elapses, then `Err` is returned. - pub(crate) fn block_on_timeout(&mut self, mut f: F, timeout: Duration) -> Result + pub(crate) fn block_on_timeout(&mut self, f: F, timeout: Duration) -> Result where F: std::future::Future, { use crate::park::{CachedParkThread, Park}; - use std::pin::Pin; use std::task::Context; use std::task::Poll::Ready; use std::time::Instant; @@ -193,9 +189,7 @@ cfg_blocking_impl! { let waker = park.get_unpark()?.into_waker(); let mut cx = Context::from_waker(&waker); - // `block_on` takes ownership of `f`. Once it is pinned here, the original `f` binding can - // no longer be accessed, making the pinning safe. - let mut f = unsafe { Pin::new_unchecked(&mut f) }; + pin!(f); let when = Instant::now() + timeout; loop { diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index a5761b6a94c..0716a7fadca 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -1,6 +1,11 @@ use crate::runtime::{blocking, context, io, time, Spawner}; use std::{error, fmt}; +cfg_blocking! { + use crate::runtime::task; + use crate::runtime::blocking::task::BlockingTask; +} + cfg_rt_core! { use crate::task::JoinHandle; @@ -71,11 +76,13 @@ impl Handle { context::enter(self.clone(), f) } - /// Returns a Handle view over the currently running Runtime + /// Returns a `Handle` view over the currently running `Runtime` /// /// # Panic /// - /// This will panic if called outside the context of a Tokio runtime. + /// This will panic if called outside the context of a Tokio runtime. That means that you must + /// call this on one of the threads **being run by the runtime**. Calling this from within a + /// thread created by `std::thread::spawn` (for example) will cause a panic. /// /// # Examples /// @@ -83,6 +90,7 @@ impl Handle { /// block or function running on that runtime. /// /// ``` + /// # use std::thread; /// # use tokio::runtime::Runtime; /// # fn dox() { /// # let rt = Runtime::new().unwrap(); @@ -93,7 +101,16 @@ impl Handle { /// let handle = Handle::current(); /// handle.spawn(async { /// println!("now running in the existing Runtime"); - /// }) + /// }); + /// + /// # let handle = + /// thread::spawn(move || { + /// // Notice that the handle is created outside of this thread and then moved in + /// handle.block_on(async { /* ... */ }) + /// // This next line would cause a panic + /// // let handle2 = Handle::current(); + /// }); + /// # handle.join().unwrap(); /// # }); /// # } /// ``` @@ -263,6 +280,79 @@ cfg_rt_core! { } } +cfg_blocking! { + impl Handle { + /// Runs the provided closure on a thread where blocking is acceptable. + /// + /// In general, issuing a blocking call or performing a lot of compute in a + /// future without yielding is not okay, as it may prevent the executor from + /// driving other futures forward. This function runs the provided closure + /// on a thread dedicated to blocking operations. See the [CPU-bound tasks + /// and blocking code][blocking] section for more information. + /// + /// Tokio will spawn more blocking threads when they are requested through + /// this function until the upper limit configured on the [`Builder`] is + /// reached. This limit is very large by default, because `spawn_blocking` is + /// often used for various kinds of IO operations that cannot be performed + /// asynchronously. When you run CPU-bound code using `spawn_blocking`, you + /// should keep this large upper limit in mind; to run your CPU-bound + /// computations on only a few threads, you should use a separate thread + /// pool such as [rayon] rather than configuring the number of blocking + /// threads. + /// + /// This function is intended for non-async operations that eventually + /// finish on their own. If you want to spawn an ordinary thread, you should + /// use [`thread::spawn`] instead. + /// + /// Closures spawned using `spawn_blocking` cannot be cancelled. When you + /// shut down the executor, it will wait indefinitely for all blocking + /// operations to finish. You can use [`shutdown_timeout`] to stop waiting + /// for them after a certain timeout. Be aware that this will still not + /// cancel the tasks — they are simply allowed to keep running after the + /// method returns. + /// + /// Note that if you are using the [basic scheduler], this function will + /// still spawn additional threads for blocking operations. The basic + /// scheduler's single thread is only used for asynchronous code. + /// + /// [`Builder`]: struct@crate::runtime::Builder + /// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code + /// [rayon]: https://docs.rs/rayon + /// [basic scheduler]: fn@crate::runtime::Builder::basic_scheduler + /// [`thread::spawn`]: fn@std::thread::spawn + /// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// # async fn docs() -> Result<(), Box>{ + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// let handle = rt.handle(); + /// + /// let res = handle.spawn_blocking(move || { + /// // do some compute-heavy work or call synchronous code + /// "done computing" + /// }).await?; + /// + /// assert_eq!(res, "done computing"); + /// # Ok(()) + /// # } + /// ``` + pub fn spawn_blocking(&self, f: F) -> JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let (task, handle) = task::joinable(BlockingTask::new(f)); + let _ = self.blocking_spawner.spawn(task, self); + handle + } + } +} + /// Error returned by `try_current` when no Runtime has been started pub struct TryCurrentError(()); diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 9b7d41a3b44..300a14657bf 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -268,7 +268,7 @@ use std::time::Duration; /// /// [timer]: crate::time /// [mod]: index.html -/// [`new`]: #method.new +/// [`new`]: method@Self::new /// [`Builder`]: struct@Builder /// [`tokio::run`]: fn@run #[derive(Debug)] @@ -548,4 +548,34 @@ impl Runtime { } = self; blocking_pool.shutdown(Some(duration)); } + + /// Shutdown the runtime, without waiting for any spawned tasks to shutdown. + /// + /// This can be useful if you want to drop a runtime from within another runtime. + /// Normally, dropping a runtime will block indefinitely for spawned blocking tasks + /// to complete, which would normally not be permitted within an asynchronous context. + /// By calling `shutdown_background()`, you can drop the runtime from such a context. + /// + /// Note however, that because we do not wait for any blocking tasks to complete, this + /// may result in a resource leak (in that any blocking tasks are still running until they + /// return. + /// + /// This function is equivalent to calling `shutdown_timeout(Duration::of_nanos(0))`. + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// fn main() { + /// let mut runtime = Runtime::new().unwrap(); + /// + /// runtime.block_on(async move { + /// let inner_runtime = Runtime::new().unwrap(); + /// // ... + /// inner_runtime.shutdown_background(); + /// }); + /// } + /// ``` + pub fn shutdown_background(self) { + self.shutdown_timeout(Duration::from_nanos(0)) + } } diff --git a/tokio/src/runtime/task/join.rs b/tokio/src/runtime/task/join.rs index 902b4d4f4d2..c2be3fabd7c 100644 --- a/tokio/src/runtime/task/join.rs +++ b/tokio/src/runtime/task/join.rs @@ -277,7 +277,7 @@ impl Future for JoinHandle { let mut ret = Poll::Pending; // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); // Raw should always be set. If it is not, this is due to polling after // completion @@ -302,6 +302,10 @@ impl Future for JoinHandle { raw.try_read_output(&mut ret as *mut _ as *mut (), cx.waker()); } + if ret.is_ready() { + coop.made_progress(); + } + ret } } diff --git a/tokio/src/runtime/thread_pool/worker.rs b/tokio/src/runtime/thread_pool/worker.rs index e31f237cc36..abe20da59cf 100644 --- a/tokio/src/runtime/thread_pool/worker.rs +++ b/tokio/src/runtime/thread_pool/worker.rs @@ -4,6 +4,7 @@ //! "core" is handed off to a new thread allowing the scheduler to continue to //! make progress while the originating thread blocks. +use crate::coop; use crate::loom::rand::seed; use crate::loom::sync::{Arc, Mutex}; use crate::park::{Park, Unpark}; @@ -179,40 +180,39 @@ cfg_blocking! { F: FnOnce() -> R, { // Try to steal the worker core back - struct Reset(bool); + struct Reset(coop::Budget); impl Drop for Reset { fn drop(&mut self) { CURRENT.with(|maybe_cx| { - if !self.0 { - // We were not the ones to give away the core, - // so we do not get to restore it either. - // This is necessary so that with a nested - // block_in_place, the inner block_in_place - // does not restore the core. - return; - } - if let Some(cx) = maybe_cx { let core = cx.worker.core.take(); let mut cx_core = cx.core.borrow_mut(); assert!(cx_core.is_none()); *cx_core = core; + + // Reset the task budget as we are re-entering the + // runtime. + coop::set(self.0); } }); } } let mut had_core = false; + let mut had_entered = false; + CURRENT.with(|maybe_cx| { match (crate::runtime::enter::context(), maybe_cx.is_some()) { (EnterContext::Entered { .. }, true) => { // We are on a thread pool runtime thread, so we just need to set up blocking. + had_entered = true; } (EnterContext::Entered { allow_blocking }, false) => { // We are on an executor, but _not_ on the thread pool. // That is _only_ okay if we are in a thread pool runtime's block_on method: if allow_blocking { + had_entered = true; return; } else { // This probably means we are on the basic_scheduler or in a LocalSet, @@ -231,16 +231,12 @@ cfg_blocking! { return; } } + let cx = maybe_cx.expect("no .is_some() == false cases above should lead here"); // Get the worker core. If none is set, then blocking is fine! let core = match cx.core.borrow_mut().take() { - Some(core) => { - // We are effectively leaving the executor, so we need to - // forcibly end budgeting. - crate::coop::stop(); - core - }, + Some(core) => core, None => return, }; @@ -263,9 +259,13 @@ cfg_blocking! { runtime::spawn_blocking(move || run(worker)); }); - let _reset = Reset(had_core); - if had_core { + // Unset the current task's budget. Blocking sections are not + // constrained by task budgets. + let _reset = Reset(coop::stop()); + + crate::runtime::enter::exit(f) + } else if had_entered { crate::runtime::enter::exit(f) } else { f() @@ -349,7 +349,7 @@ impl Context { *self.core.borrow_mut() = Some(core); // Run the task - crate::coop::budget(|| { + coop::budget(|| { task.run(); // As long as there is budget remaining and a task exists in the @@ -368,7 +368,7 @@ impl Context { None => return Ok(core), }; - if crate::coop::has_budget_remaining() { + if coop::has_budget_remaining() { // Run the LIFO task, then loop *self.core.borrow_mut() = Some(core); task.run(); @@ -575,7 +575,7 @@ impl Core { } // Drain the queue - while let Some(_) = self.next_local_task() {} + while self.next_local_task().is_some() {} } fn drain_pending_drop(&mut self, worker: &Worker) { @@ -797,7 +797,7 @@ impl Shared { } // Drain the injection queue - while let Some(_) = self.inject.pop() {} + while self.inject.pop().is_some() {} } fn ptr_eq(&self, other: &Shared) -> bool { diff --git a/tokio/src/signal/unix.rs b/tokio/src/signal/unix.rs index 06f5cf4eba7..b46b15c99a6 100644 --- a/tokio/src/signal/unix.rs +++ b/tokio/src/signal/unix.rs @@ -401,7 +401,7 @@ pub struct Signal { /// * If the lower-level C functions fail for some reason. /// * If the previous initialization of this specific signal failed. /// * If the signal is one of -/// [`signal_hook::FORBIDDEN`](https://docs.rs/signal-hook/*/signal_hook/fn.register.html#panics) +/// [`signal_hook::FORBIDDEN`](fn@signal_hook_registry::register#panics) pub fn signal(kind: SignalKind) -> io::Result { let signal = kind.0; diff --git a/tokio/src/stream/collect.rs b/tokio/src/stream/collect.rs index f44c72b7b36..46494287cd8 100644 --- a/tokio/src/stream/collect.rs +++ b/tokio/src/stream/collect.rs @@ -32,7 +32,7 @@ pin_project! { /// /// Currently, this trait may not be implemented by third parties. The trait is /// sealed in order to make changes in the future. Stabilization is pending -/// enhancements to the Rust langague. +/// enhancements to the Rust language. pub trait FromStream: sealed::FromStreamPriv {} impl Collect diff --git a/tokio/src/stream/empty.rs b/tokio/src/stream/empty.rs index 6118673e504..2f56ac6cad3 100644 --- a/tokio/src/stream/empty.rs +++ b/tokio/src/stream/empty.rs @@ -4,7 +4,7 @@ use core::marker::PhantomData; use core::pin::Pin; use core::task::{Context, Poll}; -/// Stream for the [`empty`] function. +/// Stream for the [`empty`](fn@empty) function. #[derive(Debug)] #[must_use = "streams do nothing unless polled"] pub struct Empty(PhantomData); diff --git a/tokio/src/stream/iter.rs b/tokio/src/stream/iter.rs index 36eeb5612f7..bc0388a1442 100644 --- a/tokio/src/stream/iter.rs +++ b/tokio/src/stream/iter.rs @@ -3,7 +3,7 @@ use crate::stream::Stream; use core::pin::Pin; use core::task::{Context, Poll}; -/// Stream for the [`iter`] function. +/// Stream for the [`iter`](fn@iter) function. #[derive(Debug)] #[must_use = "streams do nothing unless polled"] pub struct Iter { @@ -45,7 +45,8 @@ where type Item = I::Item; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); + coop.made_progress(); Poll::Ready(self.iter.next()) } diff --git a/tokio/src/stream/mod.rs b/tokio/src/stream/mod.rs index a59bdfcd48f..7b061efeb69 100644 --- a/tokio/src/stream/mod.rs +++ b/tokio/src/stream/mod.rs @@ -192,12 +192,17 @@ pub trait StreamExt: Stream { /// Values are produced from the merged stream in the order they arrive from /// the two source streams. If both source streams provide values /// simultaneously, the merge stream alternates between them. This provides - /// some level of fairness. + /// some level of fairness. You should not chain calls to `merge`, as this + /// will break the fairness of the merging. /// /// The merged stream completes once **both** source streams complete. When /// one source stream completes before the other, the merge stream /// exclusively polls the remaining stream. /// + /// For merging multiple streams, consider using [`StreamMap`] instead. + /// + /// [`StreamMap`]: crate::stream::StreamMap + /// /// # Examples /// /// ``` @@ -303,7 +308,7 @@ pub trait StreamExt: Stream { /// As values of this stream are made available, the provided function will /// be run on them. If the predicate `f` resolves to /// [`Some(item)`](Some) then the stream will yield the value `item`, but if - /// it resolves to [`None`] then the next value will be produced. + /// it resolves to [`None`], then the value will be skipped. /// /// Note that this function consumes the stream passed into it and returns a /// wrapped version of it, similar to [`Iterator::filter_map`] method in the @@ -697,7 +702,7 @@ pub trait StreamExt: Stream { /// # Notes /// /// `FromStream` is currently a sealed trait. Stabilization is pending - /// enhancements to the Rust langague. + /// enhancements to the Rust language. /// /// # Examples /// diff --git a/tokio/src/stream/once.rs b/tokio/src/stream/once.rs index 04a642f309a..7fe204cc127 100644 --- a/tokio/src/stream/once.rs +++ b/tokio/src/stream/once.rs @@ -4,7 +4,7 @@ use core::option; use core::pin::Pin; use core::task::{Context, Poll}; -/// Stream for the [`once`] function. +/// Stream for the [`once`](fn@once) function. #[derive(Debug)] #[must_use = "streams do nothing unless polled"] pub struct Once { diff --git a/tokio/src/stream/pending.rs b/tokio/src/stream/pending.rs index 2e06a1c2612..21224c38596 100644 --- a/tokio/src/stream/pending.rs +++ b/tokio/src/stream/pending.rs @@ -4,7 +4,7 @@ use core::marker::PhantomData; use core::pin::Pin; use core::task::{Context, Poll}; -/// Stream for the [`pending`] function. +/// Stream for the [`pending`](fn@pending) function. #[derive(Debug)] #[must_use = "streams do nothing unless polled"] pub struct Pending(PhantomData); diff --git a/tokio/src/sync/batch_semaphore.rs b/tokio/src/sync/batch_semaphore.rs index 436737a6709..070cd2010c4 100644 --- a/tokio/src/sync/batch_semaphore.rs +++ b/tokio/src/sync/batch_semaphore.rs @@ -53,7 +53,7 @@ pub(crate) struct AcquireError(()); pub(crate) struct Acquire<'a> { node: Waiter, semaphore: &'a Semaphore, - num_permits: u16, + num_permits: u32, queued: bool, } @@ -103,6 +103,8 @@ impl Semaphore { const PERMIT_SHIFT: usize = 1; /// Creates a new semaphore with the initial number of permits + /// + /// Maximum number of permits on 32-bit platforms is `1<<29`. pub(crate) fn new(permits: usize) -> Self { assert!( permits <= Self::MAX_PERMITS, @@ -123,7 +125,9 @@ impl Semaphore { self.permits.load(Acquire) >> Self::PERMIT_SHIFT } - /// Adds `n` new permits to the semaphore. + /// Adds `added` new permits to the semaphore. + /// + /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. pub(crate) fn release(&self, added: usize) { if added == 0 { return; @@ -157,9 +161,14 @@ impl Semaphore { } } - pub(crate) fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> { - let mut curr = self.permits.load(Acquire); + pub(crate) fn try_acquire(&self, num_permits: u32) -> Result<(), TryAcquireError> { + assert!( + num_permits as usize <= Self::MAX_PERMITS, + "a semaphore may not have more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT; + let mut curr = self.permits.load(Acquire); loop { // Has the semaphore closed?git if curr & Self::CLOSED > 0 { @@ -180,13 +189,13 @@ impl Semaphore { } } - pub(crate) fn acquire(&self, num_permits: u16) -> Acquire<'_> { + pub(crate) fn acquire(&self, num_permits: u32) -> Acquire<'_> { Acquire::new(self, num_permits) } /// Release `rem` permits to the semaphore's wait list, starting from the /// end of the queue. - /// + /// /// If `rem` exceeds the number of permits needed by the wait list, the /// remainder are assigned back to the semaphore. fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) { @@ -245,7 +254,7 @@ impl Semaphore { fn poll_acquire( &self, cx: &mut Context<'_>, - num_permits: u16, + num_permits: u32, node: Pin<&mut Waiter>, queued: bool, ) -> Poll> { @@ -354,7 +363,7 @@ impl fmt::Debug for Semaphore { } impl Waiter { - fn new(num_permits: u16) -> Self { + fn new(num_permits: u32) -> Self { Waiter { waker: UnsafeCell::new(None), state: AtomicUsize::new(num_permits as usize), @@ -386,13 +395,18 @@ impl Future for Acquire<'_> { type Output = Result<(), AcquireError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // First, ensure the current task has enough budget to proceed. + let coop = ready!(crate::coop::poll_proceed(cx)); + let (node, semaphore, needed, queued) = self.project(); + match semaphore.poll_acquire(cx, needed, node, *queued) { Pending => { *queued = true; Pending } Ready(r) => { + coop.made_progress(); r?; *queued = false; Ready(Ok(())) @@ -402,7 +416,7 @@ impl Future for Acquire<'_> { } impl<'a> Acquire<'a> { - fn new(semaphore: &'a Semaphore, num_permits: u16) -> Self { + fn new(semaphore: &'a Semaphore, num_permits: u32) -> Self { Self { node: Waiter::new(num_permits), semaphore, @@ -411,14 +425,14 @@ impl<'a> Acquire<'a> { } } - fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, u16, &mut bool) { + fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, u32, &mut bool) { fn is_unpin() {} unsafe { // Safety: all fields other than `node` are `Unpin` is_unpin::<&Semaphore>(); is_unpin::<&mut bool>(); - is_unpin::(); + is_unpin::(); let this = self.get_unchecked_mut(); ( @@ -512,8 +526,8 @@ impl TryAcquireError { impl fmt::Display for TryAcquireError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - TryAcquireError::Closed => write!(fmt, "{}", "semaphore closed"), - TryAcquireError::NoPermits => write!(fmt, "{}", "no permits available"), + TryAcquireError::Closed => write!(fmt, "semaphore closed"), + TryAcquireError::NoPermits => write!(fmt, "no permits available"), } } } diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index abc4974a3f4..cd62ffd5ae1 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -109,13 +109,15 @@ //! } use crate::loom::cell::UnsafeCell; -use crate::loom::future::AtomicWaker; -use crate::loom::sync::atomic::{spin_loop_hint, AtomicBool, AtomicPtr, AtomicUsize}; -use crate::loom::sync::{Arc, Condvar, Mutex}; +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard}; +use crate::util::linked_list::{self, LinkedList}; use std::fmt; -use std::mem; -use std::ptr; +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::ptr::NonNull; use std::sync::atomic::Ordering::SeqCst; use std::task::{Context, Poll, Waker}; use std::usize; @@ -193,8 +195,8 @@ pub struct Receiver { /// Next position to read from next: u64, - /// Waiter state - wait: Arc, + /// Used to support the deprecated `poll_recv` fn + waiter: Option>>>, } /// Error returned by [`Sender::send`][Sender::send]. @@ -247,20 +249,14 @@ pub enum TryRecvError { /// Data shared between senders and receivers struct Shared { /// slots in the channel - buffer: Box<[Slot]>, + buffer: Box<[RwLock>]>, /// Mask a position -> index mask: usize, - /// Tail of the queue + /// Tail of the queue. Includes the rx wait list. tail: Mutex, - /// Notifies a sender that the slot is unlocked - condvar: Condvar, - - /// Stack of pending waiters - wait_stack: AtomicPtr, - /// Number of outstanding Sender handles num_tx: AtomicUsize, } @@ -275,6 +271,9 @@ struct Tail { /// True if the channel is closed closed: bool, + + /// Receivers waiting for a value + waiters: LinkedList, } /// Slot in the buffer @@ -282,50 +281,79 @@ struct Slot { /// Remaining number of receivers that are expected to see this value. /// /// When this goes to zero, the value is released. + /// + /// An atomic is used as it is mutated concurrently with the slot read lock + /// acquired. rem: AtomicUsize, - /// Used to lock the `write` field. - lock: AtomicUsize, + /// Uniquely identifies the `send` stored in the slot + pos: u64, - /// The value being broadcast + /// True signals the channel is closed. + closed: bool, + + /// The value being broadcast. /// - /// Synchronized by `state` - write: Write, + /// The value is set by `send` when the write lock is held. When a reader + /// drops, `rem` is decremented. When it hits zero, the value is dropped. + val: UnsafeCell>, } -/// A write in the buffer -struct Write { - /// Uniquely identifies this write - pos: UnsafeCell, +/// An entry in the wait queue +struct Waiter { + /// True if queued + queued: bool, - /// The written value - val: UnsafeCell>, + /// Task waiting on the broadcast channel. + waker: Option, + + /// Intrusive linked-list pointers. + pointers: linked_list::Pointers, + + /// Should not be `Unpin`. + _p: PhantomPinned, } -/// Tracks a waiting receiver -#[derive(Debug)] -struct WaitNode { - /// `true` if queued - queued: AtomicBool, +struct RecvGuard<'a, T> { + slot: RwLockReadGuard<'a, Slot>, +} - /// Task to wake when a permit is made available. - waker: AtomicWaker, +/// Receive a value future +struct Recv +where + R: AsMut>, +{ + /// Receiver being waited on + receiver: R, - /// Next pointer in the stack of waiting senders. - next: UnsafeCell<*const WaitNode>, + /// Entry in the waiter `LinkedList` + waiter: UnsafeCell, + + _p: std::marker::PhantomData, } -struct RecvGuard<'a, T> { - slot: &'a Slot, - tail: &'a Mutex, - condvar: &'a Condvar, +/// `AsMut` is not implemented for `T` (coherence). Explicitly implementing +/// `AsMut` for `Receiver` would be included in the public API of the receiver +/// type. Instead, `Borrow` is used internally to bridge the gap. +struct Borrow(T); + +impl AsMut> for Borrow> { + fn as_mut(&mut self) -> &mut Receiver { + &mut self.0 + } +} + +impl<'a, T> AsMut> for Borrow<&'a mut Receiver> { + fn as_mut(&mut self) -> &mut Receiver { + &mut *self.0 + } } +unsafe impl> + Send, T: Send> Send for Recv {} +unsafe impl> + Sync, T: Send> Sync for Recv {} + /// Max number of receivers. Reserve space to lock. const MAX_RECEIVERS: usize = usize::MAX >> 2; -const CLOSED: usize = 1; -const WRITER: usize = 2; -const READER: usize = 4; /// Create a bounded, multi-producer, multi-consumer channel where each sent /// value is broadcasted to all active receivers. @@ -382,14 +410,12 @@ pub fn channel(mut capacity: usize) -> (Sender, Receiver) { let mut buffer = Vec::with_capacity(capacity); for i in 0..capacity { - buffer.push(Slot { + buffer.push(RwLock::new(Slot { rem: AtomicUsize::new(0), - lock: AtomicUsize::new(0), - write: Write { - pos: UnsafeCell::new((i as u64).wrapping_sub(capacity as u64)), - val: UnsafeCell::new(None), - }, - }); + pos: (i as u64).wrapping_sub(capacity as u64), + closed: false, + val: UnsafeCell::new(None), + })); } let shared = Arc::new(Shared { @@ -399,20 +425,15 @@ pub fn channel(mut capacity: usize) -> (Sender, Receiver) { pos: 0, rx_cnt: 1, closed: false, + waiters: LinkedList::new(), }), - condvar: Condvar::new(), - wait_stack: AtomicPtr::new(ptr::null_mut()), num_tx: AtomicUsize::new(1), }); let rx = Receiver { shared: shared.clone(), next: 0, - wait: Arc::new(WaitNode { - queued: AtomicBool::new(false), - waker: AtomicWaker::new(), - next: UnsafeCell::new(ptr::null()), - }), + waiter: None, }; let tx = Sender { shared }; @@ -522,11 +543,7 @@ impl Sender { Receiver { shared, next, - wait: Arc::new(WaitNode { - queued: AtomicBool::new(false), - waker: AtomicWaker::new(), - next: UnsafeCell::new(ptr::null()), - }), + waiter: None, } } @@ -587,71 +604,47 @@ impl Sender { tail.pos = tail.pos.wrapping_add(1); // Get the slot - let slot = &self.shared.buffer[idx]; - - // Acquire the write lock - let mut prev = slot.lock.fetch_or(WRITER, SeqCst); - - while prev & !WRITER != 0 { - // Concurrent readers, we must go to sleep - tail = self.shared.condvar.wait(tail).unwrap(); - - prev = slot.lock.load(SeqCst); - - if prev & WRITER == 0 { - // The writer lock bit was cleared while this thread was - // sleeping. This can only happen if a newer write happened on - // this slot by another thread. Bail early as an optimization, - // there is nothing left to do. - return Ok(rem); - } - } - - if tail.pos.wrapping_sub(pos) > self.shared.buffer.len() as u64 { - // There is a newer pending write to the same slot. - return Ok(rem); - } + let mut slot = self.shared.buffer[idx].write().unwrap(); - // Slot lock acquired - slot.write.pos.with_mut(|ptr| unsafe { *ptr = pos }); + // Track the position + slot.pos = pos; // Set remaining receivers - slot.rem.store(rem, SeqCst); + slot.rem.with_mut(|v| *v = rem); // Set the closed bit if the value is `None`; otherwise write the value if value.is_none() { tail.closed = true; - slot.lock.store(CLOSED, SeqCst); + slot.closed = true; } else { - slot.write.val.with_mut(|ptr| unsafe { *ptr = value }); - slot.lock.store(0, SeqCst); + slot.val.with_mut(|ptr| unsafe { *ptr = value }); } + // Release the slot lock before notifying the receivers. + drop(slot); + + tail.notify_rx(); + // Release the mutex. This must happen after the slot lock is released, // otherwise the writer lock bit could be cleared while another thread // is in the critical section. drop(tail); - // Notify waiting receivers - self.notify_rx(); - Ok(rem) } +} - fn notify_rx(&self) { - let mut curr = self.shared.wait_stack.swap(ptr::null_mut(), SeqCst) as *const WaitNode; - - while !curr.is_null() { - let waiter = unsafe { Arc::from_raw(curr) }; - - // Update `curr` before toggling `queued` and waking - curr = waiter.next.with(|ptr| unsafe { *ptr }); +impl Tail { + fn notify_rx(&mut self) { + while let Some(mut waiter) = self.waiters.pop_back() { + // Safety: `waiters` lock is still held. + let waiter = unsafe { waiter.as_mut() }; - // Unset queued - waiter.queued.store(false, SeqCst); + assert!(waiter.queued); + waiter.queued = false; - // Wake - waiter.waker.wake(); + let waker = waiter.waker.take().unwrap(); + waker.wake(); } } } @@ -675,81 +668,119 @@ impl Drop for Sender { impl Receiver { /// Locks the next value if there is one. - /// - /// The caller is responsible for unlocking - fn recv_ref(&mut self, spin: bool) -> Result, TryRecvError> { + fn recv_ref( + &mut self, + waiter: Option<(&UnsafeCell, &Waker)>, + ) -> Result, TryRecvError> { let idx = (self.next & self.shared.mask as u64) as usize; // The slot holding the next value to read - let slot = &self.shared.buffer[idx]; + let mut slot = self.shared.buffer[idx].read().unwrap(); - // Lock the slot - if !slot.try_rx_lock() { - if spin { - while !slot.try_rx_lock() { - spin_loop_hint(); - } - } else { + if slot.pos != self.next { + let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); + + // The receiver has read all current values in the channel and there + // is no waiter to register + if waiter.is_none() && next_pos == self.next { return Err(TryRecvError::Empty); } - } - let guard = RecvGuard { - slot, - tail: &self.shared.tail, - condvar: &self.shared.condvar, - }; + // Release the `slot` lock before attempting to acquire the `tail` + // lock. This is required because `send2` acquires the tail lock + // first followed by the slot lock. Acquiring the locks in reverse + // order here would result in a potential deadlock: `recv_ref` + // acquires the `slot` lock and attempts to acquire the `tail` lock + // while `send2` acquired the `tail` lock and attempts to acquire + // the slot lock. + drop(slot); + + let mut tail = self.shared.tail.lock().unwrap(); + + // Acquire slot lock again + slot = self.shared.buffer[idx].read().unwrap(); + + // Make sure the position did not change. This could happen in the + // unlikely event that the buffer is wrapped between dropping the + // read lock and acquiring the tail lock. + if slot.pos != self.next { + let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); + + if next_pos == self.next { + // Store the waker + if let Some((waiter, waker)) = waiter { + // Safety: called while locked. + unsafe { + // Only queue if not already queued + waiter.with_mut(|ptr| { + // If there is no waker **or** if the currently + // stored waker references a **different** task, + // track the tasks' waker to be notified on + // receipt of a new value. + match (*ptr).waker { + Some(ref w) if w.will_wake(waker) => {} + _ => { + (*ptr).waker = Some(waker.clone()); + } + } + + if !(*ptr).queued { + (*ptr).queued = true; + tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr)); + } + }); + } + } + + return Err(TryRecvError::Empty); + } - if guard.pos() != self.next { - let pos = guard.pos(); + // At this point, the receiver has lagged behind the sender by + // more than the channel capacity. The receiver will attempt to + // catch up by skipping dropped messages and setting the + // internal cursor to the **oldest** message stored by the + // channel. + // + // However, finding the oldest position is a bit more + // complicated than `tail-position - buffer-size`. When + // the channel is closed, the tail position is incremented to + // signal a new `None` message, but `None` is not stored in the + // channel itself (see issue #2425 for why). + // + // To account for this, if the channel is closed, the tail + // position is decremented by `buffer-size + 1`. + let mut adjust = 0; + if tail.closed { + adjust = 1 + } + let next = tail + .pos + .wrapping_sub(self.shared.buffer.len() as u64 + adjust); - // The receiver has read all current values in the channel - if pos.wrapping_add(self.shared.buffer.len() as u64) == self.next { - guard.drop_no_rem_dec(); - return Err(TryRecvError::Empty); - } + let missed = next.wrapping_sub(self.next); - let tail = self.shared.tail.lock().unwrap(); + drop(tail); - // `tail.pos` points to the slot that the **next** send writes to. If - // the channel is closed, the previous slot is the oldest value. - let mut adjust = 0; - if tail.closed { - adjust = 1 - } - let next = tail - .pos - .wrapping_sub(self.shared.buffer.len() as u64 + adjust); + // The receiver is slow but no values have been missed + if missed == 0 { + self.next = self.next.wrapping_add(1); - let missed = next.wrapping_sub(self.next); + return Ok(RecvGuard { slot }); + } - drop(tail); + self.next = next; - // The receiver is slow but no values have been missed - if missed == 0 { - self.next = self.next.wrapping_add(1); - return Ok(guard); + return Err(TryRecvError::Lagged(missed)); } - - guard.drop_no_rem_dec(); - self.next = next; - - return Err(TryRecvError::Lagged(missed)); } self.next = self.next.wrapping_add(1); - // If the `CLOSED` bit it set on the slot, the channel is closed - // - // `try_rx_lock` could check for this and bail early. If it's return - // value was changed to represent the state of the lock, it could - // match on being closed, empty, or available for reading. - if slot.lock.load(SeqCst) & CLOSED == CLOSED { - guard.drop_no_rem_dec(); + if slot.closed { return Err(TryRecvError::Closed); } - Ok(guard) + Ok(RecvGuard { slot }) } } @@ -777,6 +808,7 @@ where /// receive, `Err(TryRecvError::Empty)` is returned. /// /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv /// [`Receiver`]: crate::sync::broadcast::Receiver /// /// # Examples @@ -797,22 +829,59 @@ where /// } /// ``` pub fn try_recv(&mut self) -> Result { - let guard = self.recv_ref(false)?; + let guard = self.recv_ref(None)?; guard.clone_value().ok_or(TryRecvError::Closed) } - #[doc(hidden)] // TODO: document + #[doc(hidden)] + #[deprecated(since = "0.2.21", note = "use async fn recv()")] pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Some(value) = ok_empty(self.try_recv())? { - return Poll::Ready(Ok(value)); + use Poll::{Pending, Ready}; + + // The borrow checker prohibits calling `self.poll_ref` while passing in + // a mutable ref to a field (as it should). To work around this, + // `waiter` is first *removed* from `self` then `poll_recv` is called. + // + // However, for safety, we must ensure that `waiter` is **not** dropped. + // It could be contained in the intrusive linked list. The `Receiver` + // drop implementation handles cleanup. + // + // The guard pattern is used to ensure that, on return, even due to + // panic, the waiter node is replaced on `self`. + + struct Guard<'a, T> { + waiter: Option>>>, + receiver: &'a mut Receiver, } - self.register_waker(cx.waker()); + impl<'a, T> Drop for Guard<'a, T> { + fn drop(&mut self) { + self.receiver.waiter = self.waiter.take(); + } + } - if let Some(value) = ok_empty(self.try_recv())? { - Poll::Ready(Ok(value)) - } else { - Poll::Pending + let waiter = self.waiter.take().or_else(|| { + Some(Box::pin(UnsafeCell::new(Waiter { + queued: false, + waker: None, + pointers: linked_list::Pointers::new(), + _p: PhantomPinned, + }))) + }); + + let guard = Guard { + waiter, + receiver: self, + }; + let res = guard + .receiver + .recv_ref(Some((&guard.waiter.as_ref().unwrap(), cx.waker()))); + + match res { + Ok(guard) => Ready(guard.clone_value().ok_or(RecvError::Closed)), + Err(TryRecvError::Closed) => Ready(Err(RecvError::Closed)), + Err(TryRecvError::Lagged(n)) => Ready(Err(RecvError::Lagged(n))), + Err(TryRecvError::Empty) => Pending, } } @@ -881,44 +950,14 @@ where /// assert_eq!(30, rx.recv().await.unwrap()); /// } pub async fn recv(&mut self) -> Result { - use crate::future::poll_fn; - - poll_fn(|cx| self.poll_recv(cx)).await - } - - fn register_waker(&self, cx: &Waker) { - self.wait.waker.register_by_ref(cx); - - if !self.wait.queued.load(SeqCst) { - // Set `queued` before queuing. - self.wait.queued.store(true, SeqCst); - - let mut curr = self.shared.wait_stack.load(SeqCst); - - // The ref count is decremented in `notify_rx` when all nodes are - // removed from the waiter stack. - let node = Arc::into_raw(self.wait.clone()) as *mut _; - - loop { - // Safety: `queued == false` means the caller has exclusive - // access to `self.wait.next`. - self.wait.next.with_mut(|ptr| unsafe { *ptr = curr }); - - let res = self - .shared - .wait_stack - .compare_exchange(curr, node, SeqCst, SeqCst); - - match res { - Ok(_) => return, - Err(actual) => curr = actual, - } - } - } + let fut = Recv::<_, T>::new(Borrow(self)); + fut.await } } #[cfg(feature = "stream")] +#[doc(hidden)] +#[deprecated(since = "0.2.21", note = "use `into_stream()`")] impl crate::stream::Stream for Receiver where T: Clone, @@ -929,6 +968,7 @@ where mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { + #[allow(deprecated)] self.poll_recv(cx).map(|v| match v { Ok(v) => Some(Ok(v)), lag @ Err(RecvError::Lagged(_)) => Some(lag), @@ -941,13 +981,30 @@ impl Drop for Receiver { fn drop(&mut self) { let mut tail = self.shared.tail.lock().unwrap(); + if let Some(waiter) = &self.waiter { + // safety: tail lock is held + let queued = waiter.with(|ptr| unsafe { (*ptr).queued }); + + if queued { + // Remove the node + // + // safety: tail lock is held and the wait node is verified to be in + // the list. + unsafe { + waiter.with_mut(|ptr| { + tail.waiters.remove((&mut *ptr).into()); + }); + } + } + } + tail.rx_cnt -= 1; let until = tail.pos; drop(tail); while self.next != until { - match self.recv_ref(true) { + match self.recv_ref(None) { Ok(_) => {} // The channel is closed Err(TryRecvError::Closed) => break, @@ -960,105 +1017,198 @@ impl Drop for Receiver { } } -impl Drop for Shared { - fn drop(&mut self) { - // Clear the wait stack - let mut curr = self.wait_stack.with_mut(|ptr| *ptr as *const WaitNode); +impl Recv +where + R: AsMut>, +{ + fn new(receiver: R) -> Recv { + Recv { + receiver, + waiter: UnsafeCell::new(Waiter { + queued: false, + waker: None, + pointers: linked_list::Pointers::new(), + _p: PhantomPinned, + }), + _p: std::marker::PhantomData, + } + } - while !curr.is_null() { - let waiter = unsafe { Arc::from_raw(curr) }; - curr = waiter.next.with(|ptr| unsafe { *ptr }); + /// A custom `project` implementation is used in place of `pin-project-lite` + /// as a custom drop implementation is needed. + fn project(self: Pin<&mut Self>) -> (&mut Receiver, &UnsafeCell) { + unsafe { + // Safety: Receiver is Unpin + is_unpin::<&mut Receiver>(); + + let me = self.get_unchecked_mut(); + (me.receiver.as_mut(), &me.waiter) } } } -impl fmt::Debug for Sender { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "broadcast::Sender") +impl Future for Recv +where + R: AsMut>, + T: Clone, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let (receiver, waiter) = self.project(); + + let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) { + Ok(value) => value, + Err(TryRecvError::Empty) => return Poll::Pending, + Err(TryRecvError::Lagged(n)) => return Poll::Ready(Err(RecvError::Lagged(n))), + Err(TryRecvError::Closed) => return Poll::Ready(Err(RecvError::Closed)), + }; + + Poll::Ready(guard.clone_value().ok_or(RecvError::Closed)) } } -impl fmt::Debug for Receiver { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "broadcast::Receiver") +cfg_stream! { + use futures_core::Stream; + + impl Receiver { + /// Convert the receiver into a `Stream`. + /// + /// The conversion allows using `Receiver` with APIs that require stream + /// values. + /// + /// # Examples + /// + /// ``` + /// use tokio::stream::StreamExt; + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = broadcast::channel(128); + /// + /// tokio::spawn(async move { + /// for i in 0..10_i32 { + /// tx.send(i).unwrap(); + /// } + /// }); + /// + /// // Streams must be pinned to iterate. + /// tokio::pin! { + /// let stream = rx + /// .into_stream() + /// .filter(Result::is_ok) + /// .map(Result::unwrap) + /// .filter(|v| v % 2 == 0) + /// .map(|v| v + 1); + /// } + /// + /// while let Some(i) = stream.next().await { + /// println!("{}", i); + /// } + /// } + /// ``` + pub fn into_stream(self) -> impl Stream> { + Recv::new(Borrow(self)) + } } -} -impl Slot { - /// Tries to lock the slot for a receiver. If `false`, then a sender holds the - /// lock and the calling task will be notified once the sender has released - /// the lock. - fn try_rx_lock(&self) -> bool { - let mut curr = self.lock.load(SeqCst); - - loop { - if curr & WRITER == WRITER { - // Locked by sender - return false; - } + impl Stream for Recv + where + R: AsMut>, + T: Clone, + { + type Item = Result; - // Only increment (by `READER`) if the `WRITER` bit is not set. - let res = self - .lock - .compare_exchange(curr, curr + READER, SeqCst, SeqCst); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let (receiver, waiter) = self.project(); - match res { - Ok(_) => return true, - Err(actual) => curr = actual, - } + let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) { + Ok(value) => value, + Err(TryRecvError::Empty) => return Poll::Pending, + Err(TryRecvError::Lagged(n)) => return Poll::Ready(Some(Err(RecvError::Lagged(n)))), + Err(TryRecvError::Closed) => return Poll::Ready(None), + }; + + Poll::Ready(guard.clone_value().map(Ok)) } } +} - fn rx_unlock(&self, tail: &Mutex, condvar: &Condvar, rem_dec: bool) { - if rem_dec { - // Decrement the remaining counter - if 1 == self.rem.fetch_sub(1, SeqCst) { - // Last receiver, drop the value - self.write.val.with_mut(|ptr| unsafe { *ptr = None }); +impl Drop for Recv +where + R: AsMut>, +{ + fn drop(&mut self) { + // Acquire the tail lock. This is required for safety before accessing + // the waiter node. + let mut tail = self.receiver.as_mut().shared.tail.lock().unwrap(); + + // safety: tail lock is held + let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued }); + + if queued { + // Remove the node + // + // safety: tail lock is held and the wait node is verified to be in + // the list. + unsafe { + self.waiter.with_mut(|ptr| { + tail.waiters.remove((&mut *ptr).into()); + }); } } + } +} - if WRITER == self.lock.fetch_sub(READER, SeqCst) - READER { - // First acquire the lock to make sure our sender is waiting on the - // condition variable, otherwise the notification could be lost. - mem::drop(tail.lock().unwrap()); - // Wake up senders - condvar.notify_all(); - } +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + type Handle = NonNull; + type Target = Waiter; + + fn as_raw(handle: &NonNull) -> NonNull { + *handle + } + + unsafe fn from_raw(ptr: NonNull) -> NonNull { + ptr + } + + unsafe fn pointers(mut target: NonNull) -> NonNull> { + NonNull::from(&mut target.as_mut().pointers) } } -impl<'a, T> RecvGuard<'a, T> { - fn pos(&self) -> u64 { - self.slot.write.pos.with(|ptr| unsafe { *ptr }) +impl fmt::Debug for Sender { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "broadcast::Sender") } +} +impl fmt::Debug for Receiver { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "broadcast::Receiver") + } +} + +impl<'a, T> RecvGuard<'a, T> { fn clone_value(&self) -> Option where T: Clone, { - self.slot.write.val.with(|ptr| unsafe { (*ptr).clone() }) - } - - fn drop_no_rem_dec(self) { - self.slot.rx_unlock(self.tail, self.condvar, false); - - mem::forget(self); + self.slot.val.with(|ptr| unsafe { (*ptr).clone() }) } } impl<'a, T> Drop for RecvGuard<'a, T> { fn drop(&mut self) { - self.slot.rx_unlock(self.tail, self.condvar, true) - } -} - -fn ok_empty(res: Result) -> Result, RecvError> { - match res { - Ok(value) => Ok(Some(value)), - Err(TryRecvError::Empty) => Ok(None), - Err(TryRecvError::Lagged(n)) => Err(RecvError::Lagged(n)), - Err(TryRecvError::Closed) => Err(RecvError::Closed), + // Decrement the remaining counter + if 1 == self.slot.rem.fetch_sub(1, SeqCst) { + // Safety: Last receiver, drop the value + self.slot.val.with_mut(|ptr| unsafe { *ptr = None }); + } } } @@ -1084,3 +1234,5 @@ impl fmt::Display for TryRecvError { } impl std::error::Error for TryRecvError {} + +fn is_unpin() {} diff --git a/tokio/src/sync/cancellation_token.rs b/tokio/src/sync/cancellation_token.rs new file mode 100644 index 00000000000..d60d8e0202c --- /dev/null +++ b/tokio/src/sync/cancellation_token.rs @@ -0,0 +1,861 @@ +//! An asynchronously awaitable `CancellationToken`. +//! The token allows to signal a cancellation request to one or more tasks. + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Mutex; +use crate::util::intrusive_double_linked_list::{LinkedList, ListNode}; + +use core::future::Future; +use core::pin::Pin; +use core::ptr::NonNull; +use core::sync::atomic::Ordering; +use core::task::{Context, Poll, Waker}; + +/// A token which can be used to signal a cancellation request to one or more +/// tasks. +/// +/// Tasks can call [`CancellationToken::cancelled()`] in order to +/// obtain a Future which will be resolved when cancellation is requested. +/// +/// Cancellation can be requested through the [`CancellationToken::cancel`] method. +/// +/// # Examples +/// +/// ```ignore +/// use tokio::select; +/// use tokio::scope::CancellationToken; +/// +/// #[tokio::main] +/// async fn main() { +/// let token = CancellationToken::new(); +/// let cloned_token = token.clone(); +/// +/// let join_handle = tokio::spawn(async move { +/// // Wait for either cancellation or a very long time +/// select! { +/// _ = cloned_token.cancelled() => { +/// // The token was cancelled +/// 5 +/// } +/// _ = tokio::time::delay_for(std::time::Duration::from_secs(9999)) => { +/// 99 +/// } +/// } +/// }); +/// +/// tokio::spawn(async move { +/// tokio::time::delay_for(std::time::Duration::from_millis(10)).await; +/// token.cancel(); +/// }); +/// +/// assert_eq!(5, join_handle.await.unwrap()); +/// } +/// ``` +pub struct CancellationToken { + inner: NonNull, +} + +// Safety: The CancellationToken is thread-safe and can be moved between threads, +// since all methods are internally synchronized. +unsafe impl Send for CancellationToken {} +unsafe impl Sync for CancellationToken {} + +/// A Future that is resolved once the corresponding [`CancellationToken`] +/// was cancelled +#[must_use = "futures do nothing unless polled"] +pub struct WaitForCancellationFuture<'a> { + /// The CancellationToken that is associated with this WaitForCancellationFuture + cancellation_token: Option<&'a CancellationToken>, + /// Node for waiting at the cancellation_token + wait_node: ListNode, + /// Whether this future was registered at the token yet as a waiter + is_registered: bool, +} + +// Safety: Futures can be sent between threads as long as the underlying +// cancellation_token is thread-safe (Sync), +// which allows to poll/register/unregister from a different thread. +unsafe impl<'a> Send for WaitForCancellationFuture<'a> {} + +// ===== impl CancellationToken ===== + +impl core::fmt::Debug for CancellationToken { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("CancellationToken") + .field("is_cancelled", &self.is_cancelled()) + .finish() + } +} + +impl Clone for CancellationToken { + fn clone(&self) -> Self { + // Safety: The state inside a `CancellationToken` is always valid, since + // is reference counted + let inner = self.state(); + + // Tokens are cloned by increasing their refcount + let current_state = inner.snapshot(); + inner.increment_refcount(current_state); + + CancellationToken { inner: self.inner } + } +} + +impl Drop for CancellationToken { + fn drop(&mut self) { + let token_state_pointer = self.inner; + + // Safety: The state inside a `CancellationToken` is always valid, since + // is reference counted + let inner = unsafe { &mut *self.inner.as_ptr() }; + + let mut current_state = inner.snapshot(); + + // We need to safe the parent, since the state might be released by the + // next call + let parent = inner.parent; + + // Drop our own refcount + current_state = inner.decrement_refcount(current_state); + + // If this was the last reference, unregister from the parent + if current_state.refcount == 0 { + if let Some(mut parent) = parent { + // Safety: Since we still retain a reference on the parent, it must be valid. + let parent = unsafe { parent.as_mut() }; + parent.unregister_child(token_state_pointer, current_state); + } + } + } +} + +impl CancellationToken { + /// Creates a new CancellationToken in the non-cancelled state. + pub fn new() -> CancellationToken { + let state = Box::new(CancellationTokenState::new( + None, + StateSnapshot { + cancel_state: CancellationState::NotCancelled, + has_parent_ref: false, + refcount: 1, + }, + )); + + // Safety: We just created the Box. The pointer is guaranteed to be + // not null + CancellationToken { + inner: unsafe { NonNull::new_unchecked(Box::into_raw(state)) }, + } + } + + /// Returns a reference to the utilized `CancellationTokenState`. + fn state(&self) -> &CancellationTokenState { + // Safety: The state inside a `CancellationToken` is always valid, since + // is reference counted + unsafe { &*self.inner.as_ptr() } + } + + /// Creates a `CancellationToken` which will get cancelled whenever the + /// current token gets cancelled. + /// + /// If the current token is already cancelled, the child token will get + /// returned in cancelled state. + /// + /// # Examples + /// + /// ```ignore + /// use tokio::select; + /// use tokio::scope::CancellationToken; + /// + /// #[tokio::main] + /// async fn main() { + /// let token = CancellationToken::new(); + /// let child_token = token.child_token(); + /// + /// let join_handle = tokio::spawn(async move { + /// // Wait for either cancellation or a very long time + /// select! { + /// _ = child_token.cancelled() => { + /// // The token was cancelled + /// 5 + /// } + /// _ = tokio::time::delay_for(std::time::Duration::from_secs(9999)) => { + /// 99 + /// } + /// } + /// }); + /// + /// tokio::spawn(async move { + /// tokio::time::delay_for(std::time::Duration::from_millis(10)).await; + /// token.cancel(); + /// }); + /// + /// assert_eq!(5, join_handle.await.unwrap()); + /// } + /// ``` + pub fn child_token(&self) -> CancellationToken { + let inner = self.state(); + + // Increment the refcount of this token. It will be referenced by the + // child, independent of whether the child is immediately cancelled or + // not. + let _current_state = inner.increment_refcount(inner.snapshot()); + + let mut unpacked_child_state = StateSnapshot { + has_parent_ref: true, + refcount: 1, + cancel_state: CancellationState::NotCancelled, + }; + let mut child_token_state = Box::new(CancellationTokenState::new( + Some(self.inner), + unpacked_child_state, + )); + + { + let mut guard = inner.synchronized.lock().unwrap(); + if guard.is_cancelled { + // This task was already cancelled. In this case we should not + // insert the child into the list, since it would never get removed + // from the list. + (*child_token_state.synchronized.lock().unwrap()).is_cancelled = true; + unpacked_child_state.cancel_state = CancellationState::Cancelled; + // Since it's not in the list, the parent doesn't need to retain + // a reference to it. + unpacked_child_state.has_parent_ref = false; + child_token_state + .state + .store(unpacked_child_state.pack(), Ordering::SeqCst); + } else { + if let Some(mut first_child) = guard.first_child { + child_token_state.from_parent.next_peer = Some(first_child); + // Safety: We manipulate other child task inside the Mutex + // and retain a parent reference on it. The child token can't + // get invalidated while the Mutex is held. + unsafe { + first_child.as_mut().from_parent.prev_peer = + Some((&mut *child_token_state).into()) + }; + } + guard.first_child = Some((&mut *child_token_state).into()); + } + }; + + let child_token_ptr = Box::into_raw(child_token_state); + // Safety: We just created the pointer from a `Box` + CancellationToken { + inner: unsafe { NonNull::new_unchecked(child_token_ptr) }, + } + } + + /// Cancel the [`CancellationToken`] and all child tokens which had been + /// derived from it. + /// + /// This will wake up all tasks which are waiting for cancellation. + pub fn cancel(&self) { + self.state().cancel(); + } + + /// Returns `true` if the `CancellationToken` had been cancelled + pub fn is_cancelled(&self) -> bool { + self.state().is_cancelled() + } + + /// Returns a `Future` that gets fulfilled when cancellation is requested. + pub fn cancelled(&self) -> WaitForCancellationFuture<'_> { + WaitForCancellationFuture { + cancellation_token: Some(self), + wait_node: ListNode::new(WaitQueueEntry::new()), + is_registered: false, + } + } + + unsafe fn register( + &self, + wait_node: &mut ListNode, + cx: &mut Context<'_>, + ) -> Poll<()> { + self.state().register(wait_node, cx) + } + + fn check_for_cancellation( + &self, + wait_node: &mut ListNode, + cx: &mut Context<'_>, + ) -> Poll<()> { + self.state().check_for_cancellation(wait_node, cx) + } + + fn unregister(&self, wait_node: &mut ListNode) { + self.state().unregister(wait_node) + } +} + +// ===== impl WaitForCancellationFuture ===== + +impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitForCancellationFuture").finish() + } +} + +impl<'a> Future for WaitForCancellationFuture<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + // Safety: We do not move anything out of `WaitForCancellationFuture` + let mut_self: &mut WaitForCancellationFuture<'_> = unsafe { Pin::get_unchecked_mut(self) }; + + let cancellation_token = mut_self + .cancellation_token + .expect("polled WaitForCancellationFuture after completion"); + + let poll_res = if !mut_self.is_registered { + // Safety: The `ListNode` is pinned through the Future, + // and we will unregister it in `WaitForCancellationFuture::drop` + // before the Future is dropped and the memory reference is invalidated. + unsafe { cancellation_token.register(&mut mut_self.wait_node, cx) } + } else { + cancellation_token.check_for_cancellation(&mut mut_self.wait_node, cx) + }; + + if let Poll::Ready(()) = poll_res { + // The cancellation_token was signalled + mut_self.cancellation_token = None; + // A signalled Token means the Waker won't be enqueued anymore + mut_self.is_registered = false; + mut_self.wait_node.task = None; + } else { + // This `Future` and its stored `Waker` stay registered at the + // `CancellationToken` + mut_self.is_registered = true; + } + + poll_res + } +} + +impl<'a> Drop for WaitForCancellationFuture<'a> { + fn drop(&mut self) { + // If this WaitForCancellationFuture has been polled and it was added to the + // wait queue at the cancellation_token, it must be removed before dropping. + // Otherwise the cancellation_token would access invalid memory. + if let Some(token) = self.cancellation_token { + if self.is_registered { + token.unregister(&mut self.wait_node); + } + } + } +} + +/// Tracks how the future had interacted with the [`CancellationToken`] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum PollState { + /// The task has never interacted with the [`CancellationToken`]. + New, + /// The task was added to the wait queue at the [`CancellationToken`]. + Waiting, + /// The task has been polled to completion. + Done, +} + +/// Tracks the WaitForCancellationFuture waiting state. +/// Access to this struct is synchronized through the mutex in the CancellationToken. +struct WaitQueueEntry { + /// The task handle of the waiting task + task: Option, + // Current polling state. This state is only updated inside the Mutex of + // the CancellationToken. + state: PollState, +} + +impl WaitQueueEntry { + /// Creates a new WaitQueueEntry + fn new() -> WaitQueueEntry { + WaitQueueEntry { + task: None, + state: PollState::New, + } + } +} + +struct SynchronizedState { + waiters: LinkedList, + first_child: Option>, + is_cancelled: bool, +} + +impl SynchronizedState { + fn new() -> Self { + Self { + waiters: LinkedList::new(), + first_child: None, + is_cancelled: false, + } + } +} + +/// Information embedded in child tokens which is synchronized through the Mutex +/// in their parent. +struct SynchronizedThroughParent { + next_peer: Option>, + prev_peer: Option>, +} + +/// Possible states of a `CancellationToken` +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum CancellationState { + NotCancelled = 0, + Cancelling = 1, + Cancelled = 2, +} + +impl CancellationState { + fn pack(self) -> usize { + self as usize + } + + fn unpack(value: usize) -> Self { + match value { + 0 => CancellationState::NotCancelled, + 1 => CancellationState::Cancelling, + 2 => CancellationState::Cancelled, + _ => unreachable!("Invalid value"), + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +struct StateSnapshot { + /// The amount of references to this particular CancellationToken. + /// `CancellationToken` structs hold these references to a `CancellationTokenState`. + /// Also the state is referenced by the state of each child. + refcount: usize, + /// Whether the state is still referenced by it's parent and can therefore + /// not be freed. + has_parent_ref: bool, + /// Whether the token is cancelled + cancel_state: CancellationState, +} + +impl StateSnapshot { + /// Packs the snapshot into a `usize` + fn pack(self) -> usize { + self.refcount << 3 | if self.has_parent_ref { 4 } else { 0 } | self.cancel_state.pack() + } + + /// Unpacks the snapshot from a `usize` + fn unpack(value: usize) -> Self { + let refcount = value >> 3; + let has_parent_ref = value & 4 != 0; + let cancel_state = CancellationState::unpack(value & 0x03); + + StateSnapshot { + refcount, + has_parent_ref, + cancel_state, + } + } + + /// Whether this `CancellationTokenState` is still referenced by any + /// `CancellationToken`. + fn has_refs(&self) -> bool { + self.refcount != 0 || self.has_parent_ref + } +} + +/// The maximum permitted amount of references to a CancellationToken. This +/// is derived from the intent to never use more than 32bit in the `Snapshot`. +const MAX_REFS: u32 = (std::u32::MAX - 7) >> 3; + +/// Internal state of the `CancellationToken` pair above +struct CancellationTokenState { + state: AtomicUsize, + parent: Option>, + from_parent: SynchronizedThroughParent, + synchronized: Mutex, +} + +impl CancellationTokenState { + fn new( + parent: Option>, + state: StateSnapshot, + ) -> CancellationTokenState { + CancellationTokenState { + parent, + from_parent: SynchronizedThroughParent { + prev_peer: None, + next_peer: None, + }, + state: AtomicUsize::new(state.pack()), + synchronized: Mutex::new(SynchronizedState::new()), + } + } + + /// Returns a snapshot of the current atomic state of the token + fn snapshot(&self) -> StateSnapshot { + StateSnapshot::unpack(self.state.load(Ordering::SeqCst)) + } + + fn atomic_update_state(&self, mut current_state: StateSnapshot, func: F) -> StateSnapshot + where + F: Fn(StateSnapshot) -> StateSnapshot, + { + let mut current_packed_state = current_state.pack(); + loop { + let next_state = func(current_state); + match self.state.compare_exchange( + current_packed_state, + next_state.pack(), + Ordering::SeqCst, + Ordering::SeqCst, + ) { + Ok(_) => { + return next_state; + } + Err(actual) => { + current_packed_state = actual; + current_state = StateSnapshot::unpack(actual); + } + } + } + } + + fn increment_refcount(&self, current_state: StateSnapshot) -> StateSnapshot { + self.atomic_update_state(current_state, |mut state: StateSnapshot| { + if state.refcount >= MAX_REFS as usize { + eprintln!("[ERROR] Maximum reference count for CancellationToken was exceeded"); + std::process::abort(); + } + state.refcount += 1; + state + }) + } + + fn decrement_refcount(&self, current_state: StateSnapshot) -> StateSnapshot { + let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| { + state.refcount -= 1; + state + }); + + // Drop the State if it is not referenced anymore + if !current_state.has_refs() { + // Safety: `CancellationTokenState` is always stored in refcounted + // Boxes + let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) }; + } + + current_state + } + + fn remove_parent_ref(&self, current_state: StateSnapshot) -> StateSnapshot { + let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| { + state.has_parent_ref = false; + state + }); + + // Drop the State if it is not referenced anymore + if !current_state.has_refs() { + // Safety: `CancellationTokenState` is always stored in refcounted + // Boxes + let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) }; + } + + current_state + } + + /// Unregisters a child from the parent token. + /// The child tokens state is not exactly known at this point in time. + /// If the parent token is cancelled, the child token gets removed from the + /// parents list, and might therefore already have been freed. If the parent + /// token is not cancelled, the child token is still valid. + fn unregister_child( + &mut self, + mut child_state: NonNull, + current_child_state: StateSnapshot, + ) { + let removed_child = { + // Remove the child toke from the parents linked list + let mut guard = self.synchronized.lock().unwrap(); + if !guard.is_cancelled { + // Safety: Since the token was not cancelled, the child must + // still be in the list and valid. + let mut child_state = unsafe { child_state.as_mut() }; + debug_assert!(child_state.snapshot().has_parent_ref); + + if guard.first_child == Some(child_state.into()) { + guard.first_child = child_state.from_parent.next_peer; + } + // Safety: If peers wouldn't be valid anymore, they would try + // to remove themselves from the list. This would require locking + // the Mutex that we currently own. + unsafe { + if let Some(mut prev_peer) = child_state.from_parent.prev_peer { + prev_peer.as_mut().from_parent.next_peer = + child_state.from_parent.next_peer; + } + if let Some(mut next_peer) = child_state.from_parent.next_peer { + next_peer.as_mut().from_parent.prev_peer = + child_state.from_parent.prev_peer; + } + } + child_state.from_parent.prev_peer = None; + child_state.from_parent.next_peer = None; + + // The child is no longer referenced by the parent, since we were able + // to remove its reference from the parents list. + true + } else { + // Do not touch the linked list anymore. If the parent is cancelled + // it will move all childs outside of the Mutex and manipulate + // the pointers there. Manipulating the pointers here too could + // lead to races. Therefore leave them just as as and let the + // parent deal with it. The parent will make sure to retain a + // reference to this state as long as it manipulates the list + // pointers. Therefore the pointers are not dangling. + false + } + }; + + if removed_child { + // If the token removed itself from the parents list, it can reset + // the the parent ref status. If it is isn't able to do so, because the + // parent removed it from the list, there is no need to do this. + // The parent ref acts as as another reference count. Therefore + // removing this reference can free the object. + // Safety: The token was in the list. This means the parent wasn't + // cancelled before, and the token must still be alive. + unsafe { child_state.as_mut().remove_parent_ref(current_child_state) }; + } + + // Decrement the refcount on the parent and free it if necessary + self.decrement_refcount(self.snapshot()); + } + + fn cancel(&self) { + // Move the state of the CancellationToken from `NotCancelled` to `Cancelling` + let mut current_state = self.snapshot(); + + let state_after_cancellation = loop { + if current_state.cancel_state != CancellationState::NotCancelled { + // Another task already initiated the cancellation + return; + } + + let mut next_state = current_state; + next_state.cancel_state = CancellationState::Cancelling; + match self.state.compare_exchange( + current_state.pack(), + next_state.pack(), + Ordering::SeqCst, + Ordering::SeqCst, + ) { + Ok(_) => break next_state, + Err(actual) => current_state = StateSnapshot::unpack(actual), + } + }; + + // This task cancelled the token + + // Take the task list out of the Token + // We do not want to cancel child token inside this lock. If one of the + // child tasks would have additional child tokens, we would recursively + // take locks. + + // Doing this action has an impact if the child token is dropped concurrently: + // It will try to deregister itself from the parent task, but can not find + // itself in the task list anymore. Therefore it needs to assume the parent + // has extracted the list and will process it. It may not modify the list. + // This is OK from a memory safety perspective, since the parent still + // retains a reference to the child task until it finished iterating over + // it. + + let mut first_child = { + let mut guard = self.synchronized.lock().unwrap(); + // Save the cancellation also inside the Mutex + // This allows child tokens which want to detach themselves to detect + // that this is no longer required since the parent cleared the list. + guard.is_cancelled = true; + + // Wakeup all waiters + // This happens inside the lock to make cancellation reliable + // If we would access waiters outside of the lock, the pointers + // may no longer be valid. + // Typically this shouldn't be an issue, since waking a task should + // only move it from the blocked into the ready state and not have + // further side effects. + + // Use a reverse iterator, so that the oldest waiter gets + // scheduled first + guard.waiters.reverse_drain(|waiter| { + // We are not allowed to move the `Waker` out of the list node. + // The `Future` relies on the fact that the old `Waker` stays there + // as long as the `Future` has not completed in order to perform + // the `will_wake()` check. + // Therefore `wake_by_ref` is used instead of `wake()` + if let Some(handle) = &mut waiter.task { + handle.wake_by_ref(); + } + // Mark the waiter to have been removed from the list. + waiter.state = PollState::Done; + }); + + guard.first_child.take() + }; + + while let Some(mut child) = first_child { + // Safety: We know this is a valid pointer since it is in our child pointer + // list. It can't have been freed in between, since we retain a a reference + // to each child. + let mut_child = unsafe { child.as_mut() }; + + // Get the next child and clean up list pointers + first_child = mut_child.from_parent.next_peer; + mut_child.from_parent.prev_peer = None; + mut_child.from_parent.next_peer = None; + + // Cancel the child task + mut_child.cancel(); + + // Drop the parent reference. This `CancellationToken` is not interested + // in interacting with the child anymore. + // This is ONLY allowed once we promised not to touch the state anymore + // after this interaction. + mut_child.remove_parent_ref(mut_child.snapshot()); + } + + // The cancellation has completed + // At this point in time tasks which registered a wait node can be sure + // that this wait node already had been dequeued from the list without + // needing to inspect the list. + self.atomic_update_state(state_after_cancellation, |mut state| { + state.cancel_state = CancellationState::Cancelled; + state + }); + } + + /// Returns `true` if the `CancellationToken` had been cancelled + fn is_cancelled(&self) -> bool { + let current_state = self.snapshot(); + current_state.cancel_state != CancellationState::NotCancelled + } + + /// Registers a waiting task at the `CancellationToken`. + /// Safety: This method is only safe as long as the waiting waiting task + /// will properly unregister the wait node before it gets moved. + unsafe fn register( + &self, + wait_node: &mut ListNode, + cx: &mut Context<'_>, + ) -> Poll<()> { + debug_assert_eq!(PollState::New, wait_node.state); + let current_state = self.snapshot(); + + // Perform an optimistic cancellation check before. This is not strictly + // necessary since we also check for cancellation in the Mutex, but + // reduces the necessary work to be performed for tasks which already + // had been cancelled. + if current_state.cancel_state != CancellationState::NotCancelled { + return Poll::Ready(()); + } + + // So far the token is not cancelled. However it could be cancelld before + // we get the chance to store the `Waker`. Therfore we need to check + // for cancellation again inside the mutex. + let mut guard = self.synchronized.lock().unwrap(); + if guard.is_cancelled { + // Cancellation was signalled + wait_node.state = PollState::Done; + Poll::Ready(()) + } else { + // Added the task to the wait queue + wait_node.task = Some(cx.waker().clone()); + wait_node.state = PollState::Waiting; + guard.waiters.add_front(wait_node); + Poll::Pending + } + } + + fn check_for_cancellation( + &self, + wait_node: &mut ListNode, + cx: &mut Context<'_>, + ) -> Poll<()> { + debug_assert!( + wait_node.task.is_some(), + "Method can only be called after task had been registered" + ); + + let current_state = self.snapshot(); + + if current_state.cancel_state != CancellationState::NotCancelled { + // If the cancellation had been fully completed we know that our `Waker` + // is no longer registered at the `CancellationToken`. + // Otherwise the cancel call may or may not yet have iterated + // through the waiters list and removed the wait nodes. + // If it hasn't yet, we need to remove it. Otherwise an attempt to + // reuse the `wait_node´ might get freed due to the `WaitForCancellationFuture` + // getting dropped before the cancellation had interacted with it. + if current_state.cancel_state != CancellationState::Cancelled { + self.unregister(wait_node); + } + Poll::Ready(()) + } else { + // Check if we need to swap the `Waker`. This will make the check more + // expensive, since the `Waker` is synchronized through the Mutex. + // If we don't need to perform a `Waker` update, an atomic check for + // cancellation is sufficient. + let need_waker_update = wait_node + .task + .as_ref() + .map(|waker| waker.will_wake(cx.waker())) + .unwrap_or(true); + + if need_waker_update { + let guard = self.synchronized.lock().unwrap(); + if guard.is_cancelled { + // Cancellation was signalled. Since this cancellation signal + // is set inside the Mutex, the old waiter must already have + // been removed from the waiting list + debug_assert_eq!(PollState::Done, wait_node.state); + wait_node.task = None; + Poll::Ready(()) + } else { + // The WaitForCancellationFuture is already in the queue. + // The CancellationToken can't have been cancelled, + // since this would change the is_cancelled flag inside the mutex. + // Therefore we just have to update the Waker. A follow-up + // cancellation will always use the new waker. + wait_node.task = Some(cx.waker().clone()); + Poll::Pending + } + } else { + // Do nothing. If the token gets cancelled, this task will get + // woken again and can fetch the cancellation. + Poll::Pending + } + } + } + + fn unregister(&self, wait_node: &mut ListNode) { + debug_assert!( + wait_node.task.is_some(), + "waiter can not be active without task" + ); + + let mut guard = self.synchronized.lock().unwrap(); + // WaitForCancellationFuture only needs to get removed if it has been added to + // the wait queue of the CancellationToken. + // This has happened in the PollState::Waiting case. + if let PollState::Waiting = wait_node.state { + // Safety: Due to the state, we know that the node must be part + // of the waiter list + if !unsafe { guard.waiters.remove(wait_node) } { + // Panic if the address isn't found. This can only happen if the contract was + // violated, e.g. the WaitQueueEntry got moved after the initial poll. + panic!("Future could not be removed from wait queue"); + } + wait_node.state = PollState::Done; + } + wait_node.task = None; + } +} diff --git a/tokio/src/sync/mod.rs b/tokio/src/sync/mod.rs index 359b14f5e4b..3d96106d2df 100644 --- a/tokio/src/sync/mod.rs +++ b/tokio/src/sync/mod.rs @@ -32,7 +32,7 @@ //! single producer to a single consumer. This channel is usually used to send //! the result of a computation to a waiter. //! -//! **Example:** using a `oneshot` channel to receive the result of a +//! **Example:** using a [`oneshot` channel][oneshot] to receive the result of a //! computation. //! //! ``` @@ -58,11 +58,12 @@ //! } //! ``` //! -//! Note, if the task produces the the computation result as its final action -//! before terminating, the [`JoinHandle`] can be used to receive the -//! computation result instead of allocating resources for the `oneshot` -//! channel. Awaiting on [`JoinHandle`] returns `Result`. If the task panics, -//! the `Joinhandle` yields `Err` with the panic cause. +//! Note, if the task produces a computation result as its final +//! action before terminating, the [`JoinHandle`] can be used to +//! receive that value instead of allocating resources for the +//! `oneshot` channel. Awaiting on [`JoinHandle`] returns `Result`. If +//! the task panics, the `Joinhandle` yields `Err` with the panic +//! cause. //! //! **Example:** //! @@ -84,6 +85,7 @@ //! } //! ``` //! +//! [oneshot]: oneshot //! [`JoinHandle`]: crate::task::JoinHandle //! //! ## `mpsc` channel @@ -230,9 +232,11 @@ //! } //! ``` //! +//! [mpsc]: mpsc +//! //! ## `broadcast` channel //! -//! The [`broadcast` channel][broadcast] supports sending **many** values from +//! The [`broadcast` channel] supports sending **many** values from //! **many** producers to **many** consumers. Each consumer will receive //! **each** value. This channel can be used to implement "fan out" style //! patterns common with pub / sub or "chat" systems. @@ -265,12 +269,14 @@ //! } //! ``` //! +//! [`broadcast` channel]: crate::sync::broadcast +//! //! ## `watch` channel //! -//! The [`watch` channel][watch] supports sending **many** values from a -//! **single** producer to **many** consumers. However, only the **most recent** -//! value is stored in the channel. Consumers are notified when a new value is -//! sent, but there is no guarantee that consumers will see **all** values. +//! The [`watch` channel] supports sending **many** values from a **single** +//! producer to **many** consumers. However, only the **most recent** value is +//! stored in the channel. Consumers are notified when a new value is sent, but +//! there is no guarantee that consumers will see **all** values. //! //! The [`watch` channel] is similar to a [`broadcast` channel] with capacity 1. //! @@ -278,9 +284,9 @@ //! changes or signalling program state changes, such as transitioning to //! shutdown. //! -//! **Example:** use a `watch` channel to notify tasks of configuration changes. -//! In this example, a configuration file is checked periodically. When the file -//! changes, the configuration changes are signalled to consumers. +//! **Example:** use a [`watch` channel] to notify tasks of configuration +//! changes. In this example, a configuration file is checked periodically. When +//! the file changes, the configuration changes are signalled to consumers. //! //! ``` //! use tokio::sync::watch; @@ -393,6 +399,9 @@ //! } //! ``` //! +//! [`watch` channel]: crate::sync::watch +//! [`broadcast` channel]: crate::sync::broadcast +//! //! # State synchronization //! //! The remaining synchronization primitives focus on synchronizing state. @@ -400,23 +409,23 @@ //! operate in a similar way as their `std` counterparts parts but will wait //! asynchronously instead of blocking the thread. //! -//! * [`Barrier`][Barrier] Ensures multiple tasks will wait for each other to +//! * [`Barrier`](Barrier) Ensures multiple tasks will wait for each other to //! reach a point in the program, before continuing execution all together. //! -//! * [`Mutex`][Mutex] Mutual Exclusion mechanism, which ensures that at most +//! * [`Mutex`](Mutex) Mutual Exclusion mechanism, which ensures that at most //! one thread at a time is able to access some data. //! -//! * [`Notify`][Notify] Basic task notification. `Notify` supports notifying a +//! * [`Notify`](Notify) Basic task notification. `Notify` supports notifying a //! receiving task without sending data. In this case, the task wakes up and //! resumes processing. //! -//! * [`RwLock`][RwLock] Provides a mutual exclusion mechanism which allows +//! * [`RwLock`](RwLock) Provides a mutual exclusion mechanism which allows //! multiple readers at the same time, while allowing only one writer at a //! time. In some cases, this can be more efficient than a mutex. //! -//! * [`Semaphore`][Semaphore] Limits the amount of concurrency. A semaphore +//! * [`Semaphore`](Semaphore) Limits the amount of concurrency. A semaphore //! holds a number of permits, which tasks may request in order to enter a -//! critical section. Semaphores are useful for implementing limiting of +//! critical section. Semaphores are useful for implementing limiting or //! bounding of any kind. cfg_sync! { @@ -425,6 +434,11 @@ cfg_sync! { pub mod broadcast; + cfg_unstable! { + mod cancellation_token; + pub use cancellation_token::{CancellationToken, WaitForCancellationFuture}; + } + pub mod mpsc; mod mutex; diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index 34663957883..148ee3ad766 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -277,7 +277,7 @@ where use super::block::Read::*; // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); self.inner.rx_fields.with_mut(|rx_fields_ptr| { let rx_fields = unsafe { &mut *rx_fields_ptr }; @@ -287,6 +287,7 @@ where match rx_fields.list.pop(&self.inner.tx) { Some(Value(value)) => { self.inner.semaphore.add_permit(); + coop.made_progress(); return Ready(Some(value)); } Some(Closed) => { @@ -297,6 +298,7 @@ where // which ensures that if dropping the tx handle is // visible, then all messages sent are also visible. assert!(self.inner.semaphore.is_idle()); + coop.made_progress(); return Ready(None); } None => {} // fall through @@ -314,6 +316,7 @@ where try_recv!(); if rx_fields.rx_closed && self.inner.semaphore.is_idle() { + coop.made_progress(); Ready(None) } else { Pending @@ -439,11 +442,15 @@ impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) { permit: &mut Permit, ) -> Poll> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); permit .poll_acquire(cx, 1, &self.0) .map_err(|_| ClosedError::new()) + .map(move |r| { + coop.made_progress(); + r + }) } fn try_acquire(&self, permit: &mut Permit) -> Result<(), TrySendError> { diff --git a/tokio/src/sync/mpsc/mod.rs b/tokio/src/sync/mpsc/mod.rs index 4cfd6150f30..c489c9f99ff 100644 --- a/tokio/src/sync/mpsc/mod.rs +++ b/tokio/src/sync/mpsc/mod.rs @@ -6,13 +6,17 @@ //! Similar to `std`, channel creation provides [`Receiver`] and [`Sender`] //! handles. [`Receiver`] implements `Stream` and allows a task to read values //! out of the channel. If there is no message to read, the current task will be -//! notified when a new value is sent. [`Sender`] implements the `Sink` trait -//! and allows sending messages into the channel. If the channel is at capacity, -//! the send is rejected and the task will be notified when additional capacity -//! is available. In other words, the channel provides backpressure. +//! notified when a new value is sent. If the channel is at capacity, the send +//! is rejected and the task will be notified when additional capacity is +//! available. In other words, the channel provides backpressure. //! -//! Unbounded channels are also available using the `unbounded_channel` -//! constructor. +//! This module provides two variants of the channel: bounded and unbounded. The +//! bounded variant has a limit on the number of messages that the channel can +//! store, and if this limit is reached, trying to send another message will +//! wait until a message is received from the channel. An unbounded channel has +//! an infinite capacity, so the `send` method never does any kind of sleeping. +//! This makes the [`UnboundedSender`] usable from both synchronous and +//! asynchronous code. //! //! # Disconnection //! @@ -33,8 +37,32 @@ //! consumes the channel to completion, at which point the receiver can be //! dropped. //! +//! # Communicating between sync and async code +//! +//! When you want to communicate between synchronous and asynchronous code, there +//! are two situations to consider: +//! +//! **Bounded channel**: If you need a bounded channel, you should use a bounded +//! Tokio `mpsc` channel for both directions of communication. To call the async +//! [`send`][bounded-send] or [`recv`][bounded-recv] methods in sync code, you +//! will need to use [`Handle::block_on`], which allow you to execute an async +//! method in synchronous code. This is necessary because a bounded channel may +//! need to wait for additional capacity to become available. +//! +//! **Unbounded channel**: You should use the kind of channel that matches where +//! the receiver is. So for sending a message _from async to sync_, you should +//! use [the standard library unbounded channel][std-unbounded] or +//! [crossbeam][crossbeam-unbounded]. Similarly, for sending a message _from sync +//! to async_, you should use an unbounded Tokio `mpsc` channel. +//! //! [`Sender`]: crate::sync::mpsc::Sender //! [`Receiver`]: crate::sync::mpsc::Receiver +//! [bounded-send]: crate::sync::mpsc::Sender::send() +//! [bounded-recv]: crate::sync::mpsc::Receiver::recv() +//! [`UnboundedSender`]: crate::sync::mpsc::UnboundedSender +//! [`Handle::block_on`]: crate::runtime::Handle::block_on() +//! [std-unbounded]: std::sync::mpsc::channel +//! [crossbeam-unbounded]: https://docs.rs/crossbeam/*/crossbeam/channel/fn.unbounded.html pub(super) mod block; diff --git a/tokio/src/sync/mpsc/unbounded.rs b/tokio/src/sync/mpsc/unbounded.rs index ba543fe4c87..1b2288ab08c 100644 --- a/tokio/src/sync/mpsc/unbounded.rs +++ b/tokio/src/sync/mpsc/unbounded.rs @@ -163,9 +163,13 @@ impl UnboundedSender { /// Attempts to send a message on this `UnboundedSender` without blocking. /// + /// This method is not marked async because sending a message to an unbounded channel + /// never requires any form of waiting. Because of this, the `send` method can be + /// used in both synchronous and asynchronous code without problems. + /// /// If the receive half of the channel is closed, either due to [`close`] - /// being called or the [`UnboundedReceiver`] having been dropped, - /// the function returns an error. The error includes the value passed to `send`. + /// being called or the [`UnboundedReceiver`] having been dropped, this + /// function returns an error. The error includes the value passed to `send`. /// /// [`close`]: UnboundedReceiver::close /// [`UnboundedReceiver`]: UnboundedReceiver diff --git a/tokio/src/sync/mutex.rs b/tokio/src/sync/mutex.rs index e0618a5d6e4..642058be626 100644 --- a/tokio/src/sync/mutex.rs +++ b/tokio/src/sync/mutex.rs @@ -1,4 +1,3 @@ -use crate::coop::CoopFutureExt; use crate::sync::batch_semaphore as semaphore; use std::cell::UnsafeCell; @@ -10,17 +9,35 @@ use std::sync::Arc; /// An asynchronous `Mutex`-like type. /// /// This type acts similarly to an asynchronous [`std::sync::Mutex`], with one -/// major difference: [`lock`] does not block. Another difference is that the -/// lock guard can be held across await points. +/// major difference: [`lock`] does not block and the lock guard can be held +/// across await points. /// -/// There are some situations where you should prefer the mutex from the -/// standard library. Generally this is the case if: +/// # Which kind of mutex should you use? /// -/// 1. The lock does not need to be held across await points. -/// 2. The duration of any single lock is near-instant. +/// Contrary to popular belief, it is ok and often preferred to use the ordinary +/// [`Mutex`][std] from the standard library in asynchronous code. This section +/// will help you decide on which kind of mutex you should use. /// -/// On the other hand, the Tokio mutex is for the situation where the lock -/// needs to be held for longer periods of time, or across await points. +/// The primary use case of the async mutex is to provide shared mutable access +/// to IO resources such as a database connection. If the data stored behind the +/// mutex is just data, it is often better to use a blocking mutex such as the +/// one in the standard library or [`parking_lot`]. This is because the feature +/// that the async mutex offers over the blocking mutex is that it is possible +/// to keep the mutex locked across an `.await` point, which is rarely necessary +/// for data. +/// +/// A common pattern is to wrap the `Arc>` in a struct that provides +/// non-async methods for performing operations on the data within, and only +/// lock the mutex inside these methods. The [mini-redis] example provides an +/// illustration of this pattern. +/// +/// Additionally, when you _do_ want shared access to an IO resource, it is +/// often better to spawn a task to manage the IO resource, and to use message +/// passing to communicate with that task. +/// +/// [std]: std::sync::Mutex +/// [`parking_lot`]: https://docs.rs/parking_lot +/// [mini-redis]: https://github.com/tokio-rs/mini-redis/blob/master/src/db.rs /// /// # Examples: /// @@ -94,15 +111,14 @@ use std::sync::Arc; /// /// [`Mutex`]: struct@Mutex /// [`MutexGuard`]: struct@MutexGuard -/// [`Arc`]: https://doc.rust-lang.org/std/sync/struct.Arc.html -/// [`std::sync::Mutex`]: https://doc.rust-lang.org/std/sync/struct.Mutex.html -/// [`Send`]: https://doc.rust-lang.org/std/marker/trait.Send.html +/// [`Arc`]: struct@std::sync::Arc +/// [`std::sync::Mutex`]: struct@std::sync::Mutex +/// [`Send`]: trait@std::marker::Send /// [`lock`]: method@Mutex::lock - #[derive(Debug)] -pub struct Mutex { - c: UnsafeCell, +pub struct Mutex { s: semaphore::Semaphore, + c: UnsafeCell, } /// A handle to a held `Mutex`. @@ -113,7 +129,7 @@ pub struct Mutex { /// /// The lock is automatically released whenever the guard is dropped, at which /// point `lock` will succeed yet again. -pub struct MutexGuard<'a, T> { +pub struct MutexGuard<'a, T: ?Sized> { lock: &'a Mutex, } @@ -132,17 +148,17 @@ pub struct MutexGuard<'a, T> { /// point `lock` will succeed yet again. /// /// [`Arc`]: std::sync::Arc -pub struct OwnedMutexGuard { +pub struct OwnedMutexGuard { lock: Arc>, } // As long as T: Send, it's fine to send and share Mutex between threads. // If T was not Send, sending and sharing a Mutex would be bad, since you can // access T through Mutex. -unsafe impl Send for Mutex where T: Send {} -unsafe impl Sync for Mutex where T: Send {} -unsafe impl<'a, T> Sync for MutexGuard<'a, T> where T: Send + Sync {} -unsafe impl Sync for OwnedMutexGuard where T: Send + Sync {} +unsafe impl Send for Mutex where T: ?Sized + Send {} +unsafe impl Sync for Mutex where T: ?Sized + Send {} +unsafe impl Sync for MutexGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl Sync for OwnedMutexGuard where T: ?Sized + Send + Sync {} /// Error returned from the [`Mutex::try_lock`] function. /// @@ -154,7 +170,7 @@ pub struct TryLockError(()); impl fmt::Display for TryLockError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "{}", "operation would block") + write!(fmt, "operation would block") } } @@ -184,7 +200,7 @@ fn bounds() { check_static_val(arc_mutex.lock_owned()); } -impl Mutex { +impl Mutex { /// Creates a new lock in an unlocked state ready for use. /// /// # Examples @@ -194,7 +210,10 @@ impl Mutex { /// /// let lock = Mutex::new(5); /// ``` - pub fn new(t: T) -> Self { + pub fn new(t: T) -> Self + where + T: Sized, + { Self { c: UnsafeCell::new(t), s: semaphore::Semaphore::new(1), @@ -255,7 +274,7 @@ impl Mutex { } async fn acquire(&self) { - self.s.acquire(1).cooperate().await.unwrap_or_else(|_| { + self.s.acquire(1).await.unwrap_or_else(|_| { // The semaphore was closed. but, we never explicitly close it, and // we own it exclusively, which means that this can never happen. unreachable!() @@ -331,7 +350,10 @@ impl Mutex { /// assert_eq!(n, 1); /// } /// ``` - pub fn into_inner(self) -> T { + pub fn into_inner(self) -> T + where + T: Sized, + { self.c.into_inner() } } @@ -353,32 +375,32 @@ where // === impl MutexGuard === -impl<'a, T> Drop for MutexGuard<'a, T> { +impl Drop for MutexGuard<'_, T> { fn drop(&mut self) { self.lock.s.release(1) } } -impl<'a, T> Deref for MutexGuard<'a, T> { +impl Deref for MutexGuard<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { unsafe { &*self.lock.c.get() } } } -impl<'a, T> DerefMut for MutexGuard<'a, T> { +impl DerefMut for MutexGuard<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { unsafe { &mut *self.lock.c.get() } } } -impl<'a, T: fmt::Debug> fmt::Debug for MutexGuard<'a, T> { +impl fmt::Debug for MutexGuard<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(&**self, f) } } -impl<'a, T: fmt::Display> fmt::Display for MutexGuard<'a, T> { +impl fmt::Display for MutexGuard<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&**self, f) } @@ -386,32 +408,32 @@ impl<'a, T: fmt::Display> fmt::Display for MutexGuard<'a, T> { // === impl OwnedMutexGuard === -impl Drop for OwnedMutexGuard { +impl Drop for OwnedMutexGuard { fn drop(&mut self) { self.lock.s.release(1) } } -impl Deref for OwnedMutexGuard { +impl Deref for OwnedMutexGuard { type Target = T; fn deref(&self) -> &Self::Target { unsafe { &*self.lock.c.get() } } } -impl DerefMut for OwnedMutexGuard { +impl DerefMut for OwnedMutexGuard { fn deref_mut(&mut self) -> &mut Self::Target { unsafe { &mut *self.lock.c.get() } } } -impl fmt::Debug for OwnedMutexGuard { +impl fmt::Debug for OwnedMutexGuard { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(&**self, f) } } -impl fmt::Display for OwnedMutexGuard { +impl fmt::Display for OwnedMutexGuard { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&**self, f) } diff --git a/tokio/src/sync/oneshot.rs b/tokio/src/sync/oneshot.rs index 62ad484eec3..17767e7f7f8 100644 --- a/tokio/src/sync/oneshot.rs +++ b/tokio/src/sync/oneshot.rs @@ -144,8 +144,11 @@ impl Sender { /// Attempts to send a value on this channel, returning it back if it could /// not be sent. /// - /// The function consumes `self` as only one value may ever be sent on a - /// one-shot channel. + /// This method consumes `self` as only one value may ever be sent on a oneshot + /// channel. It is not marked async because sending a message to an oneshot + /// channel never requires any form of waiting. Because of this, the `send` + /// method can be used in both synchronous and asynchronous code without + /// problems. /// /// A successful send occurs when it is determined that the other end of the /// channel has not hung up already. An unsuccessful send would be one where @@ -197,13 +200,14 @@ impl Sender { #[doc(hidden)] // TODO: remove pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); let inner = self.inner.as_ref().unwrap(); let mut state = State::load(&inner.state, Acquire); if state.is_closed() { + coop.made_progress(); return Poll::Ready(()); } @@ -216,6 +220,7 @@ impl Sender { if state.is_closed() { // Set the flag again so that the waker is released in drop State::set_tx_task(&inner.state); + coop.made_progress(); return Ready(()); } else { unsafe { inner.drop_tx_task() }; @@ -233,6 +238,7 @@ impl Sender { state = State::set_tx_task(&inner.state); if state.is_closed() { + coop.made_progress(); return Ready(()); } } @@ -360,7 +366,7 @@ impl Receiver { /// Prevents the associated [`Sender`] handle from sending a value. /// /// Any `send` operation which happens after calling `close` is guaranteed - /// to fail. After calling `close`, `Receiver::poll`] should be called to + /// to fail. After calling `close`, [`try_recv`] should be called to /// receive a value if one was sent **before** the call to `close` /// completed. /// @@ -368,6 +374,7 @@ impl Receiver { /// value will not be sent into the channel and never received. /// /// [`Sender`]: Sender + /// [`try_recv`]: Receiver::try_recv /// /// # Examples /// @@ -548,17 +555,19 @@ impl Inner { fn poll_recv(&self, cx: &mut Context<'_>) -> Poll> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); // Load the state let mut state = State::load(&self.state, Acquire); if state.is_complete() { + coop.made_progress(); match unsafe { self.consume_value() } { Some(value) => Ready(Ok(value)), None => Ready(Err(RecvError(()))), } } else if state.is_closed() { + coop.made_progress(); Ready(Err(RecvError(()))) } else { if state.is_rx_task_set() { @@ -572,6 +581,7 @@ impl Inner { // Set the flag again so that the waker is released in drop State::set_rx_task(&self.state); + coop.made_progress(); return match unsafe { self.consume_value() } { Some(value) => Ready(Ok(value)), None => Ready(Err(RecvError(()))), @@ -592,6 +602,7 @@ impl Inner { state = State::set_rx_task(&self.state); if state.is_complete() { + coop.made_progress(); match unsafe { self.consume_value() } { Some(value) => Ready(Ok(value)), None => Ready(Err(RecvError(()))), diff --git a/tokio/src/sync/rwlock.rs b/tokio/src/sync/rwlock.rs index 68cf710e84b..3d2a2f7a8fc 100644 --- a/tokio/src/sync/rwlock.rs +++ b/tokio/src/sync/rwlock.rs @@ -1,4 +1,3 @@ -use crate::coop::CoopFutureExt; use crate::sync::batch_semaphore::{AcquireError, Semaphore}; use std::cell::UnsafeCell; use std::ops; @@ -32,8 +31,8 @@ const MAX_READS: usize = 10; /// /// The type parameter `T` represents the data that this lock protects. It is /// required that `T` satisfies [`Send`] to be shared across threads. The RAII guards -/// returned from the locking methods implement [`Deref`](https://doc.rust-lang.org/std/ops/trait.Deref.html) -/// (and [`DerefMut`](https://doc.rust-lang.org/std/ops/trait.DerefMut.html) +/// returned from the locking methods implement [`Deref`](trait@std::ops::Deref) +/// (and [`DerefMut`](trait@std::ops::DerefMut) /// for the `write` methods) to allow access to the content of the lock. /// /// # Examples @@ -66,10 +65,10 @@ const MAX_READS: usize = 10; /// [`RwLock`]: struct@RwLock /// [`RwLockReadGuard`]: struct@RwLockReadGuard /// [`RwLockWriteGuard`]: struct@RwLockWriteGuard -/// [`Send`]: https://doc.rust-lang.org/std/marker/trait.Send.html +/// [`Send`]: trait@std::marker::Send /// [_write-preferring_]: https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock#Priority_policies #[derive(Debug)] -pub struct RwLock { +pub struct RwLock { //semaphore to coordinate read and write access to T s: Semaphore, @@ -85,7 +84,7 @@ pub struct RwLock { /// /// [`read`]: method@RwLock::read #[derive(Debug)] -pub struct RwLockReadGuard<'a, T> { +pub struct RwLockReadGuard<'a, T: ?Sized> { permit: ReleasingPermit<'a, T>, lock: &'a RwLock, } @@ -99,29 +98,29 @@ pub struct RwLockReadGuard<'a, T> { /// [`write`]: method@RwLock::write /// [`RwLock`]: struct@RwLock #[derive(Debug)] -pub struct RwLockWriteGuard<'a, T> { +pub struct RwLockWriteGuard<'a, T: ?Sized> { permit: ReleasingPermit<'a, T>, lock: &'a RwLock, } // Wrapper arround Permit that releases on Drop #[derive(Debug)] -struct ReleasingPermit<'a, T> { +struct ReleasingPermit<'a, T: ?Sized> { num_permits: u16, lock: &'a RwLock, } -impl<'a, T> ReleasingPermit<'a, T> { +impl<'a, T: ?Sized> ReleasingPermit<'a, T> { async fn acquire( lock: &'a RwLock, num_permits: u16, ) -> Result, AcquireError> { - lock.s.acquire(num_permits).cooperate().await?; + lock.s.acquire(num_permits.into()).await?; Ok(Self { num_permits, lock }) } } -impl<'a, T> Drop for ReleasingPermit<'a, T> { +impl Drop for ReleasingPermit<'_, T> { fn drop(&mut self) { self.lock.s.release(self.num_permits as usize); } @@ -154,12 +153,12 @@ fn bounds() { // As long as T: Send + Sync, it's fine to send and share RwLock between threads. // If T were not Send, sending and sharing a RwLock would be bad, since you can access T through // RwLock. -unsafe impl Send for RwLock where T: Send {} -unsafe impl Sync for RwLock where T: Send + Sync {} -unsafe impl<'a, T> Sync for RwLockReadGuard<'a, T> where T: Send + Sync {} -unsafe impl<'a, T> Sync for RwLockWriteGuard<'a, T> where T: Send + Sync {} +unsafe impl Send for RwLock where T: ?Sized + Send {} +unsafe impl Sync for RwLock where T: ?Sized + Send + Sync {} +unsafe impl Sync for RwLockReadGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl Sync for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {} -impl RwLock { +impl RwLock { /// Creates a new instance of an `RwLock` which is unlocked. /// /// # Examples @@ -169,7 +168,10 @@ impl RwLock { /// /// let lock = RwLock::new(5); /// ``` - pub fn new(value: T) -> RwLock { + pub fn new(value: T) -> RwLock + where + T: Sized, + { RwLock { c: UnsafeCell::new(value), s: Semaphore::new(MAX_READS), @@ -251,12 +253,15 @@ impl RwLock { } /// Consumes the lock, returning the underlying data. - pub fn into_inner(self) -> T { + pub fn into_inner(self) -> T + where + T: Sized, + { self.c.into_inner() } } -impl ops::Deref for RwLockReadGuard<'_, T> { +impl ops::Deref for RwLockReadGuard<'_, T> { type Target = T; fn deref(&self) -> &T { @@ -264,7 +269,7 @@ impl ops::Deref for RwLockReadGuard<'_, T> { } } -impl ops::Deref for RwLockWriteGuard<'_, T> { +impl ops::Deref for RwLockWriteGuard<'_, T> { type Target = T; fn deref(&self) -> &T { @@ -272,7 +277,7 @@ impl ops::Deref for RwLockWriteGuard<'_, T> { } } -impl ops::DerefMut for RwLockWriteGuard<'_, T> { +impl ops::DerefMut for RwLockWriteGuard<'_, T> { fn deref_mut(&mut self) -> &mut T { unsafe { &mut *self.lock.c.get() } } diff --git a/tokio/src/sync/semaphore.rs b/tokio/src/sync/semaphore.rs index c1dd975f282..2489d34aaaf 100644 --- a/tokio/src/sync/semaphore.rs +++ b/tokio/src/sync/semaphore.rs @@ -1,5 +1,4 @@ use super::batch_semaphore as ll; // low level implementation -use crate::coop::CoopFutureExt; use std::sync::Arc; /// Counting semaphore performing asynchronous permit aquisition. @@ -81,13 +80,15 @@ impl Semaphore { } /// Adds `n` new permits to the semaphore. + /// + /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. pub fn add_permits(&self, n: usize) { self.ll_sem.release(n); } /// Acquires permit from the semaphore. pub async fn acquire(&self) -> SemaphorePermit<'_> { - self.ll_sem.acquire(1).cooperate().await.unwrap(); + self.ll_sem.acquire(1).await.unwrap(); SemaphorePermit { sem: &self, permits: 1, @@ -111,7 +112,7 @@ impl Semaphore { /// /// [`Arc`]: std::sync::Arc pub async fn acquire_owned(self: Arc) -> OwnedSemaphorePermit { - self.ll_sem.acquire(1).cooperate().await.unwrap(); + self.ll_sem.acquire(1).await.unwrap(); OwnedSemaphorePermit { sem: self.clone(), permits: 1, diff --git a/tokio/src/sync/semaphore_ll.rs b/tokio/src/sync/semaphore_ll.rs index 0bdc4e27617..25d25ac88ab 100644 --- a/tokio/src/sync/semaphore_ll.rs +++ b/tokio/src/sync/semaphore_ll.rs @@ -333,8 +333,9 @@ impl Semaphore { self.add_permits_locked(0, true); } - /// Adds `n` new permits to the semaphore. + /// + /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. pub(crate) fn add_permits(&self, n: usize) { if n == 0 { return; @@ -749,7 +750,7 @@ impl Permit { /// Forgets the permit **without** releasing it back to the semaphore. /// /// After calling `forget`, `poll_acquire` is able to acquire new permit - /// from the sempahore. + /// from the semaphore. /// /// Repeatedly calling `forget` without associated calls to `add_permit` /// will result in the semaphore losing all permits. @@ -853,8 +854,8 @@ impl TryAcquireError { impl fmt::Display for TryAcquireError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - TryAcquireError::Closed => write!(fmt, "{}", "semaphore closed"), - TryAcquireError::NoPermits => write!(fmt, "{}", "no permits available"), + TryAcquireError::Closed => write!(fmt, "semaphore closed"), + TryAcquireError::NoPermits => write!(fmt, "no permits available"), } } } diff --git a/tokio/src/sync/tests/loom_cancellation_token.rs b/tokio/src/sync/tests/loom_cancellation_token.rs new file mode 100644 index 00000000000..e9c9f3dd980 --- /dev/null +++ b/tokio/src/sync/tests/loom_cancellation_token.rs @@ -0,0 +1,155 @@ +use crate::sync::CancellationToken; + +use loom::{future::block_on, thread}; +use tokio_test::assert_ok; + +#[test] +fn cancel_token() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + + let th1 = thread::spawn(move || { + block_on(async { + token1.cancelled().await; + }); + }); + + let th2 = thread::spawn(move || { + token.cancel(); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +#[test] +fn cancel_with_child() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + let token2 = token.clone(); + let child_token = token.child_token(); + + let th1 = thread::spawn(move || { + block_on(async { + token1.cancelled().await; + }); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + block_on(async { + child_token.cancelled().await; + }); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_token_no_child() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + let token2 = token.clone(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + drop(token2); + }); + + let th3 = thread::spawn(move || { + drop(token); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_token_with_childs() { + loom::model(|| { + let token1 = CancellationToken::new(); + let child_token1 = token1.child_token(); + let child_token2 = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + drop(child_token1); + }); + + let th3 = thread::spawn(move || { + drop(child_token2); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_and_cancel_token() { + loom::model(|| { + let token1 = CancellationToken::new(); + let token2 = token1.clone(); + let child_token = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + drop(child_token); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn cancel_parent_and_child() { + loom::model(|| { + let token1 = CancellationToken::new(); + let token2 = token1.clone(); + let child_token = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + child_token.cancel(); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} diff --git a/tokio/src/sync/tests/mod.rs b/tokio/src/sync/tests/mod.rs index d571754c011..6ba8c1f9b6a 100644 --- a/tokio/src/sync/tests/mod.rs +++ b/tokio/src/sync/tests/mod.rs @@ -7,6 +7,8 @@ cfg_not_loom! { cfg_loom! { mod loom_atomic_waker; mod loom_broadcast; + #[cfg(tokio_unstable)] + mod loom_cancellation_token; mod loom_list; mod loom_mpsc; mod loom_notify; diff --git a/tokio/src/sync/tests/semaphore_batch.rs b/tokio/src/sync/tests/semaphore_batch.rs index 60f3f231e76..9342cd1cb3c 100644 --- a/tokio/src/sync/tests/semaphore_batch.rs +++ b/tokio/src/sync/tests/semaphore_batch.rs @@ -236,7 +236,7 @@ fn close_semaphore_notifies_permit2() { #[test] fn cancel_acquire_releases_permits() { let s = Semaphore::new(10); - let _permit1 = s.try_acquire(4).expect("uncontended try_acquire succeeds"); + s.try_acquire(4).expect("uncontended try_acquire succeeds"); assert_eq!(6, s.available_permits()); let mut acquire = task::spawn(s.acquire(8)); diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index ba609a8c6d7..13033d9e726 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -338,7 +338,6 @@ impl Sender { // Notify all watchers notify_all(&*shared); - // Return the old value Ok(()) } diff --git a/tokio/src/task/blocking.rs b/tokio/src/task/blocking.rs index 0ef6053528e..ed60f4c4734 100644 --- a/tokio/src/task/blocking.rs +++ b/tokio/src/task/blocking.rs @@ -29,7 +29,7 @@ cfg_rt_threaded! { /// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code /// [threaded scheduler]: fn@crate::runtime::Builder::threaded_scheduler /// [`spawn_blocking`]: fn@crate::task::spawn_blocking - /// [`join!`]: ../macro.join.html + /// [`join!`]: macro@join /// [`thread::spawn`]: fn@std::thread::spawn /// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout /// @@ -114,6 +114,19 @@ cfg_blocking! { F: FnOnce() -> R + Send + 'static, R: Send + 'static, { + #[cfg(feature = "tracing")] + let f = { + let span = tracing::trace_span!( + target: "tokio::task", + "task", + kind = %"blocking", + function = %std::any::type_name::(), + ); + move || { + let _g = span.enter(); + f() + } + }; crate::runtime::spawn_blocking(f) } } diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index 346fe437f45..3c409edfb90 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -105,7 +105,7 @@ cfg_rt_util! { /// } /// ``` /// - /// [`Send`]: https://doc.rust-lang.org/std/marker/trait.Send.html + /// [`Send`]: trait@std::marker::Send /// [local task set]: struct@LocalSet /// [`Runtime::block_on`]: method@crate::runtime::Runtime::block_on /// [`task::spawn_local`]: fn@spawn_local @@ -195,6 +195,7 @@ cfg_rt_util! { F: Future + 'static, F::Output: 'static, { + let future = crate::util::trace::task(future, "local"); CURRENT.with(|maybe_cx| { let cx = maybe_cx .expect("`spawn_local` called from outside of a `task::LocalSet`"); @@ -277,6 +278,7 @@ impl LocalSet { F: Future + 'static, F::Output: 'static, { + let future = crate::util::trace::task(future, "local"); let (task, handle) = unsafe { task::joinable_local(future) }; self.context.tasks.borrow_mut().queue.push_back(task); handle @@ -454,24 +456,20 @@ impl Future for LocalSet { // Register the waker before starting to work self.context.shared.waker.register_by_ref(cx.waker()); - // Reset any previous task budget while polling tasks spawned on the - // `LocalSet`, ensuring that each has its own separate budget. - crate::coop::reset(|| { - if self.with(|| self.tick()) { - // If `tick` returns true, we need to notify the local future again: - // there are still tasks remaining in the run queue. - cx.waker().wake_by_ref(); - Poll::Pending - } else if self.context.tasks.borrow().owned.is_empty() { - // If the scheduler has no remaining futures, we're done! - Poll::Ready(()) - } else { - // There are still futures in the local set, but we've polled all the - // futures in the run queue. Therefore, we can just return Pending - // since the remaining futures will be woken from somewhere else. - Poll::Pending - } - }) + if self.with(|| self.tick()) { + // If `tick` returns true, we need to notify the local future again: + // there are still tasks remaining in the run queue. + cx.waker().wake_by_ref(); + Poll::Pending + } else if self.context.tasks.borrow().owned.is_empty() { + // If the scheduler has no remaining futures, we're done! + Poll::Ready(()) + } else { + // There are still futures in the local set, but we've polled all the + // futures in the run queue. Therefore, we can just return Pending + // since the remaining futures will be woken from somewhere else. + Poll::Pending + } } } @@ -525,23 +523,19 @@ impl Future for RunUntil<'_, T> { .register_by_ref(cx.waker()); let _no_blocking = crate::runtime::enter::disallow_blocking(); - // Reset any previous task budget so that the future passed to - // `run_until` and any tasks spawned on the `LocalSet` have their - // own budgets. - crate::coop::reset(|| { - let f = me.future; - if let Poll::Ready(output) = crate::coop::budget(|| f.poll(cx)) { - return Poll::Ready(output); - } - - if me.local_set.tick() { - // If `tick` returns `true`, we need to notify the local future again: - // there are still tasks remaining in the run queue. - cx.waker().wake_by_ref(); - } - - Poll::Pending - }) + let f = me.future; + + if let Poll::Ready(output) = crate::coop::budget(|| f.poll(cx)) { + return Poll::Ready(output); + } + + if me.local_set.tick() { + // If `tick` returns `true`, we need to notify the local future again: + // there are still tasks remaining in the run queue. + cx.waker().wake_by_ref(); + } + + Poll::Pending }) } } diff --git a/tokio/src/task/spawn.rs b/tokio/src/task/spawn.rs index fa5ff13b01e..d6e771184f2 100644 --- a/tokio/src/task/spawn.rs +++ b/tokio/src/task/spawn.rs @@ -129,6 +129,7 @@ doc_rt_core! { { let spawn_handle = runtime::context::spawn_handle() .expect("must be called from the context of Tokio runtime configured with either `basic_scheduler` or `threaded_scheduler`"); + let task = crate::util::trace::task(task, "task"); spawn_handle.spawn(task) } } diff --git a/tokio/src/task/task_local.rs b/tokio/src/task/task_local.rs index cbff272bd6a..1679ee3ba12 100644 --- a/tokio/src/task/task_local.rs +++ b/tokio/src/task/task_local.rs @@ -89,7 +89,7 @@ macro_rules! __task_local_inner { /// }).await; /// # } /// ``` -/// [`std::thread::LocalKey`]: https://doc.rust-lang.org/std/thread/struct.LocalKey.html +/// [`std::thread::LocalKey`]: struct@std::thread::LocalKey pub struct LocalKey { #[doc(hidden)] pub inner: thread::LocalKey>>, diff --git a/tokio/src/time/clock.rs b/tokio/src/time/clock.rs index 4ac24af3d05..bd67d7a31d7 100644 --- a/tokio/src/time/clock.rs +++ b/tokio/src/time/clock.rs @@ -56,8 +56,9 @@ cfg_test_util! { /// Pause time /// /// The current value of `Instant::now()` is saved and all subsequent calls - /// to `Instant::now()` will return the saved value. This is useful for - /// running tests that are dependent on time. + /// to `Instant::now()` until the timer wheel is checked again will return the saved value. + /// Once the timer wheel is checked, time will immediately advance to the next registered + /// `Delay`. This is useful for running tests that depend on time. /// /// # Panics /// diff --git a/tokio/src/time/delay.rs b/tokio/src/time/delay.rs index 8088c9955cd..744c7e16aea 100644 --- a/tokio/src/time/delay.rs +++ b/tokio/src/time/delay.rs @@ -29,10 +29,29 @@ pub fn delay_until(deadline: Instant) -> Delay { /// operates at millisecond granularity and should not be used for tasks that /// require high-resolution timers. /// +/// To run something regularly on a schedule, see [`interval`]. +/// /// # Cancellation /// /// Canceling a delay is done by dropping the returned future. No additional /// cleanup work is required. +/// +/// # Examples +/// +/// Wait 100ms and print "100 ms have elapsed". +/// +/// ``` +/// use tokio::time::{delay_for, Duration}; +/// +/// #[tokio::main] +/// async fn main() { +/// delay_for(Duration::from_millis(100)).await; +/// println!("100 ms have elapsed"); +/// } +/// ``` +/// +/// [`interval`]: crate::time::interval() +#[cfg_attr(docsrs, doc(alias = "sleep"))] pub fn delay_for(duration: Duration) -> Delay { delay_until(Instant::now() + duration) } diff --git a/tokio/src/time/delay_queue.rs b/tokio/src/time/delay_queue.rs index 989b42e81d9..55ec7cd68d1 100644 --- a/tokio/src/time/delay_queue.rs +++ b/tokio/src/time/delay_queue.rs @@ -111,17 +111,17 @@ use std::task::{self, Poll}; /// } /// ``` /// -/// [`insert`]: #method.insert -/// [`insert_at`]: #method.insert_at +/// [`insert`]: method@Self::insert +/// [`insert_at`]: method@Self::insert_at /// [`Key`]: struct@Key /// [`Stream`]: https://docs.rs/futures/0.1/futures/stream/trait.Stream.html -/// [`poll`]: #method.poll -/// [`Stream::poll`]: #method.poll +/// [`poll`]: method@Self::poll +/// [`Stream::poll`]: method@Self::poll /// [`DelayQueue`]: struct@DelayQueue /// [`delay_for`]: fn@super::delay_for -/// [`slab`]: https://docs.rs/slab -/// [`capacity`]: #method.capacity -/// [`reserve`]: #method.reserve +/// [`slab`]: slab +/// [`capacity`]: method@Self::capacity +/// [`reserve`]: method@Self::reserve #[derive(Debug)] pub struct DelayQueue { /// Stores data associated with entries @@ -295,9 +295,9 @@ impl DelayQueue { /// # } /// ``` /// - /// [`poll`]: #method.poll - /// [`remove`]: #method.remove - /// [`reset`]: #method.reset + /// [`poll`]: method@Self::poll + /// [`remove`]: method@Self::remove + /// [`reset`]: method@Self::reset /// [`Key`]: struct@Key /// [type]: # pub fn insert_at(&mut self, value: T, when: Instant) -> Key { @@ -403,9 +403,9 @@ impl DelayQueue { /// # } /// ``` /// - /// [`poll`]: #method.poll - /// [`remove`]: #method.remove - /// [`reset`]: #method.reset + /// [`poll`]: method@Self::poll + /// [`remove`]: method@Self::remove + /// [`reset`]: method@Self::reset /// [`Key`]: struct@Key /// [type]: # pub fn insert(&mut self, value: T, timeout: Duration) -> Key { @@ -574,7 +574,7 @@ impl DelayQueue { /// /// Note that this method has no effect on the allocated capacity. /// - /// [`poll`]: #method.poll + /// [`poll`]: method@Self::poll /// /// # Examples /// diff --git a/tokio/src/time/driver/atomic_stack.rs b/tokio/src/time/driver/atomic_stack.rs index 7e5a83fa521..d27579f920f 100644 --- a/tokio/src/time/driver/atomic_stack.rs +++ b/tokio/src/time/driver/atomic_stack.rs @@ -118,7 +118,7 @@ impl Drop for AtomicStackEntries { fn drop(&mut self) { for entry in self { // Flag the entry as errored - entry.error(); + entry.error(Error::shutdown()); } } } diff --git a/tokio/src/time/driver/entry.rs b/tokio/src/time/driver/entry.rs index 8e1e6b2f92e..974465c19be 100644 --- a/tokio/src/time/driver/entry.rs +++ b/tokio/src/time/driver/entry.rs @@ -5,8 +5,8 @@ use crate::time::{Duration, Error, Instant}; use std::cell::UnsafeCell; use std::ptr; -use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering::SeqCst; +use std::sync::atomic::{AtomicBool, AtomicU8}; use std::sync::{Arc, Weak}; use std::task::{self, Poll}; use std::u64; @@ -30,7 +30,8 @@ pub(crate) struct Entry { /// Timer internals. Using a weak pointer allows the timer to shutdown /// without all `Delay` instances having completed. /// - /// When `None`, the entry has not yet been linked with a timer instance. + /// When empty, it means that the entry has not yet been linked with a + /// timer instance. inner: Weak, /// Tracks the entry state. This value contains the following information: @@ -44,6 +45,11 @@ pub(crate) struct Entry { /// instant, this value is changed. state: AtomicU64, + /// Stores the actual error. If `state` indicates that an error occurred, + /// this is guaranteed to be a non-zero value representing the first error + /// that occurred. Otherwise its value is undefined. + error: AtomicU8, + /// Task to notify once the deadline is reached. waker: AtomicWaker, @@ -109,8 +115,9 @@ impl Entry { let entry: Entry; // Increment the number of active timeouts - if inner.increment().is_err() { - entry = Entry::new2(deadline, duration, Weak::new(), ERROR) + if let Err(err) = inner.increment() { + entry = Entry::new2(deadline, duration, Weak::new(), ERROR); + entry.error(err); } else { let when = inner.normalize_deadline(deadline); let state = if when <= inner.elapsed() { @@ -122,8 +129,8 @@ impl Entry { } let entry = Arc::new(entry); - if inner.queue(&entry).is_err() { - entry.error(); + if let Err(err) = inner.queue(&entry) { + entry.error(err); } entry @@ -189,7 +196,12 @@ impl Entry { self.waker.wake(); } - pub(crate) fn error(&self) { + pub(crate) fn error(&self, error: Error) { + // Record the precise nature of the error, if there isn't already an + // error present. If we don't actually transition to the error state + // below, that's fine, as the error details we set here will be ignored. + self.error.compare_and_swap(0, error.as_u8(), SeqCst); + // Only transition to the error state if not currently elapsed let mut curr = self.state.load(SeqCst); @@ -234,7 +246,7 @@ impl Entry { if is_elapsed(curr) { return Poll::Ready(if curr == ERROR { - Err(Error::shutdown()) + Err(Error::from_u8(self.error.load(SeqCst))) } else { Ok(()) }); @@ -246,7 +258,7 @@ impl Entry { if is_elapsed(curr) { return Poll::Ready(if curr == ERROR { - Err(Error::shutdown()) + Err(Error::from_u8(self.error.load(SeqCst))) } else { Ok(()) }); @@ -309,6 +321,7 @@ impl Entry { waker: AtomicWaker::new(), state: AtomicU64::new(state), queued: AtomicBool::new(false), + error: AtomicU8::new(0), next_atomic: UnsafeCell::new(ptr::null_mut()), when: UnsafeCell::new(None), next_stack: UnsafeCell::new(None), diff --git a/tokio/src/time/driver/mod.rs b/tokio/src/time/driver/mod.rs index 4616816f3f4..92a8474a7e0 100644 --- a/tokio/src/time/driver/mod.rs +++ b/tokio/src/time/driver/mod.rs @@ -26,29 +26,29 @@ use std::sync::Arc; use std::usize; use std::{cmp, fmt}; -/// Time implementation that drives [`Delay`], [`Interval`], and [`Timeout`]. +/// Time implementation that drives [`Delay`][delay], [`Interval`][interval], and [`Timeout`][timeout]. /// /// A `Driver` instance tracks the state necessary for managing time and -/// notifying the [`Delay`] instances once their deadlines are reached. +/// notifying the [`Delay`][delay] instances once their deadlines are reached. /// -/// It is expected that a single instance manages many individual [`Delay`] +/// It is expected that a single instance manages many individual [`Delay`][delay] /// instances. The `Driver` implementation is thread-safe and, as such, is able /// to handle callers from across threads. /// -/// After creating the `Driver` instance, the caller must repeatedly call -/// [`turn`]. The time driver will perform no work unless [`turn`] is called -/// repeatedly. +/// After creating the `Driver` instance, the caller must repeatedly call `park` +/// or `park_timeout`. The time driver will perform no work unless `park` or +/// `park_timeout` is called repeatedly. /// /// The driver has a resolution of one millisecond. Any unit of time that falls /// between milliseconds are rounded up to the next millisecond. /// -/// When an instance is dropped, any outstanding [`Delay`] instance that has not +/// When an instance is dropped, any outstanding [`Delay`][delay] instance that has not /// elapsed will be notified with an error. At this point, calling `poll` on the -/// [`Delay`] instance will result in `Err` being returned. +/// [`Delay`][delay] instance will result in panic. /// /// # Implementation /// -/// THe time driver is based on the [paper by Varghese and Lauck][paper]. +/// The time driver is based on the [paper by Varghese and Lauck][paper]. /// /// A hashed timing wheel is a vector of slots, where each slot handles a time /// slice. As time progresses, the timer walks over the slot for the current @@ -73,9 +73,14 @@ use std::{cmp, fmt}; /// When the timer processes entries at level zero, it will notify all the /// `Delay` instances as their deadlines have been reached. For all higher /// levels, all entries will be redistributed across the wheel at the next level -/// down. Eventually, as time progresses, entries will [`Delay`] instances will +/// down. Eventually, as time progresses, entries will [`Delay`][delay] instances will /// either be canceled (dropped) or their associated entries will reach level /// zero and be notified. +/// +/// [paper]: http://www.cs.columbia.edu/~nahum/w6998/papers/ton97-timing-wheels.pdf +/// [delay]: crate::time::Delay +/// [timeout]: crate::time::Timeout +/// [interval]: crate::time::Interval #[derive(Debug)] pub(crate) struct Driver { /// Shared state @@ -119,7 +124,7 @@ where T: Park, { /// Creates a new `Driver` instance that uses `park` to block the current - /// thread and `now` to get the current `Instant`. + /// thread and `clock` to get the current `Instant`. /// /// Specifying the source of time is useful when testing. pub(crate) fn new(park: T, clock: Clock) -> Driver { @@ -220,7 +225,7 @@ where // The entry's deadline is invalid, so error it and update the // internal state accordingly. entry.set_when_internal(None); - entry.error(); + entry.error(Error::invalid()); } } } @@ -312,7 +317,7 @@ impl Drop for Driver { let mut poll = wheel::Poll::new(u64::MAX); while let Some(entry) = self.wheel.poll(&mut poll, &mut ()) { - entry.error(); + entry.error(Error::shutdown()); } } } diff --git a/tokio/src/time/driver/registration.rs b/tokio/src/time/driver/registration.rs index b77357e7353..3a0b34501b0 100644 --- a/tokio/src/time/driver/registration.rs +++ b/tokio/src/time/driver/registration.rs @@ -40,9 +40,12 @@ impl Registration { pub(crate) fn poll_elapsed(&self, cx: &mut task::Context<'_>) -> Poll> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); - self.entry.poll_elapsed(cx) + self.entry.poll_elapsed(cx).map(move |r| { + coop.made_progress(); + r + }) } } diff --git a/tokio/src/time/error.rs b/tokio/src/time/error.rs index 0667b97ac1e..2f93d67115b 100644 --- a/tokio/src/time/error.rs +++ b/tokio/src/time/error.rs @@ -24,10 +24,12 @@ use std::fmt; #[derive(Debug)] pub struct Error(Kind); -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] +#[repr(u8)] enum Kind { - Shutdown, - AtCapacity, + Shutdown = 1, + AtCapacity = 2, + Invalid = 3, } impl Error { @@ -56,6 +58,32 @@ impl Error { _ => false, } } + + /// Create an error representing a misconfigured timer. + pub fn invalid() -> Error { + Error(Invalid) + } + + /// Returns `true` if the error was caused by the timer being misconfigured. + pub fn is_invalid(&self) -> bool { + match self.0 { + Kind::Invalid => true, + _ => false, + } + } + + pub(crate) fn as_u8(&self) -> u8 { + self.0 as u8 + } + + pub(crate) fn from_u8(n: u8) -> Self { + Error(match n { + 1 => Shutdown, + 2 => AtCapacity, + 3 => Invalid, + _ => panic!("u8 does not correspond to any time error variant"), + }) + } } impl error::Error for Error {} @@ -66,6 +94,7 @@ impl fmt::Display for Error { let descr = match self.0 { Shutdown => "the timer is shutdown, must be called from the context of Tokio runtime", AtCapacity => "timer is at capacity and cannot create a new entry", + Invalid => "timer duration exceeds maximum duration", }; write!(fmt, "{}", descr) } diff --git a/tokio/src/time/interval.rs b/tokio/src/time/interval.rs index 090e2d1f05a..1fa21e66418 100644 --- a/tokio/src/time/interval.rs +++ b/tokio/src/time/interval.rs @@ -33,6 +33,37 @@ use std::task::{Context, Poll}; /// // approximately 20ms have elapsed. /// } /// ``` +/// +/// A simple example using `interval` to execute a task every two seconds. +/// +/// The difference between `interval` and [`delay_for`] is that an `interval` +/// measures the time since the last tick, which means that `.tick().await` +/// may wait for a shorter time than the duration specified for the interval +/// if some time has passed between calls to `.tick().await`. +/// +/// If the tick in the example below was replaced with [`delay_for`], the task +/// would only be executed once every three seconds, and not every two +/// seconds. +/// +/// ``` +/// use tokio::time; +/// +/// async fn task_that_takes_a_second() { +/// println!("hello"); +/// time::delay_for(time::Duration::from_secs(1)).await +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let mut interval = time::interval(time::Duration::from_secs(2)); +/// for _i in 0..5 { +/// interval.tick().await; +/// task_that_takes_a_second().await; +/// } +/// } +/// ``` +/// +/// [`delay_for`]: crate::time::delay_for() pub fn interval(period: Duration) -> Interval { assert!(period > Duration::new(0, 0), "`period` must be non-zero."); diff --git a/tokio/src/time/mod.rs b/tokio/src/time/mod.rs index 7070d6b2573..c532b2c175f 100644 --- a/tokio/src/time/mod.rs +++ b/tokio/src/time/mod.rs @@ -24,7 +24,7 @@ //! //! # Examples //! -//! Wait 100ms and print "Hello World!" +//! Wait 100ms and print "100 ms have elapsed" //! //! ``` //! use tokio::time::delay_for; @@ -58,6 +58,38 @@ //! } //! # } //! ``` +//! +//! A simple example using [`interval`] to execute a task every two seconds. +//! +//! The difference between [`interval`] and [`delay_for`] is that an +//! [`interval`] measures the time since the last tick, which means that +//! `.tick().await` may wait for a shorter time than the duration specified +//! for the interval if some time has passed between calls to `.tick().await`. +//! +//! If the tick in the example below was replaced with [`delay_for`], the task +//! would only be executed once every three seconds, and not every two +//! seconds. +//! +//! ``` +//! use tokio::time; +//! +//! async fn task_that_takes_a_second() { +//! println!("hello"); +//! time::delay_for(time::Duration::from_secs(1)).await +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let mut interval = time::interval(time::Duration::from_secs(2)); +//! for _i in 0..5 { +//! interval.tick().await; +//! task_that_takes_a_second().await; +//! } +//! } +//! ``` +//! +//! [`delay_for`]: crate::time::delay_for() +//! [`interval`]: crate::time::interval() mod clock; pub(crate) use self::clock::Clock; diff --git a/tokio/src/time/tests/test_delay.rs b/tokio/src/time/tests/test_delay.rs index f843434be49..b708f6fc045 100644 --- a/tokio/src/time/tests/test_delay.rs +++ b/tokio/src/time/tests/test_delay.rs @@ -1,5 +1,3 @@ -#![warn(rust_2018_idioms)] - use crate::park::{Park, Unpark}; use crate::time::driver::{Driver, Entry, Handle}; use crate::time::Clock; diff --git a/tokio/src/time/throttle.rs b/tokio/src/time/throttle.rs index 435bef63815..d53a6f76211 100644 --- a/tokio/src/time/throttle.rs +++ b/tokio/src/time/throttle.rs @@ -16,7 +16,7 @@ use pin_project_lite::pin_project; /// # Example /// /// Create a throttled stream. -/// ```rust,norun +/// ```rust,no_run /// use std::time::Duration; /// use tokio::stream::StreamExt; /// use tokio::time::throttle; diff --git a/tokio/src/time/timeout.rs b/tokio/src/time/timeout.rs index 401856a881a..efc3dc5c069 100644 --- a/tokio/src/time/timeout.rs +++ b/tokio/src/time/timeout.rs @@ -6,6 +6,7 @@ use crate::time::{delay_until, Delay, Duration, Instant}; +use pin_project_lite::pin_project; use std::fmt; use std::future::Future; use std::pin::Pin; @@ -99,12 +100,16 @@ where } } -/// Future returned by [`timeout`](timeout) and [`timeout_at`](timeout_at). -#[must_use = "futures do nothing unless you `.await` or poll them"] -#[derive(Debug)] -pub struct Timeout { - value: T, - delay: Delay, +pin_project! { + /// Future returned by [`timeout`](timeout) and [`timeout_at`](timeout_at). + #[must_use = "futures do nothing unless you `.await` or poll them"] + #[derive(Debug)] + pub struct Timeout { + #[pin] + value: T, + #[pin] + delay: Delay, + } } /// Error returned by `Timeout`. @@ -146,24 +151,18 @@ where { type Output = Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - // First, try polling the future + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + let me = self.project(); - // Safety: we never move `self.value` - unsafe { - let p = self.as_mut().map_unchecked_mut(|me| &mut me.value); - if let Poll::Ready(v) = p.poll(cx) { - return Poll::Ready(Ok(v)); - } + // First, try polling the future + if let Poll::Ready(v) = me.value.poll(cx) { + return Poll::Ready(Ok(v)); } // Now check the timer - // Safety: X_X! - unsafe { - match self.map_unchecked_mut(|me| &mut me.delay).poll(cx) { - Poll::Ready(()) => Poll::Ready(Err(Elapsed(()))), - Poll::Pending => Poll::Pending, - } + match me.delay.poll(cx) { + Poll::Ready(()) => Poll::Ready(Err(Elapsed(()))), + Poll::Pending => Poll::Pending, } } } diff --git a/tokio/src/util/intrusive_double_linked_list.rs b/tokio/src/util/intrusive_double_linked_list.rs new file mode 100644 index 00000000000..083fa31d3ec --- /dev/null +++ b/tokio/src/util/intrusive_double_linked_list.rs @@ -0,0 +1,788 @@ +//! An intrusive double linked list of data + +#![allow(dead_code, unreachable_pub)] + +use core::{ + marker::PhantomPinned, + ops::{Deref, DerefMut}, + ptr::NonNull, +}; + +/// A node which carries data of type `T` and is stored in an intrusive list +#[derive(Debug)] +pub struct ListNode { + /// The previous node in the list. `None` if there is no previous node. + prev: Option>>, + /// The next node in the list. `None` if there is no previous node. + next: Option>>, + /// The data which is associated to this list item + data: T, + /// Prevents `ListNode`s from being `Unpin`. They may never be moved, since + /// the list semantics require addresses to be stable. + _pin: PhantomPinned, +} + +impl ListNode { + /// Creates a new node with the associated data + pub fn new(data: T) -> ListNode { + Self { + prev: None, + next: None, + data, + _pin: PhantomPinned, + } + } +} + +impl Deref for ListNode { + type Target = T; + + fn deref(&self) -> &T { + &self.data + } +} + +impl DerefMut for ListNode { + fn deref_mut(&mut self) -> &mut T { + &mut self.data + } +} + +/// An intrusive linked list of nodes, where each node carries associated data +/// of type `T`. +#[derive(Debug)] +pub struct LinkedList { + head: Option>>, + tail: Option>>, +} + +impl LinkedList { + /// Creates an empty linked list + pub fn new() -> Self { + LinkedList:: { + head: None, + tail: None, + } + } + + /// Adds a node at the front of the linked list. + /// Safety: This function is only safe as long as `node` is guaranteed to + /// get removed from the list before it gets moved or dropped. + /// In addition to this `node` may not be added to another other list before + /// it is removed from the current one. + pub unsafe fn add_front(&mut self, node: &mut ListNode) { + node.next = self.head; + node.prev = None; + if let Some(mut head) = self.head { + head.as_mut().prev = Some(node.into()) + }; + self.head = Some(node.into()); + if self.tail.is_none() { + self.tail = Some(node.into()); + } + } + + /// Inserts a node into the list in a way that the list keeps being sorted. + /// Safety: This function is only safe as long as `node` is guaranteed to + /// get removed from the list before it gets moved or dropped. + /// In addition to this `node` may not be added to another other list before + /// it is removed from the current one. + pub unsafe fn add_sorted(&mut self, node: &mut ListNode) + where + T: PartialOrd, + { + if self.head.is_none() { + // First node in the list + self.head = Some(node.into()); + self.tail = Some(node.into()); + return; + } + + let mut prev: Option>> = None; + let mut current = self.head; + + while let Some(mut current_node) = current { + if node.data < current_node.as_ref().data { + // Need to insert before the current node + current_node.as_mut().prev = Some(node.into()); + match prev { + Some(mut prev) => { + prev.as_mut().next = Some(node.into()); + } + None => { + // We are inserting at the beginning of the list + self.head = Some(node.into()); + } + } + node.next = current; + node.prev = prev; + return; + } + prev = current; + current = current_node.as_ref().next; + } + + // We looped through the whole list and the nodes data is bigger or equal + // than everything we found up to now. + // Insert at the end. Since we checked before that the list isn't empty, + // tail always has a value. + node.prev = self.tail; + node.next = None; + self.tail.as_mut().unwrap().as_mut().next = Some(node.into()); + self.tail = Some(node.into()); + } + + /// Returns the first node in the linked list without removing it from the list + /// The function is only safe as long as valid pointers are stored inside + /// the linked list. + /// The returned pointer is only guaranteed to be valid as long as the list + /// is not mutated + pub fn peek_first(&self) -> Option<&mut ListNode> { + // Safety: When the node was inserted it was promised that it is alive + // until it gets removed from the list. + // The returned node has a pointer which constrains it to the lifetime + // of the list. This is ok, since the Node is supposed to outlive + // its insertion in the list. + unsafe { + self.head + .map(|mut node| &mut *(node.as_mut() as *mut ListNode)) + } + } + + /// Returns the last node in the linked list without removing it from the list + /// The function is only safe as long as valid pointers are stored inside + /// the linked list. + /// The returned pointer is only guaranteed to be valid as long as the list + /// is not mutated + pub fn peek_last(&self) -> Option<&mut ListNode> { + // Safety: When the node was inserted it was promised that it is alive + // until it gets removed from the list. + // The returned node has a pointer which constrains it to the lifetime + // of the list. This is ok, since the Node is supposed to outlive + // its insertion in the list. + unsafe { + self.tail + .map(|mut node| &mut *(node.as_mut() as *mut ListNode)) + } + } + + /// Removes the first node from the linked list + pub fn remove_first(&mut self) -> Option<&mut ListNode> { + #![allow(clippy::debug_assert_with_mut_call)] + + // Safety: When the node was inserted it was promised that it is alive + // until it gets removed from the list + unsafe { + let mut head = self.head?; + self.head = head.as_mut().next; + + let first_ref = head.as_mut(); + match first_ref.next { + None => { + // This was the only node in the list + debug_assert_eq!(Some(first_ref.into()), self.tail); + self.tail = None; + } + Some(mut next) => { + next.as_mut().prev = None; + } + } + + first_ref.prev = None; + first_ref.next = None; + Some(&mut *(first_ref as *mut ListNode)) + } + } + + /// Removes the last node from the linked list and returns it + pub fn remove_last(&mut self) -> Option<&mut ListNode> { + #![allow(clippy::debug_assert_with_mut_call)] + + // Safety: When the node was inserted it was promised that it is alive + // until it gets removed from the list + unsafe { + let mut tail = self.tail?; + self.tail = tail.as_mut().prev; + + let last_ref = tail.as_mut(); + match last_ref.prev { + None => { + // This was the last node in the list + debug_assert_eq!(Some(last_ref.into()), self.head); + self.head = None; + } + Some(mut prev) => { + prev.as_mut().next = None; + } + } + + last_ref.prev = None; + last_ref.next = None; + Some(&mut *(last_ref as *mut ListNode)) + } + } + + /// Returns whether the linked list doesn not contain any node + pub fn is_empty(&self) -> bool { + if self.head.is_some() { + return false; + } + + debug_assert!(self.tail.is_none()); + true + } + + /// Removes the given `node` from the linked list. + /// Returns whether the `node` was removed. + /// It is also only safe if it is known that the `node` is either part of this + /// list, or of no list at all. If `node` is part of another list, the + /// behavior is undefined. + pub unsafe fn remove(&mut self, node: &mut ListNode) -> bool { + #![allow(clippy::debug_assert_with_mut_call)] + + match node.prev { + None => { + // This might be the first node in the list. If it is not, the + // node is not in the list at all. Since our precondition is that + // the node must either be in this list or in no list, we check that + // the node is really in no list. + if self.head != Some(node.into()) { + debug_assert!(node.next.is_none()); + return false; + } + self.head = node.next; + } + Some(mut prev) => { + debug_assert_eq!(prev.as_ref().next, Some(node.into())); + prev.as_mut().next = node.next; + } + } + + match node.next { + None => { + // This must be the last node in our list. Otherwise the list + // is inconsistent. + debug_assert_eq!(self.tail, Some(node.into())); + self.tail = node.prev; + } + Some(mut next) => { + debug_assert_eq!(next.as_mut().prev, Some(node.into())); + next.as_mut().prev = node.prev; + } + } + + node.next = None; + node.prev = None; + + true + } + + /// Drains the list iby calling a callback on each list node + /// + /// The method does not return an iterator since stopping or deferring + /// draining the list is not permitted. If the method would push nodes to + /// an iterator we could not guarantee that the nodes do not get utilized + /// after having been removed from the list anymore. + pub fn drain(&mut self, mut func: F) + where + F: FnMut(&mut ListNode), + { + let mut current = self.head; + self.head = None; + self.tail = None; + + while let Some(mut node) = current { + // Safety: The nodes have not been removed from the list yet and must + // therefore contain valid data. The nodes can also not be added to + // the list again during iteration, since the list is mutably borrowed. + unsafe { + let node_ref = node.as_mut(); + current = node_ref.next; + + node_ref.next = None; + node_ref.prev = None; + + // Note: We do not reset the pointers from the next element in the + // list to the current one since we will iterate over the whole + // list anyway, and therefore clean up all pointers. + + func(node_ref); + } + } + } + + /// Drains the list in reverse order by calling a callback on each list node + /// + /// The method does not return an iterator since stopping or deferring + /// draining the list is not permitted. If the method would push nodes to + /// an iterator we could not guarantee that the nodes do not get utilized + /// after having been removed from the list anymore. + pub fn reverse_drain(&mut self, mut func: F) + where + F: FnMut(&mut ListNode), + { + let mut current = self.tail; + self.head = None; + self.tail = None; + + while let Some(mut node) = current { + // Safety: The nodes have not been removed from the list yet and must + // therefore contain valid data. The nodes can also not be added to + // the list again during iteration, since the list is mutably borrowed. + unsafe { + let node_ref = node.as_mut(); + current = node_ref.prev; + + node_ref.next = None; + node_ref.prev = None; + + // Note: We do not reset the pointers from the next element in the + // list to the current one since we will iterate over the whole + // list anyway, and therefore clean up all pointers. + + func(node_ref); + } + } + } +} + +#[cfg(all(test, feature = "std"))] // Tests make use of Vec at the moment +mod tests { + use super::*; + + fn collect_list(mut list: LinkedList) -> Vec { + let mut result = Vec::new(); + list.drain(|node| { + result.push(**node); + }); + result + } + + fn collect_reverse_list(mut list: LinkedList) -> Vec { + let mut result = Vec::new(); + list.reverse_drain(|node| { + result.push(**node); + }); + result + } + + unsafe fn add_nodes(list: &mut LinkedList, nodes: &mut [&mut ListNode]) { + for node in nodes.iter_mut() { + list.add_front(node); + } + } + + unsafe fn assert_clean(node: &mut ListNode) { + assert!(node.next.is_none()); + assert!(node.prev.is_none()); + } + + #[test] + fn insert_and_iterate() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + let mut setup = |list: &mut LinkedList| { + assert_eq!(true, list.is_empty()); + list.add_front(&mut c); + assert_eq!(31, **list.peek_first().unwrap()); + assert_eq!(false, list.is_empty()); + list.add_front(&mut b); + assert_eq!(7, **list.peek_first().unwrap()); + list.add_front(&mut a); + assert_eq!(5, **list.peek_first().unwrap()); + }; + + let mut list = LinkedList::new(); + setup(&mut list); + let items: Vec = collect_list(list); + assert_eq!([5, 7, 31].to_vec(), items); + + let mut list = LinkedList::new(); + setup(&mut list); + let items: Vec = collect_reverse_list(list); + assert_eq!([31, 7, 5].to_vec(), items); + } + } + + #[test] + fn add_sorted() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + let mut d = ListNode::new(99); + + let mut list = LinkedList::new(); + list.add_sorted(&mut a); + let items: Vec = collect_list(list); + assert_eq!([5].to_vec(), items); + + let mut list = LinkedList::new(); + list.add_sorted(&mut a); + let items: Vec = collect_reverse_list(list); + assert_eq!([5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut c, &mut b]); + list.add_sorted(&mut a); + let items: Vec = collect_list(list); + assert_eq!([5, 7, 31, 99].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut c, &mut b]); + list.add_sorted(&mut a); + let items: Vec = collect_reverse_list(list); + assert_eq!([99, 31, 7, 5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut c, &mut a]); + list.add_sorted(&mut b); + let items: Vec = collect_list(list); + assert_eq!([5, 7, 31, 99].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut c, &mut a]); + list.add_sorted(&mut b); + let items: Vec = collect_reverse_list(list); + assert_eq!([99, 31, 7, 5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut b, &mut a]); + list.add_sorted(&mut c); + let items: Vec = collect_list(list); + assert_eq!([5, 7, 31, 99].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut b, &mut a]); + list.add_sorted(&mut c); + let items: Vec = collect_reverse_list(list); + assert_eq!([99, 31, 7, 5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + list.add_sorted(&mut d); + let items: Vec = collect_list(list); + assert_eq!([5, 7, 31, 99].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + list.add_sorted(&mut d); + let items: Vec = collect_reverse_list(list); + assert_eq!([99, 31, 7, 5].to_vec(), items); + } + } + + #[test] + fn drain_and_collect() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + + let taken_items: Vec = collect_list(list); + assert_eq!([5, 7, 31].to_vec(), taken_items); + } + } + + #[test] + fn peek_last() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + + let last = list.peek_last(); + assert_eq!(31, **last.unwrap()); + list.remove_last(); + + let last = list.peek_last(); + assert_eq!(7, **last.unwrap()); + list.remove_last(); + + let last = list.peek_last(); + assert_eq!(5, **last.unwrap()); + list.remove_last(); + + let last = list.peek_last(); + assert!(last.is_none()); + } + } + + #[test] + fn remove_first() { + unsafe { + // We iterate forward and backwards through the manipulated lists + // to make sure pointers in both directions are still ok. + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + let removed = list.remove_first().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_list(list); + assert_eq!([7, 31].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + let removed = list.remove_first().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_reverse_list(list); + assert_eq!([31, 7].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + let removed = list.remove_first().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_list(list); + assert_eq!([7].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + let removed = list.remove_first().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_reverse_list(list); + assert_eq!([7].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut a]); + let removed = list.remove_first().unwrap(); + assert_clean(removed); + assert!(list.is_empty()); + let items: Vec = collect_list(list); + assert!(items.is_empty()); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut a]); + let removed = list.remove_first().unwrap(); + assert_clean(removed); + assert!(list.is_empty()); + let items: Vec = collect_reverse_list(list); + assert!(items.is_empty()); + } + } + + #[test] + fn remove_last() { + unsafe { + // We iterate forward and backwards through the manipulated lists + // to make sure pointers in both directions are still ok. + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + let removed = list.remove_last().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_list(list); + assert_eq!([5, 7].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + let removed = list.remove_last().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_reverse_list(list); + assert_eq!([7, 5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + let removed = list.remove_last().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_list(list); + assert_eq!([5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + let removed = list.remove_last().unwrap(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_reverse_list(list); + assert_eq!([5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut a]); + let removed = list.remove_last().unwrap(); + assert_clean(removed); + assert!(list.is_empty()); + let items: Vec = collect_list(list); + assert!(items.is_empty()); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut a]); + let removed = list.remove_last().unwrap(); + assert_clean(removed); + assert!(list.is_empty()); + let items: Vec = collect_reverse_list(list); + assert!(items.is_empty()); + } + } + + #[test] + fn remove_by_address() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + { + // Remove first + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut a)); + assert_clean((&mut a).into()); + // a should be no longer there and can't be removed twice + assert_eq!(false, list.remove(&mut a)); + assert_eq!(Some((&mut b).into()), list.head); + assert_eq!(Some((&mut c).into()), b.next); + assert_eq!(Some((&mut b).into()), c.prev); + let items: Vec = collect_list(list); + assert_eq!([7, 31].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut a)); + assert_clean((&mut a).into()); + // a should be no longer there and can't be removed twice + assert_eq!(false, list.remove(&mut a)); + assert_eq!(Some((&mut c).into()), b.next); + assert_eq!(Some((&mut b).into()), c.prev); + let items: Vec = collect_reverse_list(list); + assert_eq!([31, 7].to_vec(), items); + } + + { + // Remove middle + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut b)); + assert_clean((&mut b).into()); + assert_eq!(Some((&mut c).into()), a.next); + assert_eq!(Some((&mut a).into()), c.prev); + let items: Vec = collect_list(list); + assert_eq!([5, 31].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut b)); + assert_clean((&mut b).into()); + assert_eq!(Some((&mut c).into()), a.next); + assert_eq!(Some((&mut a).into()), c.prev); + let items: Vec = collect_reverse_list(list); + assert_eq!([31, 5].to_vec(), items); + } + + { + // Remove last + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut c)); + assert_clean((&mut c).into()); + assert!(b.next.is_none()); + assert_eq!(Some((&mut b).into()), list.tail); + let items: Vec = collect_list(list); + assert_eq!([5, 7].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut c)); + assert_clean((&mut c).into()); + assert!(b.next.is_none()); + assert_eq!(Some((&mut b).into()), list.tail); + let items: Vec = collect_reverse_list(list); + assert_eq!([7, 5].to_vec(), items); + } + + { + // Remove first of two + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + assert_eq!(true, list.remove(&mut a)); + assert_clean((&mut a).into()); + // a should be no longer there and can't be removed twice + assert_eq!(false, list.remove(&mut a)); + assert_eq!(Some((&mut b).into()), list.head); + assert_eq!(Some((&mut b).into()), list.tail); + assert!(b.next.is_none()); + assert!(b.prev.is_none()); + let items: Vec = collect_list(list); + assert_eq!([7].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + assert_eq!(true, list.remove(&mut a)); + assert_clean((&mut a).into()); + // a should be no longer there and can't be removed twice + assert_eq!(false, list.remove(&mut a)); + assert_eq!(Some((&mut b).into()), list.head); + assert_eq!(Some((&mut b).into()), list.tail); + assert!(b.next.is_none()); + assert!(b.prev.is_none()); + let items: Vec = collect_reverse_list(list); + assert_eq!([7].to_vec(), items); + } + + { + // Remove last of two + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + assert_eq!(true, list.remove(&mut b)); + assert_clean((&mut b).into()); + assert_eq!(Some((&mut a).into()), list.head); + assert_eq!(Some((&mut a).into()), list.tail); + assert!(a.next.is_none()); + assert!(a.prev.is_none()); + let items: Vec = collect_list(list); + assert_eq!([5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + assert_eq!(true, list.remove(&mut b)); + assert_clean((&mut b).into()); + assert_eq!(Some((&mut a).into()), list.head); + assert_eq!(Some((&mut a).into()), list.tail); + assert!(a.next.is_none()); + assert!(a.prev.is_none()); + let items: Vec = collect_reverse_list(list); + assert_eq!([5].to_vec(), items); + } + + { + // Remove last item + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut a]); + assert_eq!(true, list.remove(&mut a)); + assert_clean((&mut a).into()); + assert!(list.head.is_none()); + assert!(list.tail.is_none()); + let items: Vec = collect_list(list); + assert!(items.is_empty()); + } + + { + // Remove missing + let mut list = LinkedList::new(); + list.add_front(&mut b); + list.add_front(&mut a); + assert_eq!(false, list.remove(&mut c)); + } + } + } +} diff --git a/tokio/src/util/mod.rs b/tokio/src/util/mod.rs index a093395c020..6dda08ca411 100644 --- a/tokio/src/util/mod.rs +++ b/tokio/src/util/mod.rs @@ -19,6 +19,10 @@ cfg_rt_threaded! { pub(crate) use try_lock::TryLock; } +pub(crate) mod trace; + #[cfg(any(feature = "macros", feature = "stream"))] #[cfg_attr(not(feature = "macros"), allow(unreachable_pub))] pub use rand::thread_rng_n; + +pub(crate) mod intrusive_double_linked_list; diff --git a/tokio/src/util/trace.rs b/tokio/src/util/trace.rs new file mode 100644 index 00000000000..d8c6120d97c --- /dev/null +++ b/tokio/src/util/trace.rs @@ -0,0 +1,57 @@ +cfg_trace! { + cfg_rt_core! { + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + use pin_project_lite::pin_project; + + use tracing::Span; + + pin_project! { + /// A future that has been instrumented with a `tracing` span. + #[derive(Debug, Clone)] + pub(crate) struct Instrumented { + #[pin] + inner: T, + span: Span, + } + } + + impl Future for Instrumented { + type Output = T::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let _enter = this.span.enter(); + this.inner.poll(cx) + } + } + + impl Instrumented { + pub(crate) fn new(inner: T, span: Span) -> Self { + Self { inner, span } + } + } + + #[inline] + pub(crate) fn task(task: F, kind: &'static str) -> Instrumented { + let span = tracing::trace_span!( + target: "tokio::task", + "task", + %kind, + future = %std::any::type_name::(), + ); + Instrumented::new(task, span) + } + } +} + +cfg_not_trace! { + cfg_rt_core! { + #[inline] + pub(crate) fn task(task: F, _: &'static str) -> F { + // nop + task + } + } +} diff --git a/tokio/tests/async_send_sync.rs b/tokio/tests/async_send_sync.rs index 45d11bd441a..afe053b1010 100644 --- a/tokio/tests/async_send_sync.rs +++ b/tokio/tests/async_send_sync.rs @@ -259,3 +259,6 @@ async_assert_fn!(tokio::time::timeout_at(Instant, BoxFutureSync<()>): Send & Syn async_assert_fn!(tokio::time::timeout_at(Instant, BoxFutureSend<()>): Send & !Sync); async_assert_fn!(tokio::time::timeout_at(Instant, BoxFuture<()>): !Send & !Sync); async_assert_fn!(tokio::time::Interval::tick(_): Send & Sync); + +#[cfg(tokio_unstable)] +assert_value!(tokio::sync::CancellationToken: Send & Sync); diff --git a/tokio/tests/fs_dir.rs b/tokio/tests/fs_dir.rs index eaff59da4f9..6355ef05fcb 100644 --- a/tokio/tests/fs_dir.rs +++ b/tokio/tests/fs_dir.rs @@ -2,7 +2,7 @@ #![cfg(feature = "full")] use tokio::fs; -use tokio_test::assert_ok; +use tokio_test::{assert_err, assert_ok}; use std::sync::{Arc, Mutex}; use tempfile::tempdir; @@ -28,6 +28,23 @@ async fn create_all() { assert!(new_dir_2.is_dir()); } +#[tokio::test] +async fn build_dir() { + let base_dir = tempdir().unwrap(); + let new_dir = base_dir.path().join("foo").join("bar"); + let new_dir_2 = new_dir.clone(); + + assert_ok!(fs::DirBuilder::new().recursive(true).create(new_dir).await); + + assert!(new_dir_2.is_dir()); + assert_err!( + fs::DirBuilder::new() + .recursive(false) + .create(new_dir_2) + .await + ); +} + #[tokio::test] async fn remove() { let base_dir = tempdir().unwrap(); diff --git a/tokio/tests/fs_file_mocked.rs b/tokio/tests/fs_file_mocked.rs index 0c5722404ea..2e7e8b7cf48 100644 --- a/tokio/tests/fs_file_mocked.rs +++ b/tokio/tests/fs_file_mocked.rs @@ -257,27 +257,27 @@ fn flush_while_idle() { #[test] fn read_with_buffer_larger_than_max() { // Chunks - let a = 16 * 1024; - let b = a * 2; - let c = a * 3; - let d = a * 4; + let chunk_a = 16 * 1024; + let chunk_b = chunk_a * 2; + let chunk_c = chunk_a * 3; + let chunk_d = chunk_a * 4; - assert_eq!(d / 1024, 64); + assert_eq!(chunk_d / 1024, 64); let mut data = vec![]; - for i in 0..(d - 1) { + for i in 0..(chunk_d - 1) { data.push((i % 151) as u8); } let (mock, file) = sys::File::mock(); - mock.read(&data[0..a]) - .read(&data[a..b]) - .read(&data[b..c]) - .read(&data[c..]); + mock.read(&data[0..chunk_a]) + .read(&data[chunk_a..chunk_b]) + .read(&data[chunk_b..chunk_c]) + .read(&data[chunk_c..]); let mut file = File::from_std(file); - let mut actual = vec![0; d]; + let mut actual = vec![0; chunk_d]; let mut pos = 0; while pos < data.len() { @@ -288,7 +288,7 @@ fn read_with_buffer_larger_than_max() { assert!(t.is_woken()); let n = assert_ready_ok!(t.poll()); - assert!(n <= a); + assert!(n <= chunk_a); pos += n; } @@ -300,23 +300,23 @@ fn read_with_buffer_larger_than_max() { #[test] fn write_with_buffer_larger_than_max() { // Chunks - let a = 16 * 1024; - let b = a * 2; - let c = a * 3; - let d = a * 4; + let chunk_a = 16 * 1024; + let chunk_b = chunk_a * 2; + let chunk_c = chunk_a * 3; + let chunk_d = chunk_a * 4; - assert_eq!(d / 1024, 64); + assert_eq!(chunk_d / 1024, 64); let mut data = vec![]; - for i in 0..(d - 1) { + for i in 0..(chunk_d - 1) { data.push((i % 151) as u8); } let (mock, file) = sys::File::mock(); - mock.write(&data[0..a]) - .write(&data[a..b]) - .write(&data[b..c]) - .write(&data[c..]); + mock.write(&data[0..chunk_a]) + .write(&data[chunk_a..chunk_b]) + .write(&data[chunk_b..chunk_c]) + .write(&data[chunk_c..]); let mut file = File::from_std(file); @@ -325,17 +325,17 @@ fn write_with_buffer_larger_than_max() { let mut first = true; while !rem.is_empty() { - let mut t = task::spawn(file.write(rem)); + let mut task = task::spawn(file.write(rem)); if !first { - assert_pending!(t.poll()); + assert_pending!(task.poll()); pool::run_one(); - assert!(t.is_woken()); + assert!(task.is_woken()); } first = false; - let n = assert_ready_ok!(t.poll()); + let n = assert_ready_ok!(task.poll()); rem = &rem[n..]; } diff --git a/tokio/tests/io_mem_stream.rs b/tokio/tests/io_mem_stream.rs new file mode 100644 index 00000000000..3335214cb9d --- /dev/null +++ b/tokio/tests/io_mem_stream.rs @@ -0,0 +1,83 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; + +#[tokio::test] +async fn ping_pong() { + let (mut a, mut b) = duplex(32); + + let mut buf = [0u8; 4]; + + a.write_all(b"ping").await.unwrap(); + b.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"ping"); + + b.write_all(b"pong").await.unwrap(); + a.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"pong"); +} + +#[tokio::test] +async fn across_tasks() { + let (mut a, mut b) = duplex(32); + + let t1 = tokio::spawn(async move { + a.write_all(b"ping").await.unwrap(); + let mut buf = [0u8; 4]; + a.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"pong"); + }); + + let t2 = tokio::spawn(async move { + let mut buf = [0u8; 4]; + b.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"ping"); + b.write_all(b"pong").await.unwrap(); + }); + + t1.await.unwrap(); + t2.await.unwrap(); +} + +#[tokio::test] +async fn disconnect() { + let (mut a, mut b) = duplex(32); + + let t1 = tokio::spawn(async move { + a.write_all(b"ping").await.unwrap(); + // and dropped + }); + + let t2 = tokio::spawn(async move { + let mut buf = [0u8; 32]; + let n = b.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..n], b"ping"); + + let n = b.read(&mut buf).await.unwrap(); + assert_eq!(n, 0); + }); + + t1.await.unwrap(); + t2.await.unwrap(); +} + +#[tokio::test] +async fn max_write_size() { + let (mut a, mut b) = duplex(32); + + let t1 = tokio::spawn(async move { + let n = a.write(&[0u8; 64]).await.unwrap(); + assert_eq!(n, 32); + let n = a.write(&[0u8; 64]).await.unwrap(); + assert_eq!(n, 4); + }); + + let t2 = tokio::spawn(async move { + let mut buf = [0u8; 4]; + b.read_exact(&mut buf).await.unwrap(); + }); + + t1.await.unwrap(); + t2.await.unwrap(); +} diff --git a/tokio/tests/io_read_line.rs b/tokio/tests/io_read_line.rs index 57ae37cef3e..15841c9b49d 100644 --- a/tokio/tests/io_read_line.rs +++ b/tokio/tests/io_read_line.rs @@ -1,8 +1,9 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::AsyncBufReadExt; -use tokio_test::assert_ok; +use std::io::ErrorKind; +use tokio::io::{AsyncBufReadExt, BufReader, Error}; +use tokio_test::{assert_ok, io::Builder}; use std::io::Cursor; @@ -27,3 +28,80 @@ async fn read_line() { assert_eq!(n, 0); assert_eq!(buf, ""); } + +#[tokio::test] +async fn read_line_not_all_ready() { + let mock = Builder::new() + .read(b"Hello Wor") + .read(b"ld\nFizzBuz") + .read(b"z\n1\n2") + .build(); + + let mut read = BufReader::new(mock); + + let mut line = "We say ".to_string(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, "Hello World\n".len()); + assert_eq!(line.as_str(), "We say Hello World\n"); + + line = "I solve ".to_string(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, "FizzBuzz\n".len()); + assert_eq!(line.as_str(), "I solve FizzBuzz\n"); + + line.clear(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, 2); + assert_eq!(line.as_str(), "1\n"); + + line.clear(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, 1); + assert_eq!(line.as_str(), "2"); +} + +#[tokio::test] +async fn read_line_invalid_utf8() { + let mock = Builder::new().read(b"Hello Wor\xffld.\n").build(); + + let mut read = BufReader::new(mock); + + let mut line = "Foo".to_string(); + let err = read.read_line(&mut line).await.expect_err("Should fail"); + assert_eq!(err.kind(), ErrorKind::InvalidData); + assert_eq!(err.to_string(), "stream did not contain valid UTF-8"); + assert_eq!(line.as_str(), "Foo"); +} + +#[tokio::test] +async fn read_line_fail() { + let mock = Builder::new() + .read(b"Hello Wor") + .read_error(Error::new(ErrorKind::Other, "The world has no end")) + .build(); + + let mut read = BufReader::new(mock); + + let mut line = "Foo".to_string(); + let err = read.read_line(&mut line).await.expect_err("Should fail"); + assert_eq!(err.kind(), ErrorKind::Other); + assert_eq!(err.to_string(), "The world has no end"); + assert_eq!(line.as_str(), "FooHello Wor"); +} + +#[tokio::test] +async fn read_line_fail_and_utf8_fail() { + let mock = Builder::new() + .read(b"Hello Wor") + .read(b"\xff\xff\xff") + .read_error(Error::new(ErrorKind::Other, "The world has no end")) + .build(); + + let mut read = BufReader::new(mock); + + let mut line = "Foo".to_string(); + let err = read.read_line(&mut line).await.expect_err("Should fail"); + assert_eq!(err.kind(), ErrorKind::Other); + assert_eq!(err.to_string(), "The world has no end"); + assert_eq!(line.as_str(), "Foo"); +} diff --git a/tokio/tests/io_read_until.rs b/tokio/tests/io_read_until.rs index 4e0e0d10d34..61800a0d9c1 100644 --- a/tokio/tests/io_read_until.rs +++ b/tokio/tests/io_read_until.rs @@ -1,8 +1,9 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::AsyncBufReadExt; -use tokio_test::assert_ok; +use std::io::ErrorKind; +use tokio::io::{AsyncBufReadExt, BufReader, Error}; +use tokio_test::{assert_ok, io::Builder}; #[tokio::test] async fn read_until() { @@ -21,3 +22,53 @@ async fn read_until() { assert_eq!(n, 0); assert_eq!(buf, []); } + +#[tokio::test] +async fn read_until_not_all_ready() { + let mock = Builder::new() + .read(b"Hello Wor") + .read(b"ld#Fizz\xffBuz") + .read(b"z#1#2") + .build(); + + let mut read = BufReader::new(mock); + + let mut chunk = b"We say ".to_vec(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, b"Hello World#".len()); + assert_eq!(chunk, b"We say Hello World#"); + + chunk = b"I solve ".to_vec(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, b"Fizz\xffBuzz\n".len()); + assert_eq!(chunk, b"I solve Fizz\xffBuzz#"); + + chunk.clear(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, 2); + assert_eq!(chunk, b"1#"); + + chunk.clear(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, 1); + assert_eq!(chunk, b"2"); +} + +#[tokio::test] +async fn read_until_fail() { + let mock = Builder::new() + .read(b"Hello \xffWor") + .read_error(Error::new(ErrorKind::Other, "The world has no end")) + .build(); + + let mut read = BufReader::new(mock); + + let mut chunk = b"Foo".to_vec(); + let err = read + .read_until(b'#', &mut chunk) + .await + .expect_err("Should fail"); + assert_eq!(err.kind(), ErrorKind::Other); + assert_eq!(err.to_string(), "The world has no end"); + assert_eq!(chunk, b"FooHello \xffWor"); +} diff --git a/tokio/tests/macros_join.rs b/tokio/tests/macros_join.rs index d9b748d9a7b..169e898f97d 100644 --- a/tokio/tests/macros_join.rs +++ b/tokio/tests/macros_join.rs @@ -1,3 +1,4 @@ +#![allow(clippy::blacklisted_name)] use tokio::sync::oneshot; use tokio_test::{assert_pending, assert_ready, task}; diff --git a/tokio/tests/macros_select.rs b/tokio/tests/macros_select.rs index c08e816a015..6f027f3bdfc 100644 --- a/tokio/tests/macros_select.rs +++ b/tokio/tests/macros_select.rs @@ -1,3 +1,4 @@ +#![allow(clippy::blacklisted_name)] use tokio::sync::{mpsc, oneshot}; use tokio::task; use tokio_test::{assert_ok, assert_pending, assert_ready}; diff --git a/tokio/tests/macros_test.rs b/tokio/tests/macros_test.rs new file mode 100644 index 00000000000..8e68b8a4417 --- /dev/null +++ b/tokio/tests/macros_test.rs @@ -0,0 +1,19 @@ +use tokio::test; + +#[test] +async fn test_macro_can_be_used_via_use() { + tokio::spawn(async { + assert_eq!(1 + 1, 2); + }) + .await + .unwrap(); +} + +#[tokio::test] +async fn test_macro_is_resilient_to_shadowing() { + tokio::spawn(async { + assert_eq!(1 + 1, 2); + }) + .await + .unwrap(); +} diff --git a/tokio/tests/macros_try_join.rs b/tokio/tests/macros_try_join.rs index faa55421a2b..a9251532664 100644 --- a/tokio/tests/macros_try_join.rs +++ b/tokio/tests/macros_try_join.rs @@ -1,3 +1,4 @@ +#![allow(clippy::blacklisted_name)] use tokio::sync::oneshot; use tokio_test::{assert_pending, assert_ready, task}; diff --git a/tokio/tests/rt_basic.rs b/tokio/tests/rt_basic.rs index b9e373b88f8..0885992d7d2 100644 --- a/tokio/tests/rt_basic.rs +++ b/tokio/tests/rt_basic.rs @@ -29,6 +29,7 @@ fn spawned_task_does_not_progress_without_block_on() { #[test] fn no_extra_poll() { + use pin_project_lite::pin_project; use std::pin::Pin; use std::sync::{ atomic::{AtomicUsize, Ordering::SeqCst}, @@ -37,9 +38,12 @@ fn no_extra_poll() { use std::task::{Context, Poll}; use tokio::stream::{Stream, StreamExt}; - struct TrackPolls { - npolls: Arc, - s: S, + pin_project! { + struct TrackPolls { + npolls: Arc, + #[pin] + s: S, + } } impl Stream for TrackPolls @@ -48,11 +52,9 @@ fn no_extra_poll() { { type Item = S::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // safety: we do not move s - let this = unsafe { self.get_unchecked_mut() }; + let this = self.project(); this.npolls.fetch_add(1, SeqCst); - // safety: we are pinned, and so is s - unsafe { Pin::new_unchecked(&mut this.s) }.poll_next(cx) + this.s.poll_next(cx) } } @@ -65,7 +67,7 @@ fn no_extra_poll() { let mut rt = rt(); - rt.spawn(async move { while let Some(_) = rx.next().await {} }); + rt.spawn(async move { while rx.next().await.is_some() {} }); rt.block_on(async { tokio::task::yield_now().await; }); diff --git a/tokio/tests/rt_common.rs b/tokio/tests/rt_common.rs index 9f2d3d66890..71101d46cef 100644 --- a/tokio/tests/rt_common.rs +++ b/tokio/tests/rt_common.rs @@ -601,6 +601,19 @@ rt_test! { } #[test] + // IOCP requires setting the "max thread" concurrency value. The sane, + // default, is to set this to the number of cores. Threads that poll I/O + // become associated with the IOCP handle. Once those threads sleep for any + // reason (mutex), they yield their ownership. + // + // This test hits an edge case on windows where more threads than cores are + // created, none of those threads ever yield due to being at capacity, so + // IOCP gets "starved". + // + // For now, this is a very edge case that is probably not a real production + // concern. There also isn't a great/obvious solution to take. For now, the + // test is disabled. + #[cfg(not(windows))] fn io_driver_called_when_under_load() { let mut rt = rt(); diff --git a/tokio/tests/rt_threaded.rs b/tokio/tests/rt_threaded.rs index ad063348f68..b5ec96dec35 100644 --- a/tokio/tests/rt_threaded.rs +++ b/tokio/tests/rt_threaded.rs @@ -7,6 +7,7 @@ use tokio::runtime::{self, Runtime}; use tokio::sync::oneshot; use tokio_test::{assert_err, assert_ok}; +use futures::future::poll_fn; use std::future::Future; use std::pin::Pin; use std::sync::atomic::AtomicUsize; @@ -322,6 +323,64 @@ fn multi_threadpool() { done_rx.recv().unwrap(); } +// When `block_in_place` returns, it attempts to reclaim the yielded runtime +// worker. In this case, the remainder of the task is on the runtime worker and +// must take part in the cooperative task budgeting system. +// +// The test ensures that, when this happens, attempting to consume from a +// channel yields occasionally even if there are values ready to receive. +#[test] +fn coop_and_block_in_place() { + use tokio::sync::mpsc; + + let mut rt = tokio::runtime::Builder::new() + .threaded_scheduler() + // Setting max threads to 1 prevents another thread from claiming the + // runtime worker yielded as part of `block_in_place` and guarantees the + // same thread will reclaim the worker at the end of the + // `block_in_place` call. + .max_threads(1) + .build() + .unwrap(); + + rt.block_on(async move { + let (mut tx, mut rx) = mpsc::channel(1024); + + // Fill the channel + for _ in 0..1024 { + tx.send(()).await.unwrap(); + } + + drop(tx); + + tokio::spawn(async move { + // Block in place without doing anything + tokio::task::block_in_place(|| {}); + + // Receive all the values, this should trigger a `Pending` as the + // coop limit will be reached. + poll_fn(|cx| { + while let Poll::Ready(v) = { + tokio::pin! { + let fut = rx.recv(); + } + + Pin::new(&mut fut).poll(cx) + } { + if v.is_none() { + panic!("did not yield"); + } + } + + Poll::Ready(()) + }) + .await + }) + .await + .unwrap(); + }); +} + // Testing this does not panic #[test] fn max_threads() { diff --git a/tokio/tests/sync_broadcast.rs b/tokio/tests/sync_broadcast.rs index 4fb7c0aa7bd..e37695b37d9 100644 --- a/tokio/tests/sync_broadcast.rs +++ b/tokio/tests/sync_broadcast.rs @@ -49,7 +49,7 @@ macro_rules! assert_closed { }; } -trait AssertSend: Send {} +trait AssertSend: Send + Sync {} impl AssertSend for broadcast::Sender {} impl AssertSend for broadcast::Receiver {} @@ -90,10 +90,13 @@ fn send_two_recv() { } #[tokio::test] -async fn send_recv_stream() { +async fn send_recv_into_stream_ready() { use tokio::stream::StreamExt; - let (tx, mut rx) = broadcast::channel::(8); + let (tx, rx) = broadcast::channel::(8); + tokio::pin! { + let rx = rx.into_stream(); + } assert_ok!(tx.send(1)); assert_ok!(tx.send(2)); @@ -106,6 +109,26 @@ async fn send_recv_stream() { assert_eq!(None, rx.next().await); } +#[tokio::test] +async fn send_recv_into_stream_pending() { + use tokio::stream::StreamExt; + + let (tx, rx) = broadcast::channel::(8); + + tokio::pin! { + let rx = rx.into_stream(); + } + + let mut recv = task::spawn(rx.next()); + assert_pending!(recv.poll()); + + assert_ok!(tx.send(1)); + + assert!(recv.is_woken()); + let val = assert_ready!(recv.poll()); + assert_eq!(val, Some(Ok(1))); +} + #[test] fn send_recv_bounded() { let (tx, mut rx) = broadcast::channel(16); @@ -160,6 +183,23 @@ fn send_two_recv_bounded() { assert_eq!(val2, "world"); } +#[test] +fn change_tasks() { + let (tx, mut rx) = broadcast::channel(1); + + let mut recv = Box::pin(rx.recv()); + + let mut task1 = task::spawn(&mut recv); + assert_pending!(task1.poll()); + + let mut task2 = task::spawn(&mut recv); + assert_pending!(task2.poll()); + + tx.send("hello").unwrap(); + + assert!(task2.is_woken()); +} + #[test] fn send_slow_rx() { let (tx, mut rx1) = broadcast::channel(16); @@ -451,6 +491,39 @@ fn lagging_receiver_recovers_after_wrap_open() { assert_empty!(rx); } +#[tokio::test] +async fn send_recv_stream_ready_deprecated() { + use tokio::stream::StreamExt; + + let (tx, mut rx) = broadcast::channel::(8); + + assert_ok!(tx.send(1)); + assert_ok!(tx.send(2)); + + assert_eq!(Some(Ok(1)), rx.next().await); + assert_eq!(Some(Ok(2)), rx.next().await); + + drop(tx); + + assert_eq!(None, rx.next().await); +} + +#[tokio::test] +async fn send_recv_stream_pending_deprecated() { + use tokio::stream::StreamExt; + + let (tx, mut rx) = broadcast::channel::(8); + + let mut recv = task::spawn(rx.next()); + assert_pending!(recv.poll()); + + assert_ok!(tx.send(1)); + + assert!(recv.is_woken()); + let val = assert_ready!(recv.poll()); + assert_eq!(val, Some(Ok(1))); +} + fn is_closed(err: broadcast::RecvError) -> bool { match err { broadcast::RecvError::Closed => true, diff --git a/tokio/tests/sync_cancellation_token.rs b/tokio/tests/sync_cancellation_token.rs new file mode 100644 index 00000000000..de543c94b1f --- /dev/null +++ b/tokio/tests/sync_cancellation_token.rs @@ -0,0 +1,220 @@ +#![cfg(tokio_unstable)] + +use tokio::pin; +use tokio::sync::CancellationToken; + +use core::future::Future; +use core::task::{Context, Poll}; +use futures_test::task::new_count_waker; + +#[test] +fn cancel_token() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + assert_eq!(false, token.is_cancelled()); + + let wait_fut = token.cancelled(); + pin!(wait_fut); + + assert_eq!( + Poll::Pending, + wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + let wait_fut_2 = token.cancelled(); + pin!(wait_fut_2); + + token.cancel(); + assert_eq!(wake_counter, 1); + assert_eq!(true, token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + wait_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn cancel_child_token_through_parent() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let child_token = token.child_token(); + assert!(!child_token.is_cancelled()); + + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Pending, + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + token.cancel(); + assert_eq!(wake_counter, 2); + assert_eq!(true, token.is_cancelled()); + assert_eq!(true, child_token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn cancel_child_token_without_parent() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let child_token_1 = token.child_token(); + + let child_fut = child_token_1.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Pending, + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + child_token_1.cancel(); + assert_eq!(wake_counter, 1); + assert_eq!(false, token.is_cancelled()); + assert_eq!(true, child_token_1.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + let child_token_2 = token.child_token(); + let child_fut_2 = child_token_2.cancelled(); + pin!(child_fut_2); + + assert_eq!( + Poll::Pending, + child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + token.cancel(); + assert_eq!(wake_counter, 3); + assert_eq!(true, token.is_cancelled()); + assert_eq!(true, child_token_2.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn create_child_token_after_parent_was_cancelled() { + for drop_child_first in [true, false].iter().cloned() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + token.cancel(); + + let child_token = token.child_token(); + assert!(child_token.is_cancelled()); + + { + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + drop(child_fut); + drop(parent_fut); + } + + if drop_child_first { + drop(child_token); + drop(token); + } else { + drop(token); + drop(child_token); + } + } +} + +#[test] +fn drop_multiple_child_tokens() { + for drop_first_child_first in &[true, false] { + let token = CancellationToken::new(); + let mut child_tokens = [None, None, None]; + for i in 0..child_tokens.len() { + child_tokens[i] = Some(token.child_token()); + } + + assert!(!token.is_cancelled()); + assert!(!child_tokens[0].as_ref().unwrap().is_cancelled()); + + for i in 0..child_tokens.len() { + if *drop_first_child_first { + child_tokens[i] = None; + } else { + child_tokens[child_tokens.len() - 1 - i] = None; + } + assert!(!token.is_cancelled()); + } + + drop(token); + } +} + +#[test] +fn drop_parent_before_child_tokens() { + let token = CancellationToken::new(); + let child1 = token.child_token(); + let child2 = token.child_token(); + + drop(token); + assert!(!child1.is_cancelled()); + + drop(child1); + drop(child2); +} diff --git a/tokio/tests/sync_mutex_owned.rs b/tokio/tests/sync_mutex_owned.rs index eef966fd41d..394a6708bd2 100644 --- a/tokio/tests/sync_mutex_owned.rs +++ b/tokio/tests/sync_mutex_owned.rs @@ -36,7 +36,7 @@ fn straight_execution() { fn readiness() { let l = Arc::new(Mutex::new(100)); let mut t1 = spawn(l.clone().lock_owned()); - let mut t2 = spawn(l.clone().lock_owned()); + let mut t2 = spawn(l.lock_owned()); let g = assert_ready!(t1.poll()); diff --git a/tokio/tests/task_blocking.rs b/tokio/tests/task_blocking.rs index 72fed01e961..50c070a355a 100644 --- a/tokio/tests/task_blocking.rs +++ b/tokio/tests/task_blocking.rs @@ -96,3 +96,83 @@ fn no_block_in_basic_block_on() { task::block_in_place(|| {}); }); } + +#[test] +fn can_enter_basic_rt_from_within_block_in_place() { + let mut outer = tokio::runtime::Builder::new() + .threaded_scheduler() + .build() + .unwrap(); + + outer.block_on(async { + tokio::task::block_in_place(|| { + let mut inner = tokio::runtime::Builder::new() + .basic_scheduler() + .build() + .unwrap(); + + inner.block_on(async {}) + }) + }); +} + +#[test] +fn useful_panic_message_when_dropping_rt_in_rt() { + use std::panic::{catch_unwind, AssertUnwindSafe}; + + let mut outer = tokio::runtime::Builder::new() + .threaded_scheduler() + .build() + .unwrap(); + + let result = catch_unwind(AssertUnwindSafe(|| { + outer.block_on(async { + let _ = tokio::runtime::Builder::new() + .basic_scheduler() + .build() + .unwrap(); + }); + })); + + assert!(result.is_err()); + let err = result.unwrap_err(); + let err: &'static str = err.downcast_ref::<&'static str>().unwrap(); + + assert!( + err.find("Cannot drop a runtime").is_some(), + "Wrong panic message: {:?}", + err + ); +} + +#[test] +fn can_shutdown_with_zero_timeout_in_runtime() { + let mut outer = tokio::runtime::Builder::new() + .threaded_scheduler() + .build() + .unwrap(); + + outer.block_on(async { + let rt = tokio::runtime::Builder::new() + .basic_scheduler() + .build() + .unwrap(); + rt.shutdown_timeout(Duration::from_nanos(0)); + }); +} + +#[test] +fn can_shutdown_now_in_runtime() { + let mut outer = tokio::runtime::Builder::new() + .threaded_scheduler() + .build() + .unwrap(); + + outer.block_on(async { + let rt = tokio::runtime::Builder::new() + .basic_scheduler() + .build() + .unwrap(); + rt.shutdown_background(); + }); +} diff --git a/tokio/tests/task_local_set.rs b/tokio/tests/task_local_set.rs index 38c7c939238..bf80b8ee5f5 100644 --- a/tokio/tests/task_local_set.rs +++ b/tokio/tests/task_local_set.rs @@ -365,7 +365,7 @@ fn drop_cancels_remote_tasks() { let mut rt = rt(); let local = LocalSet::new(); - local.spawn_local(async move { while let Some(_) = rx.recv().await {} }); + local.spawn_local(async move { while rx.recv().await.is_some() {} }); local.block_on(&mut rt, async { time::delay_for(Duration::from_millis(1)).await; }); diff --git a/tokio/tests/tcp_accept.rs b/tokio/tests/tcp_accept.rs index ff62fb96a2b..9f5b441468d 100644 --- a/tokio/tests/tcp_accept.rs +++ b/tokio/tests/tcp_accept.rs @@ -39,6 +39,7 @@ test_accept! { (ip_port_tuple, ("127.0.0.1".parse::().unwrap(), 0)), } +use pin_project_lite::pin_project; use std::pin::Pin; use std::sync::{ atomic::{AtomicUsize, Ordering::SeqCst}, @@ -47,9 +48,12 @@ use std::sync::{ use std::task::{Context, Poll}; use tokio::stream::{Stream, StreamExt}; -struct TrackPolls { - npolls: Arc, - s: S, +pin_project! { + struct TrackPolls { + npolls: Arc, + #[pin] + s: S, + } } impl Stream for TrackPolls @@ -58,11 +62,9 @@ where { type Item = S::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // safety: we do not move s - let this = unsafe { self.get_unchecked_mut() }; + let this = self.project(); this.npolls.fetch_add(1, SeqCst); - // safety: we are pinned, and so is s - unsafe { Pin::new_unchecked(&mut this.s) }.poll_next(cx) + this.s.poll_next(cx) } } @@ -80,7 +82,7 @@ async fn no_extra_poll() { s: listener.incoming(), }; assert_ok!(tx.send(Arc::clone(&incoming.npolls))); - while let Some(_) = incoming.next().await { + while incoming.next().await.is_some() { accepted_tx.send(()).unwrap(); } }); diff --git a/tokio/tests/tcp_into_split.rs b/tokio/tests/tcp_into_split.rs index 6561fa30a9b..86ed461923d 100644 --- a/tokio/tests/tcp_into_split.rs +++ b/tokio/tests/tcp_into_split.rs @@ -3,7 +3,6 @@ use std::io::{Error, ErrorKind, Result}; use std::io::{Read, Write}; -use std::sync::{Arc, Barrier}; use std::{net, thread}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -89,12 +88,9 @@ async fn drop_write() -> Result<()> { let listener = net::TcpListener::bind("127.0.0.1:0")?; let addr = listener.local_addr()?; - let barrier = Arc::new(Barrier::new(2)); - let barrier2 = barrier.clone(); - let handle = thread::spawn(move || { let (mut stream, _) = listener.accept().unwrap(); - stream.write(MSG).unwrap(); + stream.write_all(MSG).unwrap(); let mut read_buf = [0u8; 32]; let res = match stream.read(&mut read_buf) { @@ -106,8 +102,6 @@ async fn drop_write() -> Result<()> { Err(err) => Err(err), }; - barrier2.wait(); - drop(stream); res @@ -132,8 +126,6 @@ async fn drop_write() -> Result<()> { Err(err) => panic!("Unexpected error: {}.", err), } - barrier.wait(); - handle.join().unwrap().unwrap(); Ok(()) } diff --git a/tokio/tests/tcp_split.rs b/tokio/tests/tcp_split.rs index 42f797708cc..7171dac4635 100644 --- a/tokio/tests/tcp_split.rs +++ b/tokio/tests/tcp_split.rs @@ -17,7 +17,7 @@ async fn split() -> Result<()> { let handle = thread::spawn(move || { let (mut stream, _) = listener.accept().unwrap(); - stream.write(MSG).unwrap(); + stream.write_all(MSG).unwrap(); let mut read_buf = [0u8; 32]; let read_len = stream.read(&mut read_buf).unwrap(); diff --git a/tokio/tests/time_delay_queue.rs b/tokio/tests/time_delay_queue.rs index 3cf2d1cd059..f04576d5a18 100644 --- a/tokio/tests/time_delay_queue.rs +++ b/tokio/tests/time_delay_queue.rs @@ -1,3 +1,4 @@ +#![allow(clippy::blacklisted_name)] #![warn(rust_2018_idioms)] #![cfg(feature = "full")] diff --git a/tokio/tests/uds_datagram.rs b/tokio/tests/uds_datagram.rs index dd9952378f7..d3c3535e7f4 100644 --- a/tokio/tests/uds_datagram.rs +++ b/tokio/tests/uds_datagram.rs @@ -3,6 +3,7 @@ #![cfg(unix)] use tokio::net::UnixDatagram; +use tokio::try_join; use std::io; @@ -41,3 +42,92 @@ async fn echo() -> io::Result<()> { Ok(()) } + +// Even though we use sync non-blocking io we still need a reactor. +#[tokio::test] +async fn try_send_recv_never_block() -> io::Result<()> { + let mut recv_buf = [0u8; 16]; + let payload = b"PAYLOAD"; + let mut count = 0; + + let (mut dgram1, mut dgram2) = UnixDatagram::pair()?; + + // Send until we hit the OS `net.unix.max_dgram_qlen`. + loop { + match dgram1.try_send(payload) { + Err(err) => match err.kind() { + io::ErrorKind::WouldBlock | io::ErrorKind::Other => break, + _ => unreachable!("unexpected error {:?}", err), + }, + Ok(len) => { + assert_eq!(len, payload.len()); + } + } + count += 1; + } + + // Read every dgram we sent. + while count > 0 { + let len = dgram2.try_recv(&mut recv_buf[..])?; + assert_eq!(len, payload.len()); + assert_eq!(payload, &recv_buf[..len]); + count -= 1; + } + + let err = dgram2.try_recv(&mut recv_buf[..]).unwrap_err(); + match err.kind() { + io::ErrorKind::WouldBlock => (), + _ => unreachable!("unexpected error {:?}", err), + } + + Ok(()) +} + +#[tokio::test] +async fn split() -> std::io::Result<()> { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("split.sock"); + let socket = UnixDatagram::bind(path.clone())?; + let (mut r, mut s) = socket.into_split(); + + let msg = b"hello"; + let ((), ()) = try_join! { + async { + s.send_to(msg, path).await?; + io::Result::Ok(()) + }, + async { + let mut recv_buf = [0u8; 32]; + let (len, _) = r.recv_from(&mut recv_buf[..]).await?; + assert_eq!(&recv_buf[..len], msg); + Ok(()) + }, + }?; + + Ok(()) +} + +#[tokio::test] +async fn reunite() -> std::io::Result<()> { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("reunite.sock"); + let socket = UnixDatagram::bind(path)?; + let (s, r) = socket.into_split(); + assert!(s.reunite(r).is_ok()); + Ok(()) +} + +#[tokio::test] +async fn reunite_error() -> std::io::Result<()> { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("reunit.sock"); + let dir = tempfile::tempdir().unwrap(); + let path1 = dir.path().join("reunit.sock"); + let socket = UnixDatagram::bind(path)?; + let socket1 = UnixDatagram::bind(path1)?; + + let (s, _) = socket.into_split(); + let (_, r1) = socket1.into_split(); + assert!(s.reunite(r1).is_err()); + Ok(()) +}