Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_mut_or_insert #151

Merged
merged 3 commits into from Sep 4, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
79 changes: 73 additions & 6 deletions src/lib.rs
Expand Up @@ -192,7 +192,7 @@ pub struct LruCache<K, V, S = DefaultHasher> {
map: HashMap<KeyRef<K>, Box<LruEntry<K, V>>, S>,
cap: NonZeroUsize,

// head and tail are sigil nodes to faciliate inserting entries
// head and tail are sigil nodes to facilitate inserting entries
head: *mut LruEntry<K, V>,
tail: *mut LruEntry<K, V>,
}
Expand Down Expand Up @@ -366,7 +366,7 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
}

// Used internally to swap out a node if the cache is full or to create a new node if space
// is available. Shared between `put`, `push`, and `get_or_insert`.
// is available. Shared between `put`, `push`, `get_or_insert`, and `get_or_insert_mut`.
#[allow(clippy::type_complexity)]
fn replace_or_create_node(&mut self, k: K, v: V) -> (Option<(K, V)>, Box<LruEntry<K, V>>) {
if self.len() == self.cap().get() {
Expand Down Expand Up @@ -510,6 +510,52 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
}
}

/// Returns a mutable reference to the value of the key in the cache if it is
/// present in the cache and moves the key to the head of the LRU list.
/// If the key does not exist the provided `FnOnce` is used to populate
/// the list and a mutable reference is returned.
///
/// # Example
///
/// ```
/// use lru::LruCache;
/// use std::num::NonZeroUsize;
/// let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap());
///
/// cache.put(1, "a");
/// cache.put(2, "b");
///
/// let v = cache.get_or_insert_mut(2, ||"c");
/// assert_eq!(v, &"b");
/// *v = "d";
/// assert_eq!(cache.get_or_insert_mut(2, ||"e"), &mut "d");
/// assert_eq!(cache.get_or_insert_mut(3, ||"f"), &mut "f");
/// assert_eq!(cache.get_or_insert_mut(3, ||"e"), &mut "f");
/// ```
pub fn get_or_insert_mut<'a, F>(&mut self, k: K, f: F) -> &'a mut V
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pub fn get_or_insert_mut<'a, F>(&mut self, k: K, f: F) -> &'a mut V
pub fn get_or_insert_mut<'a, F>(&'a mut self, k: K, f: F) -> &'a mut V

#153

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matklad fixed.

where
F: FnOnce() -> V,
{
if let Some(node) = self.map.get_mut(&KeyRef { k: &k }) {
let node_ptr: *mut LruEntry<K, V> = &mut **node;

self.detach(node_ptr);
self.attach(node_ptr);

unsafe { &mut (*(*node_ptr).val.as_mut_ptr()) as &mut V }
} else {
let v = f();
let (_, mut node) = self.replace_or_create_node(k, v);

let node_ptr: *mut LruEntry<K, V> = &mut *node;
self.attach(node_ptr);

let keyref = unsafe { (*node_ptr).key.as_ptr() };
self.map.insert(KeyRef { k: keyref }, node);
unsafe { &mut (*(*node_ptr).val.as_mut_ptr()) as &mut V }
}
}

/// Returns a reference to the value corresponding to the key in the cache or `None` if it is
/// not present in the cache. Unlike `get`, `peek` does not update the LRU list so the key's
/// position will be unchanged.
Expand Down Expand Up @@ -1252,10 +1298,31 @@ mod tests {
assert_eq!(cache.cap().get(), 2);
assert_eq!(cache.len(), 2);
assert!(!cache.is_empty());
assert_eq!(cache.get_or_insert(&"apple", || "orange"), &"red");
assert_eq!(cache.get_or_insert(&"banana", || "orange"), &"yellow");
assert_eq!(cache.get_or_insert(&"lemon", || "orange"), &"orange");
assert_eq!(cache.get_or_insert(&"lemon", || "red"), &"orange");
assert_eq!(cache.get_or_insert("apple", || "orange"), &"red");
assert_eq!(cache.get_or_insert("banana", || "orange"), &"yellow");
assert_eq!(cache.get_or_insert("lemon", || "orange"), &"orange");
assert_eq!(cache.get_or_insert("lemon", || "red"), &"orange");
}

#[test]
fn test_put_and_get_or_insert_mut() {
let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap());
assert!(cache.is_empty());

assert_eq!(cache.put("apple", "red"), None);
assert_eq!(cache.put("banana", "yellow"), None);

assert_eq!(cache.cap().get(), 2);
assert_eq!(cache.len(), 2);

let v = cache.get_or_insert_mut("apple", || "orange");
assert_eq!(v, &"red");
*v = "blue";

assert_eq!(cache.get_or_insert_mut("apple", || "orange"), &"blue");
assert_eq!(cache.get_or_insert_mut("banana", || "orange"), &"yellow");
assert_eq!(cache.get_or_insert_mut("lemon", || "orange"), &"orange");
assert_eq!(cache.get_or_insert_mut("lemon", || "red"), &"orange");
}

#[test]
Expand Down