From f7323a41ef262c46f0b4960995fc141a66c2c8a7 Mon Sep 17 00:00:00 2001 From: Thom Chiovoloni Date: Wed, 9 Jun 2021 17:57:44 -0700 Subject: [PATCH] Add named functions to perform set operations --- src/lib.rs | 283 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 282 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 9e5ac8dd..3504ec95 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -164,6 +164,16 @@ //! - `toggle`: the specified flags will be inserted if not present, and removed //! if they are. //! - `set`: inserts or removes the specified flags depending on the passed value +//! - `intersection`: returns a new set of flags, containing only the flags present +//! in both `self` and `other` (the argument to the function). +//! - `union`: returns a new set of flags, containing any flags present in +//! either `self` or `other` (the argument to the function). +//! - `difference`: returns a new set of flags, containing all flags present in +//! `self` without any of the flags present in `other` (the +//! argument to the function). +//! - `symmetric_difference`: returns a new set of flags, containing all flags +//! present in either `self` or `other` (the argument +//! to the function), but not both. //! //! ## Default //! @@ -653,7 +663,6 @@ macro_rules! __impl_bitflags { pub const fn contains(&self, other: $BitFlags) -> bool { (self.bits & other.bits) == other.bits } - /// Inserts the specified flags in-place. #[inline] pub fn insert(&mut self, other: $BitFlags) { @@ -681,6 +690,96 @@ macro_rules! __impl_bitflags { self.remove(other); } } + + /// Returns the intersection between the flags in `self` and + /// `other`. + /// + /// Specifically, the returned set contains only the flags which are + /// present in *both* `self` *and* `other`. + /// + /// This is equivalent to using the `&` operator (e.g. + /// [`ops::BitAnd`]), as in `flags & other`. + /// + /// [`ops::BitAnd`]: https://doc.rust-lang.org/std/ops/trait.BitAnd.html + #[inline] + #[must_use] + pub const fn intersection(self, other: $BitFlags) -> Self { + Self { bits: self.bits & other.bits } + } + + /// Returns the union of between the flags in `self` and `other`. + /// + /// Specifically, the returned set contains all flags which are + /// present in *either* `self` *or* `other`, including any which are + /// present in both (see [`Self::symmetric_difference`] if that + /// is undesirable). + /// + /// This is equivalent to using the `|` operator (e.g. + /// [`ops::BitOr`]), as in `flags | other`. + /// + /// [`ops::BitOr`]: https://doc.rust-lang.org/std/ops/trait.BitOr.html + #[inline] + #[must_use] + pub const fn union(self, other: $BitFlags) -> Self { + Self { bits: self.bits | other.bits } + } + + /// Returns the difference between the flags in `self` and `other`. + /// + /// Specifically, the returned set contains all flags present in + /// `self`, except for the ones present in `other`. + /// + /// It is also conceptually equivalent to the "bit-clear" operation: + /// `flags & !other` (and this syntax is also supported). + /// + /// This is equivalent to using the `-` operator (e.g. + /// [`ops::Sub`]), as in `flags - other`. + /// + /// [`ops::Sub`]: https://doc.rust-lang.org/std/ops/trait.Sub.html + #[inline] + #[must_use] + pub const fn difference(self, other: $BitFlags) -> Self { + Self { bits: self.bits & !other.bits } + } + + /// Returns the [symmetric difference][sym-diff] between the flags + /// in `self` and `other`. + /// + /// Specifically, the returned set contains the flags present which + /// are present in `self` or `other`, but that are not present in + /// both. Equivalently, it contains the flags present in *exactly + /// one* of the sets `self` and `other`. + /// + /// This is equivalent to using the `^` operator (e.g. + /// [`ops::BitXor`]), as in `flags ^ other`. + /// + /// [sym-diff]: https://en.wikipedia.org/wiki/Symmetric_difference + /// [`ops::BitXor`]: https://doc.rust-lang.org/std/ops/trait.BitXor.html + #[inline] + #[must_use] + pub const fn symmetric_difference(self, other: $BitFlags) -> Self { + Self { bits: self.bits ^ other.bits } + } + + /// Returns the complement of this set of flags. + /// + /// Specifically, the returned set contains all the flags which are + /// not set in `self`, but which are allowed for this type. + /// + /// Alternatively, it can be thought of as the set difference + /// between [`Self::all()`] and `self` (e.g. `Self::all() - self`) + /// + /// This is equivalent to using the `!` operator (e.g. + /// [`ops::Not`]), as in `!flags`. + /// + /// [`Self::all()`]: Self::all + /// [`ops::Not`]: https://doc.rust-lang.org/std/ops/trait.Not.html + #[inline] + #[must_use] + pub const fn complement(self) -> Self { + Self::from_bits_truncate(!self.bits) + } + } impl $crate::_core::ops::BitOr for $BitFlags { @@ -1156,6 +1255,188 @@ mod tests { assert_eq!(e3, Flags::A | Flags::B | extra); } + #[test] + fn test_set_ops_basic() { + let ab = Flags::A.union(Flags::B); + let ac = Flags::A.union(Flags::C); + let bc = Flags::B.union(Flags::C); + assert_eq!(ab.bits, 0b011); + assert_eq!(bc.bits, 0b110); + assert_eq!(ac.bits, 0b101); + + assert_eq!(ab, Flags::B.union(Flags::A)); + assert_eq!(ac, Flags::C.union(Flags::A)); + assert_eq!(bc, Flags::C.union(Flags::B)); + + assert_eq!(ac, Flags::A | Flags::C); + assert_eq!(bc, Flags::B | Flags::C); + assert_eq!(ab.union(bc), Flags::ABC); + + assert_eq!(ac, Flags::A | Flags::C); + assert_eq!(bc, Flags::B | Flags::C); + + assert_eq!(ac.union(bc), ac | bc); + assert_eq!(ac.union(bc), Flags::ABC); + assert_eq!(bc.union(ac), Flags::ABC); + + assert_eq!(ac.intersection(bc), ac & bc); + assert_eq!(ac.intersection(bc), Flags::C); + assert_eq!(bc.intersection(ac), Flags::C); + + assert_eq!(ac.difference(bc), ac - bc); + assert_eq!(bc.difference(ac), bc - ac); + assert_eq!(ac.difference(bc), Flags::A); + assert_eq!(bc.difference(ac), Flags::B); + + assert_eq!(bc.complement(), !bc); + assert_eq!(bc.complement(), Flags::A); + assert_eq!(ac.symmetric_difference(bc), Flags::A.union(Flags::B)); + assert_eq!(bc.symmetric_difference(ac), Flags::A.union(Flags::B)); + } + + #[test] + fn test_set_ops_const() { + // These just test that these compile and don't cause use-site panics + // (would be possible if we had some sort of UB), which is enoug + const INTERSECT: Flags = Flags::all().intersection(Flags::C); + const UNION: Flags = Flags::A.union(Flags::C); + const DIFFERENCE: Flags = Flags::all().difference(Flags::A); + const COMPLEMENT: Flags = Flags::C.complement(); + const SYM_DIFFERENCE: Flags = UNION.symmetric_difference(DIFFERENCE); + assert_eq!(INTERSECT, Flags::C); + assert_eq!(UNION, Flags::A | Flags::C); + assert_eq!(DIFFERENCE, Flags::all() - Flags::A); + assert_eq!(COMPLEMENT, !Flags::C); + assert_eq!(SYM_DIFFERENCE, (Flags::A | Flags::C) ^ (Flags::all() - Flags::A)); + } + + #[test] + fn test_set_ops_unchecked() { + let extra = unsafe { Flags::from_bits_unchecked(0b1000) }; + let e1 = Flags::A.union(Flags::C).union(extra); + let e2 = Flags::B.union(Flags::C); + assert_eq!(e1.bits, 0b1101); + assert_eq!(e1.union(e2), (Flags::ABC | extra)); + assert_eq!(e1.intersection(e2), Flags::C); + assert_eq!(e1.difference(e2), Flags::A | extra); + assert_eq!(e2.difference(e1), Flags::B); + assert_eq!(e2.complement(), Flags::A); + assert_eq!(e1.complement(), Flags::B); + assert_eq!(e1.symmetric_difference(e2), Flags::A | Flags::B | extra); // toggle + } + + #[test] + fn test_set_ops_exhaustive() { + // Define a flag that contains gaps to help exercise edge-cases, + // especially around "unknown" flags (e.g. ones outside of `all()` + // `from_bits_unchecked`). + // - when lhs and rhs both have different sets of unknown flags. + // - unknown flags at both ends, and in the middle + // - cases with "gaps". + bitflags! { + struct Test: u16 { + // Intentionally no `A` + const B = 0b000000010; + // Intentionally no `C` + const D = 0b000001000; + const E = 0b000010000; + const F = 0b000100000; + const G = 0b001000000; + // Intentionally no `H` + const I = 0b100000000; + } + } + let iter_test_flags = + || (0..=0b111_1111_1111).map(|bits| unsafe { Test::from_bits_unchecked(bits) }); + + for a in iter_test_flags() { + assert_eq!( + a.complement(), + Test::from_bits_truncate(!a.bits), + "wrong result: !({:?})", + a, + ); + assert_eq!(a.complement(), !a, "named != op: !({:?})", a); + for b in iter_test_flags() { + // Check that the named operations produce the expected bitwise + // values. + assert_eq!( + a.union(b).bits, + a.bits | b.bits, + "wrong result: `{:?}` | `{:?}`", + a, + b, + ); + assert_eq!( + a.intersection(b).bits, + a.bits & b.bits, + "wrong result: `{:?}` & `{:?}`", + a, + b, + ); + assert_eq!( + a.symmetric_difference(b).bits, + a.bits ^ b.bits, + "wrong result: `{:?}` ^ `{:?}`", + a, + b, + ); + assert_eq!( + a.difference(b).bits, + a.bits & !b.bits, + "wrong result: `{:?}` - `{:?}`", + a, + b, + ); + // Note: Difference is checked as both `a - b` and `b - a` + assert_eq!( + b.difference(a).bits, + b.bits & !a.bits, + "wrong result: `{:?}` - `{:?}`", + b, + a, + ); + // Check that the named set operations are equivalent to the + // bitwise equivalents + assert_eq!(a.union(b), a | b, "named != op: `{:?}` | `{:?}`", a, b,); + assert_eq!( + a.intersection(b), + a & b, + "named != op: `{:?}` & `{:?}`", + a, + b, + ); + assert_eq!( + a.symmetric_difference(b), + a ^ b, + "named != op: `{:?}` ^ `{:?}`", + a, + b, + ); + assert_eq!(a.difference(b), a - b, "named != op: `{:?}` - `{:?}`", a, b,); + // Note: Difference is checked as both `a - b` and `b - a` + assert_eq!(b.difference(a), b - a, "named != op: `{:?}` - `{:?}`", b, a,); + // Verify that the operations which should be symmetric are + // actually symmetric. + assert_eq!(a.union(b), b.union(a), "asymmetry: `{:?}` | `{:?}`", a, b,); + assert_eq!( + a.intersection(b), + b.intersection(a), + "asymmetry: `{:?}` & `{:?}`", + a, + b, + ); + assert_eq!( + a.symmetric_difference(b), + b.symmetric_difference(a), + "asymmetry: `{:?}` ^ `{:?}`", + a, + b, + ); + } + } + } + #[test] fn test_set() { let mut e1 = Flags::A | Flags::C;