Skip to content

Commit 0e030e5

Browse files
first attempt at arb cost sized hash
1 parent cf063f6 commit 0e030e5

File tree

1 file changed

+100
-37
lines changed

1 file changed

+100
-37
lines changed

src/lib.rs

Lines changed: 100 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,17 @@ impl<T> Borrow<[T]> for KeyRef<alloc::vec::Vec<T>> {
158158
struct LruEntry<K, V> {
159159
key: mem::MaybeUninit<K>,
160160
val: mem::MaybeUninit<V>,
161+
cost: usize,
161162
prev: *mut LruEntry<K, V>,
162163
next: *mut LruEntry<K, V>,
163164
}
164165

165166
impl<K, V> LruEntry<K, V> {
166-
fn new(key: K, val: V) -> Self {
167+
fn new(key: K, val: V, cost: usize) -> Self {
167168
LruEntry {
168169
key: mem::MaybeUninit::new(key),
169170
val: mem::MaybeUninit::new(val),
171+
cost,
170172
prev: ptr::null_mut(),
171173
next: ptr::null_mut(),
172174
}
@@ -176,6 +178,7 @@ impl<K, V> LruEntry<K, V> {
176178
LruEntry {
177179
key: mem::MaybeUninit::uninit(),
178180
val: mem::MaybeUninit::uninit(),
181+
cost: 0,
179182
prev: ptr::null_mut(),
180183
next: ptr::null_mut(),
181184
}
@@ -190,7 +193,8 @@ pub type DefaultHasher = std::collections::hash_map::RandomState;
190193
/// An LRU Cache
191194
pub struct LruCache<K, V, S = DefaultHasher> {
192195
map: HashMap<KeyRef<K>, Box<LruEntry<K, V>>, S>,
193-
cap: NonZeroUsize,
196+
cost_cap: NonZeroUsize,
197+
cost: usize,
194198

195199
// head and tail are sigil nodes to facilitate inserting entries
196200
head: *mut LruEntry<K, V>,
@@ -272,7 +276,8 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
272276
// declare it as such since we only mutate it inside the unsafe block.
273277
let cache = LruCache {
274278
map,
275-
cap,
279+
cost_cap: cap,
280+
cost: 0,
276281
head: Box::into_raw(Box::new(LruEntry::new_sigil())),
277282
tail: Box::into_raw(Box::new(LruEntry::new_sigil())),
278283
};
@@ -303,7 +308,11 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
303308
/// assert_eq!(cache.get(&2), Some(&"beta"));
304309
/// ```
305310
pub fn put(&mut self, k: K, v: V) -> Option<V> {
306-
self.capturing_put(k, v, false).map(|(_, v)| v)
311+
self.put_with_cost(k, v, 1)
312+
}
313+
314+
pub fn put_with_cost(&mut self, k: K, v: V, cost: usize) -> Option<V> {
315+
self.capturing_put(k, v, false, cost).map(|(_, v)| v)
307316
}
308317

309318
/// Pushes a key-value pair into the cache. If an entry with key `k` already exists in
@@ -331,64 +340,101 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
331340
/// assert_eq!(cache.get(&3), Some(&"alpha"));
332341
/// ```
333342
pub fn push(&mut self, k: K, v: V) -> Option<(K, V)> {
334-
self.capturing_put(k, v, true)
343+
self.push_with_cost(k, v, 1)
344+
}
345+
346+
pub fn push_with_cost(&mut self, k: K, v: V, cost: usize) -> Option<(K, V)> {
347+
self.capturing_put(k, v, true, cost)
335348
}
336349

337350
// Used internally by `put` and `push` to add a new entry to the lru.
338351
// Takes ownership of and returns entries replaced due to the cache's capacity
339352
// when `capture` is true.
340-
fn capturing_put(&mut self, k: K, mut v: V, capture: bool) -> Option<(K, V)> {
353+
fn capturing_put(&mut self, k: K, mut v: V, capture: bool, cost: usize) -> Option<(K, V)> {
341354
let node_ref = self.map.get_mut(&KeyRef { k: &k });
342355

343356
match node_ref {
344357
Some(node_ref) => {
358+
let old_cost = node_ref.cost;
359+
node_ref.cost = cost;
360+
345361
let node_ptr: *mut LruEntry<K, V> = &mut **node_ref;
346362

347363
// if the key is already in the cache just update its value and move it to the
348364
// front of the list
349365
unsafe { mem::swap(&mut v, &mut (*(*node_ptr).val.as_mut_ptr()) as &mut V) }
350366
self.detach(node_ptr);
351367
self.attach(node_ptr);
368+
369+
self.cost -= old_cost;
370+
self.cost += cost;
371+
372+
self.shrink_within_cost();
373+
352374
Some((k, v))
353375
}
354376
None => {
355-
let (replaced, mut node) = self.replace_or_create_node(k, v);
377+
let (replaced, mut node) = self.replace_or_create_node(k, v, cost);
356378

357379
let node_ptr: *mut LruEntry<K, V> = &mut *node;
358380
self.attach(node_ptr);
359381

360382
let keyref = unsafe { (*node_ptr).key.as_ptr() };
361383
self.map.insert(KeyRef { k: keyref }, node);
362384

385+
self.shrink_within_cost();
386+
363387
replaced.filter(|_| capture)
364388
}
365389
}
366390
}
367391

392+
fn shrink_within_cost(&mut self) {
393+
let mut did_shrink = false;
394+
while self.cost() > self.cost_cap.get() {
395+
self.pop_lru();
396+
did_shrink = true;
397+
}
398+
if did_shrink {
399+
self.map.shrink_to_fit();
400+
}
401+
}
402+
368403
// Used internally to swap out a node if the cache is full or to create a new node if space
369404
// is available. Shared between `put`, `push`, `get_or_insert`, and `get_or_insert_mut`.
370405
#[allow(clippy::type_complexity)]
371-
fn replace_or_create_node(&mut self, k: K, v: V) -> (Option<(K, V)>, Box<LruEntry<K, V>>) {
372-
if self.len() == self.cap().get() {
406+
fn replace_or_create_node(
407+
&mut self,
408+
k: K,
409+
v: V,
410+
cost: usize,
411+
) -> (Option<(K, V)>, Box<LruEntry<K, V>>) {
412+
if self.cost + cost > self.cost_cap.get() && !self.is_empty() {
373413
// if the cache is full, remove the last entry so we can use it for the new key
374414
let old_key = KeyRef {
375415
k: unsafe { &(*(*(*self.tail).prev).key.as_ptr()) },
376416
};
377417
let mut old_node = self.map.remove(&old_key).unwrap();
418+
let old_cost = old_node.cost;
378419

379420
// read out the node's old key and value and then replace it
380421
let replaced = unsafe { (old_node.key.assume_init(), old_node.val.assume_init()) };
381422

382423
old_node.key = mem::MaybeUninit::new(k);
383424
old_node.val = mem::MaybeUninit::new(v);
425+
old_node.cost = cost;
426+
427+
self.cost -= old_cost;
428+
self.cost += cost;
384429

385430
let node_ptr: *mut LruEntry<K, V> = &mut *old_node;
386431
self.detach(node_ptr);
387432

388433
(Some(replaced), old_node)
389434
} else {
390435
// if the cache is not full allocate a new LruEntry
391-
(None, Box::new(LruEntry::new(k, v)))
436+
self.cost += cost;
437+
(None, Box::new(LruEntry::new(k, v, cost)))
392438
}
393439
}
394440

@@ -489,6 +535,13 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
489535
pub fn get_or_insert<'a, F>(&'a mut self, k: K, f: F) -> &'a V
490536
where
491537
F: FnOnce() -> V,
538+
{
539+
self.get_or_insert_cost(k, || (f(), 1))
540+
}
541+
542+
pub fn get_or_insert_cost<'a, F>(&'a mut self, k: K, f: F) -> &'a V
543+
where
544+
F: FnOnce() -> (V, usize),
492545
{
493546
if let Some(node) = self.map.get_mut(&KeyRef { k: &k }) {
494547
let node_ptr: *mut LruEntry<K, V> = &mut **node;
@@ -498,14 +551,17 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
498551

499552
unsafe { &(*(*node_ptr).val.as_ptr()) as &V }
500553
} else {
501-
let v = f();
502-
let (_, mut node) = self.replace_or_create_node(k, v);
554+
let (v, cost) = f();
555+
let (_, mut node) = self.replace_or_create_node(k, v, cost);
503556

504557
let node_ptr: *mut LruEntry<K, V> = &mut *node;
505558
self.attach(node_ptr);
506559

507560
let keyref = unsafe { (*node_ptr).key.as_ptr() };
508561
self.map.insert(KeyRef { k: keyref }, node);
562+
563+
self.shrink_within_cost();
564+
509565
unsafe { &(*(*node_ptr).val.as_ptr()) as &V }
510566
}
511567
}
@@ -535,6 +591,13 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
535591
pub fn get_or_insert_mut<'a, F>(&'a mut self, k: K, f: F) -> &'a mut V
536592
where
537593
F: FnOnce() -> V,
594+
{
595+
self.get_or_insert_cost_mut(k, || (f(), 1))
596+
}
597+
598+
pub fn get_or_insert_cost_mut<'a, F>(&'a mut self, k: K, f: F) -> &'a mut V
599+
where
600+
F: FnOnce() -> (V, usize),
538601
{
539602
if let Some(node) = self.map.get_mut(&KeyRef { k: &k }) {
540603
let node_ptr: *mut LruEntry<K, V> = &mut **node;
@@ -544,14 +607,17 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
544607

545608
unsafe { &mut (*(*node_ptr).val.as_mut_ptr()) as &mut V }
546609
} else {
547-
let v = f();
548-
let (_, mut node) = self.replace_or_create_node(k, v);
610+
let (v, cost) = f();
611+
let (_, mut node) = self.replace_or_create_node(k, v, cost);
549612

550613
let node_ptr: *mut LruEntry<K, V> = &mut *node;
551614
self.attach(node_ptr);
552615

553616
let keyref = unsafe { (*node_ptr).key.as_ptr() };
554617
self.map.insert(KeyRef { k: keyref }, node);
618+
619+
self.shrink_within_cost();
620+
555621
unsafe { &mut (*(*node_ptr).val.as_mut_ptr()) as &mut V }
556622
}
557623
}
@@ -693,6 +759,7 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
693759
match self.map.remove(k) {
694760
None => None,
695761
Some(mut old_node) => {
762+
self.cost -= old_node.cost;
696763
unsafe {
697764
ptr::drop_in_place(old_node.key.as_mut_ptr());
698765
}
@@ -730,6 +797,7 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
730797
match self.map.remove(k) {
731798
None => None,
732799
Some(mut old_node) => {
800+
self.cost -= old_node.cost;
733801
let node_ptr: *mut LruEntry<K, V> = &mut *old_node;
734802
self.detach(node_ptr);
735803
unsafe { Some((old_node.key.assume_init(), old_node.val.assume_init())) }
@@ -759,6 +827,7 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
759827
/// ```
760828
pub fn pop_lru(&mut self) -> Option<(K, V)> {
761829
let node = self.remove_last()?;
830+
self.cost -= node.cost;
762831
// N.B.: Can't destructure directly because of https://github.com/rust-lang/rust/issues/28536
763832
let node = *node;
764833
let LruEntry { key, val, .. } = node;
@@ -858,6 +927,10 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
858927
self.map.len()
859928
}
860929

930+
pub fn cost(&self) -> usize {
931+
self.cost
932+
}
933+
861934
/// Returns a bool indicating whether the cache is empty or not.
862935
///
863936
/// # Example
@@ -883,10 +956,10 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
883956
/// use lru::LruCache;
884957
/// use std::num::NonZeroUsize;
885958
/// let mut cache: LruCache<isize, &str> = LruCache::new(NonZeroUsize::new(2).unwrap());
886-
/// assert_eq!(cache.cap().get(), 2);
959+
/// assert_eq!(cache.cost_cap().get(), 2);
887960
/// ```
888-
pub fn cap(&self) -> NonZeroUsize {
889-
self.cap
961+
pub fn cost_cap(&self) -> NonZeroUsize {
962+
self.cost_cap
890963
}
891964

892965
/// Resizes the cache. If the new capacity is smaller than the size of the current
@@ -913,16 +986,12 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
913986
/// ```
914987
pub fn resize(&mut self, cap: NonZeroUsize) {
915988
// return early if capacity doesn't change
916-
if cap == self.cap {
989+
if cap == self.cost_cap {
917990
return;
918991
}
919992

920-
while self.map.len() > cap.get() {
921-
self.pop_lru();
922-
}
923-
self.map.shrink_to_fit();
924-
925-
self.cap = cap;
993+
self.cost_cap = cap;
994+
self.shrink_within_cost()
926995
}
927996

928997
/// Clears the contents of the cache.
@@ -1097,7 +1166,7 @@ impl<K: Hash + Eq, V> fmt::Debug for LruCache<K, V> {
10971166
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
10981167
f.debug_struct("LruCache")
10991168
.field("len", &self.len())
1100-
.field("cap", &self.cap())
1169+
.field("cost_cap", &self.cost_cap())
11011170
.finish()
11021171
}
11031172
}
@@ -1361,7 +1430,7 @@ mod tests {
13611430
assert_eq!(cache.put("apple", "red"), None);
13621431
assert_eq!(cache.put("banana", "yellow"), None);
13631432

1364-
assert_eq!(cache.cap().get(), 2);
1433+
assert_eq!(cache.cost_cap().get(), 2);
13651434
assert_eq!(cache.len(), 2);
13661435
assert!(!cache.is_empty());
13671436
assert_opt_eq(cache.get(&"apple"), "red");
@@ -1376,7 +1445,7 @@ mod tests {
13761445
assert_eq!(cache.put("apple", "red"), None);
13771446
assert_eq!(cache.put("banana", "yellow"), None);
13781447

1379-
assert_eq!(cache.cap().get(), 2);
1448+
assert_eq!(cache.cost_cap().get(), 2);
13801449
assert_eq!(cache.len(), 2);
13811450
assert!(!cache.is_empty());
13821451
assert_eq!(cache.get_or_insert("apple", || "orange"), &"red");
@@ -1393,7 +1462,7 @@ mod tests {
13931462
assert_eq!(cache.put("apple", "red"), None);
13941463
assert_eq!(cache.put("banana", "yellow"), None);
13951464

1396-
assert_eq!(cache.cap().get(), 2);
1465+
assert_eq!(cache.cost_cap().get(), 2);
13971466
assert_eq!(cache.len(), 2);
13981467

13991468
let v = cache.get_or_insert_mut("apple", || "orange");
@@ -1413,7 +1482,7 @@ mod tests {
14131482
cache.put("apple", "red");
14141483
cache.put("banana", "yellow");
14151484

1416-
assert_eq!(cache.cap().get(), 2);
1485+
assert_eq!(cache.cost_cap().get(), 2);
14171486
assert_eq!(cache.len(), 2);
14181487
assert_opt_eq_mut(cache.get_mut(&"apple"), "red");
14191488
assert_opt_eq_mut(cache.get_mut(&"banana"), "yellow");
@@ -1431,7 +1500,7 @@ mod tests {
14311500
*v = 4;
14321501
}
14331502

1434-
assert_eq!(cache.cap().get(), 2);
1503+
assert_eq!(cache.cost_cap().get(), 2);
14351504
assert_eq!(cache.len(), 2);
14361505
assert_opt_eq_mut(cache.get_mut(&"apple"), 4);
14371506
assert_opt_eq_mut(cache.get_mut(&"banana"), 3);
@@ -2006,15 +2075,9 @@ mod tests {
20062075
fn test_no_memory_leaks_with_pop() {
20072076
static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
20082077

2009-
#[derive(Hash, Eq)]
2078+
#[derive(Hash, PartialEq, Eq)]
20102079
struct KeyDropCounter(usize);
20112080

2012-
impl PartialEq for KeyDropCounter {
2013-
fn eq(&self, other: &Self) -> bool {
2014-
self.0.eq(&other.0)
2015-
}
2016-
}
2017-
20182081
impl Drop for KeyDropCounter {
20192082
fn drop(&mut self) {
20202083
DROP_COUNT.fetch_add(1, Ordering::SeqCst);

0 commit comments

Comments
 (0)