Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow predicates to own object and evaluate against borrowed types #134

Merged
merged 2 commits into from Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
103 changes: 35 additions & 68 deletions src/iter.rs
Expand Up @@ -60,6 +60,11 @@ where
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("b"));
/// assert_eq!(true, predicate_fn.eval("c"));
///
/// let predicate_fn = predicate::in_iter(vec![String::from("a"), String::from("c"), String::from("e")]).sort();
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("b"));
/// assert_eq!(true, predicate_fn.eval("c"));
/// ```
pub fn sort(self) -> OrdInPredicate<T> {
let mut items = self.inner.debug;
Expand All @@ -70,33 +75,16 @@ where
}
}

impl<T> Predicate<T> for InPredicate<T>
where
T: PartialEq + fmt::Debug,
{
fn eval(&self, variable: &T) -> bool {
self.inner.debug.contains(variable)
}

fn find_case<'a>(&'a self, expected: bool, variable: &T) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
utils::DebugAdapter::new(variable).to_string(),
))
})
}
}

impl<'a, T> Predicate<T> for InPredicate<&'a T>
impl<P, T> Predicate<P> for InPredicate<T>
where
T: PartialEq + fmt::Debug + ?Sized,
T: std::borrow::Borrow<P> + PartialEq + fmt::Debug,
P: PartialEq + fmt::Debug + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
self.inner.debug.contains(&variable)
fn eval(&self, variable: &P) -> bool {
self.inner.debug.iter().any(|x| x.borrow() == variable)
}

fn find_case<'b>(&'b self, expected: bool, variable: &T) -> Option<reflection::Case<'b>> {
fn find_case<'a>(&'a self, expected: bool, variable: &P) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
Expand Down Expand Up @@ -160,6 +148,11 @@ where
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("b"));
/// assert_eq!(true, predicate_fn.eval("c"));
///
/// let predicate_fn = predicate::in_iter(vec![String::from("a"), String::from("c"), String::from("e")]);
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("b"));
/// assert_eq!(true, predicate_fn.eval("c"));
/// ```
pub fn in_iter<I, T>(iter: I) -> InPredicate<T>
where
Expand Down Expand Up @@ -188,33 +181,19 @@ where
inner: utils::DebugAdapter<Vec<T>>,
}

impl<T> Predicate<T> for OrdInPredicate<T>
impl<P, T> Predicate<P> for OrdInPredicate<T>
where
T: Ord + fmt::Debug,
T: std::borrow::Borrow<P> + Ord + fmt::Debug,
P: Ord + fmt::Debug + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
self.inner.debug.binary_search(variable).is_ok()
fn eval(&self, variable: &P) -> bool {
self.inner
.debug
.binary_search_by(|x| x.borrow().cmp(variable))
.is_ok()
}

fn find_case<'a>(&'a self, expected: bool, variable: &T) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
utils::DebugAdapter::new(variable).to_string(),
))
})
}
}

impl<'a, T> Predicate<T> for OrdInPredicate<&'a T>
where
T: Ord + fmt::Debug + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
self.inner.debug.binary_search(&variable).is_ok()
}

fn find_case<'b>(&'b self, expected: bool, variable: &T) -> Option<reflection::Case<'b>> {
fn find_case<'a>(&'a self, expected: bool, variable: &P) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
Expand Down Expand Up @@ -267,33 +246,16 @@ where
inner: utils::DebugAdapter<HashSet<T>>,
}

impl<T> Predicate<T> for HashableInPredicate<T>
impl<P, T> Predicate<P> for HashableInPredicate<T>
where
T: Hash + Eq + fmt::Debug,
T: std::borrow::Borrow<P> + Hash + Eq + fmt::Debug,
P: Hash + Eq + fmt::Debug + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
fn eval(&self, variable: &P) -> bool {
self.inner.debug.contains(variable)
}

fn find_case<'a>(&'a self, expected: bool, variable: &T) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
utils::DebugAdapter::new(variable).to_string(),
))
})
}
}

impl<'a, T> Predicate<T> for HashableInPredicate<&'a T>
where
T: Hash + Eq + fmt::Debug + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
self.inner.debug.contains(&variable)
}

fn find_case<'b>(&'b self, expected: bool, variable: &T) -> Option<reflection::Case<'b>> {
fn find_case<'a>(&'a self, expected: bool, variable: &P) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
Expand Down Expand Up @@ -351,6 +313,11 @@ where
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("b"));
/// assert_eq!(true, predicate_fn.eval("c"));
///
/// let predicate_fn = predicate::in_hash(vec![String::from("a"), String::from("c"), String::from("e")]);
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("b"));
/// assert_eq!(true, predicate_fn.eval("c"));
/// ```
pub fn in_hash<I, T>(iter: I) -> HashableInPredicate<T>
where
Expand Down
102 changes: 31 additions & 71 deletions src/ord.rs
@@ -1,4 +1,4 @@
// Copyright (c) 2018 The predicates-rs Project Developers.
// Copyright (c) 2018, 2022 The predicates-rs Project Developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/license/LICENSE-2.0> or the MIT license
Expand Down Expand Up @@ -35,26 +35,24 @@ impl fmt::Display for EqOps {
///
/// This is created by the `predicate::{eq, ne}` functions.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EqPredicate<T>
where
T: fmt::Debug + PartialEq,
{
pub struct EqPredicate<T> {
constant: T,
op: EqOps,
}

impl<T> Predicate<T> for EqPredicate<T>
impl<P, T> Predicate<P> for EqPredicate<T>
where
T: fmt::Debug + PartialEq,
T: std::borrow::Borrow<P> + fmt::Debug,
P: fmt::Debug + PartialEq + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
fn eval(&self, variable: &P) -> bool {
match self.op {
EqOps::Equal => variable.eq(&self.constant),
EqOps::NotEqual => variable.ne(&self.constant),
EqOps::Equal => variable.eq(self.constant.borrow()),
EqOps::NotEqual => variable.ne(self.constant.borrow()),
}
}

fn find_case<'a>(&'a self, expected: bool, variable: &T) -> Option<reflection::Case<'a>> {
fn find_case<'a>(&'a self, expected: bool, variable: &P) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
Expand All @@ -64,32 +62,11 @@ where
}
}

