Skip to content

Commit

Permalink
Merge pull request #134 from rshearman/owned-ord
Browse files Browse the repository at this point in the history
feat: Allow predicates to own object and evaluate against borrowed types
  • Loading branch information
epage committed Dec 29, 2022
2 parents ee57a38 + 7934a3a commit 4e1d03c
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 139 deletions.
103 changes: 35 additions & 68 deletions src/iter.rs
Expand Up @@ -60,6 +60,11 @@ where
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("b"));
/// assert_eq!(true, predicate_fn.eval("c"));
///
/// let predicate_fn = predicate::in_iter(vec![String::from("a"), String::from("c"), String::from("e")]).sort();
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("b"));
/// assert_eq!(true, predicate_fn.eval("c"));
/// ```
pub fn sort(self) -> OrdInPredicate<T> {
let mut items = self.inner.debug;
Expand All @@ -70,33 +75,16 @@ where
}
}

impl<T> Predicate<T> for InPredicate<T>
where
T: PartialEq + fmt::Debug,
{
fn eval(&self, variable: &T) -> bool {
self.inner.debug.contains(variable)
}

fn find_case<'a>(&'a self, expected: bool, variable: &T) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
utils::DebugAdapter::new(variable).to_string(),
))
})
}
}

impl<'a, T> Predicate<T> for InPredicate<&'a T>
impl<P, T> Predicate<P> for InPredicate<T>
where
T: PartialEq + fmt::Debug + ?Sized,
T: std::borrow::Borrow<P> + PartialEq + fmt::Debug,
P: PartialEq + fmt::Debug + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
self.inner.debug.contains(&variable)
fn eval(&self, variable: &P) -> bool {
self.inner.debug.iter().any(|x| x.borrow() == variable)
}

fn find_case<'b>(&'b self, expected: bool, variable: &T) -> Option<reflection::Case<'b>> {
fn find_case<'a>(&'a self, expected: bool, variable: &P) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
Expand Down Expand Up @@ -160,6 +148,11 @@ where
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("b"));
/// assert_eq!(true, predicate_fn.eval("c"));
///
/// let predicate_fn = predicate::in_iter(vec![String::from("a"), String::from("c"), String::from("e")]);
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("b"));
/// assert_eq!(true, predicate_fn.eval("c"));
/// ```
pub fn in_iter<I, T>(iter: I) -> InPredicate<T>
where
Expand Down Expand Up @@ -188,33 +181,19 @@ where
inner: utils::DebugAdapter<Vec<T>>,
}

impl<T> Predicate<T> for OrdInPredicate<T>
impl<P, T> Predicate<P> for OrdInPredicate<T>
where
T: Ord + fmt::Debug,
T: std::borrow::Borrow<P> + Ord + fmt::Debug,
P: Ord + fmt::Debug + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
self.inner.debug.binary_search(variable).is_ok()
fn eval(&self, variable: &P) -> bool {
self.inner
.debug
.binary_search_by(|x| x.borrow().cmp(variable))
.is_ok()
}

fn find_case<'a>(&'a self, expected: bool, variable: &T) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
utils::DebugAdapter::new(variable).to_string(),
))
})
}
}

impl<'a, T> Predicate<T> for OrdInPredicate<&'a T>
where
T: Ord + fmt::Debug + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
self.inner.debug.binary_search(&variable).is_ok()
}

fn find_case<'b>(&'b self, expected: bool, variable: &T) -> Option<reflection::Case<'b>> {
fn find_case<'a>(&'a self, expected: bool, variable: &P) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
Expand Down Expand Up @@ -267,33 +246,16 @@ where
inner: utils::DebugAdapter<HashSet<T>>,
}

impl<T> Predicate<T> for HashableInPredicate<T>
impl<P, T> Predicate<P> for HashableInPredicate<T>
where
T: Hash + Eq + fmt::Debug,
T: std::borrow::Borrow<P> + Hash + Eq + fmt::Debug,
P: Hash + Eq + fmt::Debug + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
fn eval(&self, variable: &P) -> bool {
self.inner.debug.contains(variable)
}

fn find_case<'a>(&'a self, expected: bool, variable: &T) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
utils::DebugAdapter::new(variable).to_string(),
))
})
}
}

impl<'a, T> Predicate<T> for HashableInPredicate<&'a T>
where
T: Hash + Eq + fmt::Debug + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
self.inner.debug.contains(&variable)
}

