@@ -158,15 +158,17 @@ impl<T> Borrow<[T]> for KeyRef<alloc::vec::Vec<T>> {
158
158
struct LruEntry < K , V > {
159
159
key : mem:: MaybeUninit < K > ,
160
160
val : mem:: MaybeUninit < V > ,
161
+ cost : usize ,
161
162
prev : * mut LruEntry < K , V > ,
162
163
next : * mut LruEntry < K , V > ,
163
164
}
164
165
165
166
impl < K , V > LruEntry < K , V > {
166
- fn new ( key : K , val : V ) -> Self {
167
+ fn new ( key : K , val : V , cost : usize ) -> Self {
167
168
LruEntry {
168
169
key : mem:: MaybeUninit :: new ( key) ,
169
170
val : mem:: MaybeUninit :: new ( val) ,
171
+ cost,
170
172
prev : ptr:: null_mut ( ) ,
171
173
next : ptr:: null_mut ( ) ,
172
174
}
@@ -176,6 +178,7 @@ impl<K, V> LruEntry<K, V> {
176
178
LruEntry {
177
179
key : mem:: MaybeUninit :: uninit ( ) ,
178
180
val : mem:: MaybeUninit :: uninit ( ) ,
181
+ cost : 0 ,
179
182
prev : ptr:: null_mut ( ) ,
180
183
next : ptr:: null_mut ( ) ,
181
184
}
@@ -190,7 +193,8 @@ pub type DefaultHasher = std::collections::hash_map::RandomState;
190
193
/// An LRU Cache
191
194
pub struct LruCache < K , V , S = DefaultHasher > {
192
195
map : HashMap < KeyRef < K > , Box < LruEntry < K , V > > , S > ,
193
- cap : NonZeroUsize ,
196
+ cost_cap : NonZeroUsize ,
197
+ cost : usize ,
194
198
195
199
// head and tail are sigil nodes to facilitate inserting entries
196
200
head : * mut LruEntry < K , V > ,
@@ -272,7 +276,8 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
272
276
// declare it as such since we only mutate it inside the unsafe block.
273
277
let cache = LruCache {
274
278
map,
275
- cap,
279
+ cost_cap : cap,
280
+ cost : 0 ,
276
281
head : Box :: into_raw ( Box :: new ( LruEntry :: new_sigil ( ) ) ) ,
277
282
tail : Box :: into_raw ( Box :: new ( LruEntry :: new_sigil ( ) ) ) ,
278
283
} ;
@@ -303,7 +308,11 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
303
308
/// assert_eq!(cache.get(&2), Some(&"beta"));
304
309
/// ```
305
310
pub fn put ( & mut self , k : K , v : V ) -> Option < V > {
306
- self . capturing_put ( k, v, false ) . map ( |( _, v) | v)
311
+ self . put_with_cost ( k, v, 1 )
312
+ }
313
+
314
+ pub fn put_with_cost ( & mut self , k : K , v : V , cost : usize ) -> Option < V > {
315
+ self . capturing_put ( k, v, false , cost) . map ( |( _, v) | v)
307
316
}
308
317
309
318
/// Pushes a key-value pair into the cache. If an entry with key `k` already exists in
@@ -331,64 +340,101 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
331
340
/// assert_eq!(cache.get(&3), Some(&"alpha"));
332
341
/// ```
333
342
pub fn push ( & mut self , k : K , v : V ) -> Option < ( K , V ) > {
334
- self . capturing_put ( k, v, true )
343
+ self . push_with_cost ( k, v, 1 )
344
+ }
345
+
346
+ pub fn push_with_cost ( & mut self , k : K , v : V , cost : usize ) -> Option < ( K , V ) > {
347
+ self . capturing_put ( k, v, true , cost)
335
348
}
336
349
337
350
// Used internally by `put` and `push` to add a new entry to the lru.
338
351
// Takes ownership of and returns entries replaced due to the cache's capacity
339
352
// when `capture` is true.
340
- fn capturing_put ( & mut self , k : K , mut v : V , capture : bool ) -> Option < ( K , V ) > {
353
+ fn capturing_put ( & mut self , k : K , mut v : V , capture : bool , cost : usize ) -> Option < ( K , V ) > {
341
354
let node_ref = self . map . get_mut ( & KeyRef { k : & k } ) ;
342
355
343
356
match node_ref {
344
357
Some ( node_ref) => {
358
+ let old_cost = node_ref. cost ;
359
+ node_ref. cost = cost;
360
+
345
361
let node_ptr: * mut LruEntry < K , V > = & mut * * node_ref;
346
362
347
363
// if the key is already in the cache just update its value and move it to the
348
364
// front of the list
349
365
unsafe { mem:: swap ( & mut v, & mut ( * ( * node_ptr) . val . as_mut_ptr ( ) ) as & mut V ) }
350
366
self . detach ( node_ptr) ;
351
367
self . attach ( node_ptr) ;
368
+
369
+ self . cost -= old_cost;
370
+ self . cost += cost;
371
+
372
+ self . shrink_within_cost ( ) ;
373
+
352
374
Some ( ( k, v) )
353
375
}
354
376
None => {
355
- let ( replaced, mut node) = self . replace_or_create_node ( k, v) ;
377
+ let ( replaced, mut node) = self . replace_or_create_node ( k, v, cost ) ;
356
378
357
379
let node_ptr: * mut LruEntry < K , V > = & mut * node;
358
380
self . attach ( node_ptr) ;
359
381
360
382
let keyref = unsafe { ( * node_ptr) . key . as_ptr ( ) } ;
361
383
self . map . insert ( KeyRef { k : keyref } , node) ;
362
384
385
+ self . shrink_within_cost ( ) ;
386
+
363
387
replaced. filter ( |_| capture)
364
388
}
365
389
}
366
390
}
367
391
392
+ fn shrink_within_cost ( & mut self ) {
393
+ let mut did_shrink = false ;
394
+ while self . cost ( ) > self . cost_cap . get ( ) {
395
+ self . pop_lru ( ) ;
396
+ did_shrink = true ;
397
+ }
398
+ if did_shrink {
399
+ self . map . shrink_to_fit ( ) ;
400
+ }
401
+ }
402
+
368
403
// Used internally to swap out a node if the cache is full or to create a new node if space
369
404
// is available. Shared between `put`, `push`, `get_or_insert`, and `get_or_insert_mut`.
370
405
#[ allow( clippy:: type_complexity) ]
371
- fn replace_or_create_node ( & mut self , k : K , v : V ) -> ( Option < ( K , V ) > , Box < LruEntry < K , V > > ) {
372
- if self . len ( ) == self . cap ( ) . get ( ) {
406
+ fn replace_or_create_node (
407
+ & mut self ,
408
+ k : K ,
409
+ v : V ,
410
+ cost : usize ,
411
+ ) -> ( Option < ( K , V ) > , Box < LruEntry < K , V > > ) {
412
+ if self . cost + cost > self . cost_cap . get ( ) && !self . is_empty ( ) {
373
413
// if the cache is full, remove the last entry so we can use it for the new key
374
414
let old_key = KeyRef {
375
415
k : unsafe { & ( * ( * ( * self . tail ) . prev ) . key . as_ptr ( ) ) } ,
376
416
} ;
377
417
let mut old_node = self . map . remove ( & old_key) . unwrap ( ) ;
418
+ let old_cost = old_node. cost ;
378
419
379
420
// read out the node's old key and value and then replace it
380
421
let replaced = unsafe { ( old_node. key . assume_init ( ) , old_node. val . assume_init ( ) ) } ;
381
422
382
423
old_node. key = mem:: MaybeUninit :: new ( k) ;
383
424
old_node. val = mem:: MaybeUninit :: new ( v) ;
425
+ old_node. cost = cost;
426
+
427
+ self . cost -= old_cost;
428
+ self . cost += cost;
384
429
385
430
let node_ptr: * mut LruEntry < K , V > = & mut * old_node;
386
431
self . detach ( node_ptr) ;
387
432
388
433
( Some ( replaced) , old_node)
389
434
} else {
390
435
// if the cache is not full allocate a new LruEntry
391
- ( None , Box :: new ( LruEntry :: new ( k, v) ) )
436
+ self . cost += cost;
437
+ ( None , Box :: new ( LruEntry :: new ( k, v, cost) ) )
392
438
}
393
439
}
394
440
@@ -489,6 +535,13 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
489
535
pub fn get_or_insert < ' a , F > ( & ' a mut self , k : K , f : F ) -> & ' a V
490
536
where
491
537
F : FnOnce ( ) -> V ,
538
+ {
539
+ self . get_or_insert_cost ( k, || ( f ( ) , 1 ) )
540
+ }
541
+
542
+ pub fn get_or_insert_cost < ' a , F > ( & ' a mut self , k : K , f : F ) -> & ' a V
543
+ where
544
+ F : FnOnce ( ) -> ( V , usize ) ,
492
545
{
493
546
if let Some ( node) = self . map . get_mut ( & KeyRef { k : & k } ) {
494
547
let node_ptr: * mut LruEntry < K , V > = & mut * * node;
@@ -498,14 +551,17 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
498
551
499
552
unsafe { & ( * ( * node_ptr) . val . as_ptr ( ) ) as & V }
500
553
} else {
501
- let v = f ( ) ;
502
- let ( _, mut node) = self . replace_or_create_node ( k, v) ;
554
+ let ( v , cost ) = f ( ) ;
555
+ let ( _, mut node) = self . replace_or_create_node ( k, v, cost ) ;
503
556
504
557
let node_ptr: * mut LruEntry < K , V > = & mut * node;
505
558
self . attach ( node_ptr) ;
506
559
507
560
let keyref = unsafe { ( * node_ptr) . key . as_ptr ( ) } ;
508
561
self . map . insert ( KeyRef { k : keyref } , node) ;
562
+
563
+ self . shrink_within_cost ( ) ;
564
+
509
565
unsafe { & ( * ( * node_ptr) . val . as_ptr ( ) ) as & V }
510
566
}
511
567
}
@@ -535,6 +591,13 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
535
591
pub fn get_or_insert_mut < ' a , F > ( & ' a mut self , k : K , f : F ) -> & ' a mut V
536
592
where
537
593
F : FnOnce ( ) -> V ,
594
+ {
595
+ self . get_or_insert_cost_mut ( k, || ( f ( ) , 1 ) )
596
+ }
597
+
598
+ pub fn get_or_insert_cost_mut < ' a , F > ( & ' a mut self , k : K , f : F ) -> & ' a mut V
599
+ where
600
+ F : FnOnce ( ) -> ( V , usize ) ,
538
601
{
539
602
if let Some ( node) = self . map . get_mut ( & KeyRef { k : & k } ) {
540
603
let node_ptr: * mut LruEntry < K , V > = & mut * * node;
@@ -544,14 +607,17 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
544
607
545
608
unsafe { & mut ( * ( * node_ptr) . val . as_mut_ptr ( ) ) as & mut V }
546
609
} else {
547
- let v = f ( ) ;
548
- let ( _, mut node) = self . replace_or_create_node ( k, v) ;
610
+ let ( v , cost ) = f ( ) ;
611
+ let ( _, mut node) = self . replace_or_create_node ( k, v, cost ) ;
549
612
550
613
let node_ptr: * mut LruEntry < K , V > = & mut * node;
551
614
self . attach ( node_ptr) ;
552
615
553
616
let keyref = unsafe { ( * node_ptr) . key . as_ptr ( ) } ;
554
617
self . map . insert ( KeyRef { k : keyref } , node) ;
618
+
619
+ self . shrink_within_cost ( ) ;
620
+
555
621
unsafe { & mut ( * ( * node_ptr) . val . as_mut_ptr ( ) ) as & mut V }
556
622
}
557
623
}
@@ -693,6 +759,7 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
693
759
match self . map . remove ( k) {
694
760
None => None ,
695
761
Some ( mut old_node) => {
762
+ self . cost -= old_node. cost ;
696
763
unsafe {
697
764
ptr:: drop_in_place ( old_node. key . as_mut_ptr ( ) ) ;
698
765
}
@@ -730,6 +797,7 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
730
797
match self . map . remove ( k) {
731
798
None => None ,
732
799
Some ( mut old_node) => {
800
+ self . cost -= old_node. cost ;
733
801
let node_ptr: * mut LruEntry < K , V > = & mut * old_node;
734
802
self . detach ( node_ptr) ;
735
803
unsafe { Some ( ( old_node. key . assume_init ( ) , old_node. val . assume_init ( ) ) ) }
@@ -759,6 +827,7 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
759
827
/// ```
760
828
pub fn pop_lru ( & mut self ) -> Option < ( K , V ) > {
761
829
let node = self . remove_last ( ) ?;
830
+ self . cost -= node. cost ;
762
831
// N.B.: Can't destructure directly because of https://github.com/rust-lang/rust/issues/28536
763
832
let node = * node;
764
833
let LruEntry { key, val, .. } = node;
@@ -858,6 +927,10 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
858
927
self . map . len ( )
859
928
}
860
929
930
+ pub fn cost ( & self ) -> usize {
931
+ self . cost
932
+ }
933
+
861
934
/// Returns a bool indicating whether the cache is empty or not.
862
935
///
863
936
/// # Example
@@ -883,10 +956,10 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
883
956
/// use lru::LruCache;
884
957
/// use std::num::NonZeroUsize;
885
958
/// let mut cache: LruCache<isize, &str> = LruCache::new(NonZeroUsize::new(2).unwrap());
886
- /// assert_eq!(cache.cap ().get(), 2);
959
+ /// assert_eq!(cache.cost_cap ().get(), 2);
887
960
/// ```
888
- pub fn cap ( & self ) -> NonZeroUsize {
889
- self . cap
961
+ pub fn cost_cap ( & self ) -> NonZeroUsize {
962
+ self . cost_cap
890
963
}
891
964
892
965
/// Resizes the cache. If the new capacity is smaller than the size of the current
@@ -913,16 +986,12 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
913
986
/// ```
914
987
pub fn resize ( & mut self , cap : NonZeroUsize ) {
915
988
// return early if capacity doesn't change
916
- if cap == self . cap {
989
+ if cap == self . cost_cap {
917
990
return ;
918
991
}
919
992
920
- while self . map . len ( ) > cap. get ( ) {
921
- self . pop_lru ( ) ;
922
- }
923
- self . map . shrink_to_fit ( ) ;
924
-
925
- self . cap = cap;
993
+ self . cost_cap = cap;
994
+ self . shrink_within_cost ( )
926
995
}
927
996
928
997
/// Clears the contents of the cache.
@@ -1097,7 +1166,7 @@ impl<K: Hash + Eq, V> fmt::Debug for LruCache<K, V> {
1097
1166
fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
1098
1167
f. debug_struct ( "LruCache" )
1099
1168
. field ( "len" , & self . len ( ) )
1100
- . field ( "cap " , & self . cap ( ) )
1169
+ . field ( "cost_cap " , & self . cost_cap ( ) )
1101
1170
. finish ( )
1102
1171
}
1103
1172
}
@@ -1361,7 +1430,7 @@ mod tests {
1361
1430
assert_eq ! ( cache. put( "apple" , "red" ) , None ) ;
1362
1431
assert_eq ! ( cache. put( "banana" , "yellow" ) , None ) ;
1363
1432
1364
- assert_eq ! ( cache. cap ( ) . get( ) , 2 ) ;
1433
+ assert_eq ! ( cache. cost_cap ( ) . get( ) , 2 ) ;
1365
1434
assert_eq ! ( cache. len( ) , 2 ) ;
1366
1435
assert ! ( !cache. is_empty( ) ) ;
1367
1436
assert_opt_eq ( cache. get ( & "apple" ) , "red" ) ;
@@ -1376,7 +1445,7 @@ mod tests {
1376
1445
assert_eq ! ( cache. put( "apple" , "red" ) , None ) ;
1377
1446
assert_eq ! ( cache. put( "banana" , "yellow" ) , None ) ;
1378
1447
1379
- assert_eq ! ( cache. cap ( ) . get( ) , 2 ) ;
1448
+ assert_eq ! ( cache. cost_cap ( ) . get( ) , 2 ) ;
1380
1449
assert_eq ! ( cache. len( ) , 2 ) ;
1381
1450
assert ! ( !cache. is_empty( ) ) ;
1382
1451
assert_eq ! ( cache. get_or_insert( "apple" , || "orange" ) , & "red" ) ;
@@ -1393,7 +1462,7 @@ mod tests {
1393
1462
assert_eq ! ( cache. put( "apple" , "red" ) , None ) ;
1394
1463
assert_eq ! ( cache. put( "banana" , "yellow" ) , None ) ;
1395
1464
1396
- assert_eq ! ( cache. cap ( ) . get( ) , 2 ) ;
1465
+ assert_eq ! ( cache. cost_cap ( ) . get( ) , 2 ) ;
1397
1466
assert_eq ! ( cache. len( ) , 2 ) ;
1398
1467
1399
1468
let v = cache. get_or_insert_mut ( "apple" , || "orange" ) ;
@@ -1413,7 +1482,7 @@ mod tests {
1413
1482
cache. put ( "apple" , "red" ) ;
1414
1483
cache. put ( "banana" , "yellow" ) ;
1415
1484
1416
- assert_eq ! ( cache. cap ( ) . get( ) , 2 ) ;
1485
+ assert_eq ! ( cache. cost_cap ( ) . get( ) , 2 ) ;
1417
1486
assert_eq ! ( cache. len( ) , 2 ) ;
1418
1487
assert_opt_eq_mut ( cache. get_mut ( & "apple" ) , "red" ) ;
1419
1488
assert_opt_eq_mut ( cache. get_mut ( & "banana" ) , "yellow" ) ;
@@ -1431,7 +1500,7 @@ mod tests {
1431
1500
* v = 4 ;
1432
1501
}
1433
1502
1434
- assert_eq ! ( cache. cap ( ) . get( ) , 2 ) ;
1503
+ assert_eq ! ( cache. cost_cap ( ) . get( ) , 2 ) ;
1435
1504
assert_eq ! ( cache. len( ) , 2 ) ;
1436
1505
assert_opt_eq_mut ( cache. get_mut ( & "apple" ) , 4 ) ;
1437
1506
assert_opt_eq_mut ( cache. get_mut ( & "banana" ) , 3 ) ;
@@ -2006,15 +2075,9 @@ mod tests {
2006
2075
fn test_no_memory_leaks_with_pop ( ) {
2007
2076
static DROP_COUNT : AtomicUsize = AtomicUsize :: new ( 0 ) ;
2008
2077
2009
- #[ derive( Hash , Eq ) ]
2078
+ #[ derive( Hash , PartialEq , Eq ) ]
2010
2079
struct KeyDropCounter ( usize ) ;
2011
2080
2012
- impl PartialEq for KeyDropCounter {
2013
- fn eq ( & self , other : & Self ) -> bool {
2014
- self . 0 . eq ( & other. 0 )
2015
- }
2016
- }
2017
-
2018
2081
impl Drop for KeyDropCounter {
2019
2082
fn drop ( & mut self ) {
2020
2083
DROP_COUNT . fetch_add ( 1 , Ordering :: SeqCst ) ;
0 commit comments