Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor signal handling in yes, tee, and timeout #4588

Merged
merged 1 commit into from Mar 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 12 additions & 37 deletions src/uu/tee/src/tee.rs
Expand Up @@ -16,7 +16,7 @@ use uucore::{format_usage, help_about, help_section, help_usage, show_error};
// spell-checker:ignore nopipe

#[cfg(unix)]
use uucore::libc;
use uucore::signals::{enable_pipe_errors, ignore_interrupts};

const ABOUT: &str = help_about!("tee.md");
const USAGE: &str = help_usage!("tee.md");
Expand Down Expand Up @@ -135,44 +135,19 @@ pub fn uu_app() -> Command {
)
}

#[cfg(unix)]
fn ignore_interrupts() -> Result<()> {
let ret = unsafe { libc::signal(libc::SIGINT, libc::SIG_IGN) };
if ret == libc::SIG_ERR {
return Err(Error::new(ErrorKind::Other, ""));
}
Ok(())
}

#[cfg(not(unix))]
fn ignore_interrupts() -> Result<()> {
// Do nothing.
Ok(())
}

#[cfg(unix)]
fn enable_pipe_errors() -> Result<()> {
let ret = unsafe { libc::signal(libc::SIGPIPE, libc::SIG_DFL) };
if ret == libc::SIG_ERR {
return Err(Error::new(ErrorKind::Other, ""));
}
Ok(())
}

#[cfg(not(unix))]
fn enable_pipe_errors() -> Result<()> {
// Do nothing.
Ok(())
}