impl<'a, T> Predicate<T> for EqPredicate<&'a T>
where
T: fmt::Debug + PartialEq + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
match self.op {
EqOps::Equal => variable.eq(self.constant),
EqOps::NotEqual => variable.ne(self.constant),
}
}

fn find_case<'b>(&'b self, expected: bool, variable: &T) -> Option<reflection::Case<'b>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
utils::DebugAdapter::new(variable).to_string(),
))
})
}
}

impl<T> reflection::PredicateReflection for EqPredicate<T> where T: fmt::Debug + PartialEq {}
impl<T> reflection::PredicateReflection for EqPredicate<T> where T: fmt::Debug {}

impl<T> fmt::Display for EqPredicate<T>
where
T: fmt::Debug + PartialEq,
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let palette = crate::Palette::current();
Expand Down Expand Up @@ -120,6 +97,10 @@ where
/// let predicate_fn = predicate::eq("Hello");
/// assert_eq!(true, predicate_fn.eval("Hello"));
/// assert_eq!(false, predicate_fn.eval("Goodbye"));
///
/// let predicate_fn = predicate::eq(String::from("Hello"));
/// assert_eq!(true, predicate_fn.eval("Hello"));
/// assert_eq!(false, predicate_fn.eval("Goodbye"));
/// ```
pub fn eq<T>(constant: T) -> EqPredicate<T>
where
Expand Down Expand Up @@ -178,28 +159,26 @@ impl fmt::Display for OrdOps {
///
/// This is created by the `predicate::{gt, ge, lt, le}` functions.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct OrdPredicate<T>
where
T: fmt::Debug + PartialOrd,
{
pub struct OrdPredicate<T> {
constant: T,
op: OrdOps,
}

impl<T> Predicate<T> for OrdPredicate<T>
impl<P, T> Predicate<P> for OrdPredicate<T>
where
T: fmt::Debug + PartialOrd,
T: std::borrow::Borrow<P> + fmt::Debug,
P: fmt::Debug + PartialOrd + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
fn eval(&self, variable: &P) -> bool {
match self.op {
OrdOps::LessThan => variable.lt(&self.constant),
OrdOps::LessThanOrEqual => variable.le(&self.constant),
OrdOps::GreaterThanOrEqual => variable.ge(&self.constant),
OrdOps::GreaterThan => variable.gt(&self.constant),
OrdOps::LessThan => variable.lt(self.constant.borrow()),
OrdOps::LessThanOrEqual => variable.le(self.constant.borrow()),
OrdOps::GreaterThanOrEqual => variable.ge(self.constant.borrow()),
OrdOps::GreaterThan => variable.gt(self.constant.borrow()),
}
}

fn find_case<'a>(&'a self, expected: bool, variable: &T) -> Option<reflection::Case<'a>> {
fn find_case<'a>(&'a self, expected: bool, variable: &P) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
Expand All @@ -209,34 +188,11 @@ where
}
}

impl<'a, T> Predicate<T> for OrdPredicate<&'a T>
where
T: fmt::Debug + PartialOrd + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
match self.op {
OrdOps::LessThan => variable.lt(self.constant),
OrdOps::LessThanOrEqual => variable.le(self.constant),
OrdOps::GreaterThanOrEqual => variable.ge(self.constant),
OrdOps::GreaterThan => variable.gt(self.constant),
}
}

fn find_case<'b>(&'b self, expected: bool, variable: &T) -> Option<reflection::Case<'b>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
utils::DebugAdapter::new(variable).to_string(),
))
})
}
}

impl<T> reflection::PredicateReflection for OrdPredicate<T> where T: fmt::Debug + PartialOrd {}
impl<T> reflection::PredicateReflection for OrdPredicate<T> where T: fmt::Debug {}

impl<T> fmt::Display for OrdPredicate<T>
where
T: fmt::Debug + PartialOrd,
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let palette = crate::Palette::current();
Expand Down Expand Up @@ -267,6 +223,10 @@ where
/// let predicate_fn = predicate::lt("b");
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("c"));
///
/// let predicate_fn = predicate::lt(String::from("b"));
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("c"));
/// ```
pub fn lt<T>(constant: T) -> OrdPredicate<T>
where
Expand Down