fn find_case<'b>(&'b self, expected: bool, variable: &T) -> Option<reflection::Case<'b>> {
fn find_case<'a>(&'a self, expected: bool, variable: &P) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
Expand Down Expand Up @@ -351,6 +313,11 @@ where
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("b"));
/// assert_eq!(true, predicate_fn.eval("c"));
///
/// let predicate_fn = predicate::in_hash(vec![String::from("a"), String::from("c"), String::from("e")]);
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("b"));
/// assert_eq!(true, predicate_fn.eval("c"));
/// ```
pub fn in_hash<I, T>(iter: I) -> HashableInPredicate<T>
where
Expand Down
102 changes: 31 additions & 71 deletions src/ord.rs
@@ -1,4 +1,4 @@
// Copyright (c) 2018 The predicates-rs Project Developers.
// Copyright (c) 2018, 2022 The predicates-rs Project Developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/license/LICENSE-2.0> or the MIT license
Expand Down Expand Up @@ -35,26 +35,24 @@ impl fmt::Display for EqOps {
///
/// This is created by the `predicate::{eq, ne}` functions.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EqPredicate<T>
where
T: fmt::Debug + PartialEq,
{
pub struct EqPredicate<T> {
constant: T,
op: EqOps,
}

impl<T> Predicate<T> for EqPredicate<T>
impl<P, T> Predicate<P> for EqPredicate<T>
where
T: fmt::Debug + PartialEq,
T: std::borrow::Borrow<P> + fmt::Debug,
P: fmt::Debug + PartialEq + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
fn eval(&self, variable: &P) -> bool {
match self.op {
EqOps::Equal => variable.eq(&self.constant),
EqOps::NotEqual => variable.ne(&self.constant),
EqOps::Equal => variable.eq(self.constant.borrow()),
EqOps::NotEqual => variable.ne(self.constant.borrow()),
}
}

fn find_case<'a>(&'a self, expected: bool, variable: &T) -> Option<reflection::Case<'a>> {
fn find_case<'a>(&'a self, expected: bool, variable: &P) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
Expand All @@ -64,32 +62,11 @@ where
}
}

impl<'a, T> Predicate<T> for EqPredicate<&'a T>
where
T: fmt::Debug + PartialEq + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
match self.op {
EqOps::Equal => variable.eq(self.constant),
EqOps::NotEqual => variable.ne(self.constant),
}
}

fn find_case<'b>(&'b self, expected: bool, variable: &T) -> Option<reflection::Case<'b>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
utils::DebugAdapter::new(variable).to_string(),
))
})
}
}

impl<T> reflection::PredicateReflection for EqPredicate<T> where T: fmt::Debug + PartialEq {}
impl<T> reflection::PredicateReflection for EqPredicate<T> where T: fmt::Debug {}

impl<T> fmt::Display for EqPredicate<T>
where
T: fmt::Debug + PartialEq,
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let palette = crate::Palette::current();
Expand Down Expand Up @@ -120,6 +97,10 @@ where
/// let predicate_fn = predicate::eq("Hello");
/// assert_eq!(true, predicate_fn.eval("Hello"));
/// assert_eq!(false, predicate_fn.eval("Goodbye"));
///
/// let predicate_fn = predicate::eq(String::from("Hello"));
/// assert_eq!(true, predicate_fn.eval("Hello"));
/// assert_eq!(false, predicate_fn.eval("Goodbye"));
/// ```
pub fn eq<T>(constant: T) -> EqPredicate<T>
where
Expand Down Expand Up @@ -178,28 +159,26 @@ impl fmt::Display for OrdOps {
///
/// This is created by the `predicate::{gt, ge, lt, le}` functions.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct OrdPredicate<T>
where
T: fmt::Debug + PartialOrd,
{
pub struct OrdPredicate<T> {
constant: T,
op: OrdOps,
}

impl<T> Predicate<T> for OrdPredicate<T>
impl<P, T> Predicate<P> for OrdPredicate<T>
where
T: fmt::Debug + PartialOrd,
T: std::borrow::Borrow<P> + fmt::Debug,
P: fmt::Debug + PartialOrd + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
fn eval(&self, variable: &P) -> bool {
match self.op {
OrdOps::LessThan => variable.lt(&self.constant),
OrdOps::LessThanOrEqual => variable.le(&self.constant),
OrdOps::GreaterThanOrEqual => variable.ge(&self.constant),
OrdOps::GreaterThan => variable.gt(&self.constant),
OrdOps::LessThan => variable.lt(self.constant.borrow()),
OrdOps::LessThanOrEqual => variable.le(self.constant.borrow()),
OrdOps::GreaterThanOrEqual => variable.ge(self.constant.borrow()),
OrdOps::GreaterThan => variable.gt(self.constant.borrow()),
}
}

fn find_case<'a>(&'a self, expected: bool, variable: &T) -> Option<reflection::Case<'a>> {
fn find_case<'a>(&'a self, expected: bool, variable: &P) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
Expand All @@ -209,34 +188,11 @@ where
}
}

impl<'a, T> Predicate<T> for OrdPredicate<&'a T>
where
T: fmt::Debug + PartialOrd + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
match self.op {
OrdOps::LessThan => variable.lt(self.constant),
OrdOps::LessThanOrEqual => variable.le(self.constant),
OrdOps::GreaterThanOrEqual => variable.ge(self.constant),
OrdOps::GreaterThan => variable.gt(self.constant),
}
}

fn find_case<'b>(&'b self, expected: bool, variable: &T) -> Option<reflection::Case<'b>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
utils::DebugAdapter::new(variable).to_string(),
))
})
}
}

impl<T> reflection::PredicateReflection for OrdPredicate<T> where T: fmt::Debug + PartialOrd {}
impl<T> reflection::PredicateReflection for OrdPredicate<T> where T: fmt::Debug {}

impl<T> fmt::Display for OrdPredicate<T>
where
T: fmt::Debug + PartialOrd,
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let palette = crate::Palette::current();
Expand Down Expand Up @@ -267,6 +223,10 @@ where
/// let predicate_fn = predicate::lt("b");
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("c"));
///
/// let predicate_fn = predicate::lt(String::from("b"));
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("c"));
/// ```
pub fn lt<T>(constant: T) -> OrdPredicate<T>
where
Expand Down

0 comments on commit 4e1d03c

Please sign in to comment.