fn tee(options: &Options) -> Result<()> {
if options.ignore_interrupts {
ignore_interrupts()?;
}
if options.output_error.is_none() {
enable_pipe_errors()?;
}
#[cfg(unix)]
{
// ErrorKind::Other is raised by MultiWriter when all writers have exited.
// This is therefore just a clever way to stop all writers

if options.ignore_interrupts {
ignore_interrupts().map_err(|_| Error::from(ErrorKind::Other))?;
}
if options.output_error.is_none() {
enable_pipe_errors().map_err(|_| Error::from(ErrorKind::Other))?;
}
}
let mut writers: Vec<NamedWriter> = options
.files
.clone()
Expand Down
27 changes: 9 additions & 18 deletions src/uu/timeout/src/timeout.rs
Expand Up @@ -17,8 +17,14 @@ use std::time::Duration;
use uucore::display::Quotable;
use uucore::error::{UClapError, UResult, USimpleError, UUsageError};
use uucore::process::ChildExt;
use uucore::signals::{signal_by_name_or_value, signal_name_by_value};
use uucore::{format_usage, show_error};

#[cfg(unix)]
use uucore::signals::enable_pipe_errors;

use uucore::{
format_usage, show_error,
signals::{signal_by_name_or_value, signal_name_by_value},
};

static ABOUT: &str = "Start COMMAND, and kill it if still running after DURATION.";
const USAGE: &str = "{} [OPTION] DURATION COMMAND...";
Expand Down Expand Up @@ -285,21 +291,6 @@ fn preserve_signal_info(signal: libc::c_int) -> libc::c_int {
signal
}

#[cfg(unix)]
fn enable_pipe_errors() -> std::io::Result<()> {
let ret = unsafe { libc::signal(libc::SIGPIPE, libc::SIG_DFL) };
if ret == libc::SIG_ERR {
return Err(std::io::Error::new(std::io::ErrorKind::Other, ""));
}
Ok(())
}

#[cfg(not(unix))]
fn enable_pipe_errors() -> std::io::Result<()> {
// Do nothing.
Ok(())
}

/// TODO: Improve exit codes, and make them consistent with the GNU Coreutils exit codes.

fn timeout(
Expand All @@ -314,7 +305,7 @@ fn timeout(
if !foreground {
unsafe { libc::setpgid(0, 0) };
}

#[cfg(unix)]
enable_pipe_errors()?;

let process = &mut process::Command::new(&cmd[0])
Expand Down
8 changes: 5 additions & 3 deletions src/uu/yes/Cargo.toml
Expand Up @@ -16,12 +16,14 @@ path = "src/yes.rs"

[dependencies]
clap = { workspace=true }
libc = { workspace=true }
uucore = { workspace=true, features=["pipes"] }

[target.'cfg(any(target_os = "linux", target_os = "android"))'.dependencies]
[target.'cfg(unix)'.dependencies]
uucore = { workspace=true, features=["pipes", "signals"] }
nix = { workspace=true }

[target.'cfg(not(unix))'.dependencies]
uucore = { workspace=true, features=["pipes"] }

[[bin]]
name = "yes"
path = "src/main.rs"
25 changes: 5 additions & 20 deletions src/uu/yes/src/yes.rs
Expand Up @@ -7,13 +7,13 @@

/* last synced with: yes (GNU coreutils) 8.13 */

use std::borrow::Cow;
use std::io::{self, Result, Write};

use clap::{Arg, ArgAction, Command};
use std::borrow::Cow;
use std::io::{self, Write};
use uucore::error::{UResult, USimpleError};
#[cfg(unix)]
use uucore::signals::enable_pipe_errors;
use uucore::{format_usage, help_about, help_usage};

#[cfg(any(target_os = "linux", target_os = "android"))]
mod splice;

Expand Down Expand Up @@ -69,25 +69,10 @@ fn prepare_buffer<'a>(input: &'a str, buffer: &'a mut [u8; BUF_SIZE]) -> &'a [u8
}
}

#[cfg(unix)]
fn enable_pipe_errors() -> Result<()> {
let ret = unsafe { libc::signal(libc::SIGPIPE, libc::SIG_DFL) };
if ret == libc::SIG_ERR {
return Err(io::Error::new(io::ErrorKind::Other, ""));
}
Ok(())
}

#[cfg(not(unix))]
fn enable_pipe_errors() -> Result<()> {
// Do nothing.
Ok(())
}

pub fn exec(bytes: &[u8]) -> io::Result<()> {
let stdout = io::stdout();
let mut stdout = stdout.lock();

#[cfg(unix)]
enable_pipe_errors()?;

#[cfg(any(target_os = "linux", target_os = "android"))]
Expand Down
2 changes: 1 addition & 1 deletion src/uucore/Cargo.toml
Expand Up @@ -49,7 +49,7 @@ sm3 = { workspace=true }

[target.'cfg(unix)'.dependencies]
walkdir = { workspace=true, optional=true }
nix = { workspace=true, features = ["fs", "uio", "zerocopy"] }
nix = { workspace=true, features = ["fs", "uio", "zerocopy", "signal"] }

[dev-dependencies]
clap = { workspace=true }
Expand Down
20 changes: 19 additions & 1 deletion src/uucore/src/lib/features/signals.rs
Expand Up @@ -7,7 +7,12 @@

// spell-checker:ignore (vars/api) fcntl setrlimit setitimer
// spell-checker:ignore (vars/signals) ABRT ALRM CHLD SEGV SIGABRT SIGALRM SIGBUS SIGCHLD SIGCONT SIGEMT SIGFPE SIGHUP SIGILL SIGINFO SIGINT SIGIO SIGIOT SIGKILL SIGPIPE SIGPROF SIGPWR SIGQUIT SIGSEGV SIGSTOP SIGSYS SIGTERM SIGTRAP SIGTSTP SIGTHR SIGTTIN SIGTTOU SIGURG SIGUSR SIGVTALRM SIGWINCH SIGXCPU SIGXFSZ STKFLT PWR THR TSTP TTIN TTOU VTALRM XCPU XFSZ

#[cfg(unix)]
use nix::errno::Errno;
#[cfg(unix)]
use nix::sys::signal::{
signal, SigHandler::SigDfl, SigHandler::SigIgn, Signal::SIGINT, Signal::SIGPIPE,
};
pub static DEFAULT_SIGNAL: usize = 15;

/*
Expand Down Expand Up @@ -196,6 +201,19 @@
ALL_SIGNALS.get(signal_value).copied()
}

#[cfg(unix)]
pub fn enable_pipe_errors() -> Result<(), Errno> {
// We pass the error as is, the return value would just be Ok(SigDfl), so we can safely ignore it.
// SAFETY: this function is safe as long as we do not use a custom SigHandler -- we use the default one.
unsafe { signal(SIGPIPE, SigDfl) }.map(|_| ())
}
#[cfg(unix)]
pub fn ignore_interrupts() -> Result<(), Errno> {

Check warning on line 211 in src/uucore/src/lib/features/signals.rs

View check run for this annotation

Codecov / codecov/patch

src/uucore/src/lib/features/signals.rs#L211

Added line #L211 was not covered by tests
// We pass the error as is, the return value would just be Ok(SigIgn), so we can safely ignore it.
// SAFETY: this function is safe as long as we do not use a custom SigHandler -- we use the default one.
unsafe { signal(SIGINT, SigIgn) }.map(|_| ())
}

Check warning on line 215 in src/uucore/src/lib/features/signals.rs

View check run for this annotation

Codecov / codecov/patch

src/uucore/src/lib/features/signals.rs#L214-L215

Added lines #L214 - L215 were not covered by tests

#[test]
fn signal_by_value() {
assert_eq!(signal_by_name_or_value("0"), Some(0));
Expand Down
10 changes: 5 additions & 5 deletions src/uucore/src/lib/mods/error.rs
Expand Up @@ -522,7 +522,7 @@ impl From<std::io::Error> for Box<dyn UError> {
/// // prints "fix me please!: Permission denied"
/// println!("{}", uio_result.unwrap_err());
/// ```
#[cfg(any(target_os = "linux", target_os = "android"))]
#[cfg(unix)]
impl<T> FromIo<UResult<T>> for Result<T, nix::Error> {
fn map_err_context(self, context: impl FnOnce() -> String) -> UResult<T> {
self.map_err(|e| {
Expand All @@ -534,7 +534,7 @@ impl<T> FromIo<UResult<T>> for Result<T, nix::Error> {
}
}

#[cfg(any(target_os = "linux", target_os = "android"))]
#[cfg(unix)]
impl<T> FromIo<UResult<T>> for nix::Error {
fn map_err_context(self, context: impl FnOnce() -> String) -> UResult<T> {
Err(Box::new(UIoError {
Expand All @@ -544,7 +544,7 @@ impl<T> FromIo<UResult<T>> for nix::Error {
}
}

#[cfg(any(target_os = "linux", target_os = "android"))]
#[cfg(unix)]
impl From<nix::Error> for UIoError {
fn from(f: nix::Error) -> Self {
Self {
Expand All @@ -554,7 +554,7 @@ impl From<nix::Error> for UIoError {
}
}

#[cfg(any(target_os = "linux", target_os = "android"))]
#[cfg(unix)]
impl From<nix::Error> for Box<dyn UError> {
fn from(f: nix::Error) -> Self {
let u_error: UIoError = f.into();
Expand Down Expand Up @@ -751,7 +751,7 @@ impl Display for ClapErrorWrapper {
#[cfg(test)]
mod tests {
#[test]
#[cfg(any(target_os = "linux", target_os = "android"))]
#[cfg(unix)]
fn test_nix_error_conversion() {
use super::{FromIo, UIoError};
use nix::errno::Errno;
Expand Down