/
MemoizingMRUCache.cs
313 lines (275 loc) · 11.9 KB
/
MemoizingMRUCache.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
// Copyright (c) 2019 .NET Foundation and Contributors. All rights reserved.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Diagnostics.Contracts;
using System.Linq;
using System.Threading;
namespace Splat
{
/// <summary>
/// This data structure is a representation of a memoizing cache - i.e. a
/// class that will evaluate a function, but keep a cache of recently
/// evaluated parameters.
///
/// Since this is a memoizing cache, it is important that this function be a
/// "pure" function in the mathematical sense - that a key *always* maps to
/// a corresponding return value.
/// </summary>
/// <typeparam name="TParam">The type of the parameter to the calculation function.</typeparam>
/// <typeparam name="TVal">The type of the value returned by the calculation
/// function.</typeparam>
public sealed class MemoizingMRUCache<TParam, TVal>
{
private readonly object _lockObject = new object();
private readonly Func<TParam, object, TVal> _calculationFunction;
private readonly Action<TVal> _releaseFunction;
private readonly int _maxCacheSize;
private readonly IEqualityComparer<TParam> _comparer;
private LinkedList<TParam> _cacheMRUList;
private Dictionary<TParam, (LinkedListNode<TParam> param, TVal value)> _cacheEntries;
/// <summary>
/// Initializes a new instance of the <see cref="MemoizingMRUCache{TParam, TVal}"/> class.
/// </summary>
/// <param name="calculationFunc">The function whose results you want to cache,
/// which is provided the key value, and an Tag object that is
/// user-defined.</param>
/// <param name="maxSize">The size of the cache to maintain, after which old
/// items will start to be thrown out.</param>
public MemoizingMRUCache(Func<TParam, object, TVal> calculationFunc, int maxSize)
: this(calculationFunc, maxSize, null, EqualityComparer<TParam>.Default)
{
}
/// <summary>
/// Initializes a new instance of the <see cref="MemoizingMRUCache{TParam, TVal}"/> class.
/// </summary>
/// <param name="calculationFunc">The function whose results you want to cache,
/// which is provided the key value, and an Tag object that is
/// user-defined.</param>
/// <param name="maxSize">The size of the cache to maintain, after which old
/// items will start to be thrown out.</param>
/// <param name="onRelease">A function to call when a result gets
/// evicted from the cache (i.e. because Invalidate was called or the
/// cache is full).</param>
public MemoizingMRUCache(Func<TParam, object, TVal> calculationFunc, int maxSize, Action<TVal> onRelease)
: this(calculationFunc, maxSize, onRelease, EqualityComparer<TParam>.Default)
{
}
/// <summary>
/// Initializes a new instance of the <see cref="MemoizingMRUCache{TParam, TVal}"/> class.
/// </summary>
/// <param name="calculationFunc">The function whose results you want to cache,
/// which is provided the key value, and an Tag object that is
/// user-defined.</param>
/// <param name="maxSize">The size of the cache to maintain, after which old
/// items will start to be thrown out.</param>
/// <param name="paramComparer">A comparer for the parameter.</param>
public MemoizingMRUCache(Func<TParam, object, TVal> calculationFunc, int maxSize, IEqualityComparer<TParam> paramComparer)
: this(calculationFunc, maxSize, null, paramComparer)
{
}
/// <summary>
/// Initializes a new instance of the <see cref="MemoizingMRUCache{TParam, TVal}"/> class.
/// </summary>
/// <param name="calculationFunc">The function whose results you want to cache,
/// which is provided the key value, and an Tag object that is
/// user-defined.</param>
/// <param name="maxSize">The size of the cache to maintain, after which old
/// items will start to be thrown out.</param>
/// <param name="onRelease">A function to call when a result gets
/// evicted from the cache (i.e. because Invalidate was called or the
/// cache is full).</param>
/// <param name="paramComparer">A comparer for the parameter.</param>
public MemoizingMRUCache(Func<TParam, object, TVal> calculationFunc, int maxSize, Action<TVal> onRelease, IEqualityComparer<TParam> paramComparer)
{
Contract.Requires(calculationFunc != null);
Contract.Requires(maxSize > 0);
_calculationFunction = calculationFunc;
_releaseFunction = onRelease;
_maxCacheSize = maxSize;
_comparer = paramComparer ?? EqualityComparer<TParam>.Default;
InvalidateAll();
}
/// <summary>
/// Gets the value from the specified key.
/// </summary>
/// <param name="key">The value to pass to the calculation function.</param>
/// <returns>The value that we have got.</returns>
public TVal Get(TParam key)
{
return Get(key, null);
}
/// <summary>
/// Evaluates the function provided, returning the cached value if possible.
/// </summary>
/// <param name="key">The value to pass to the calculation function.</param>
/// <param name="context">An additional optional user-specific parameter.</param>
/// <returns>The value that we have got.</returns>
public TVal Get(TParam key, object context = null)
{
Contract.Requires(key != null);
lock (_lockObject)
{
if (_cacheEntries.TryGetValue(key, out var found))
{
RefreshEntry(found.param);
return found.value;
}
var result = _calculationFunction(key, context);
var node = new LinkedListNode<TParam>(key);
_cacheMRUList.AddFirst(node);
_cacheEntries[key] = (node, result);
MaintainCache();
return result;
}
}
/// <summary>
/// Tries to get the value if it's available.
/// </summary>
/// <param name="key">The input value of the key to use.</param>
/// <param name="result">The result if available, otherwise it will be the default value.</param>
/// <returns>If we were able to retrieve the value or not.</returns>
public bool TryGet(TParam key, out TVal result)
{
Contract.Requires(key != null);
lock (_lockObject)
{
var ret = _cacheEntries.TryGetValue(key, out var output);
if (ret)
{
RefreshEntry(output.param);
result = output.value;
}
else
{
result = default(TVal);
}
return ret;
}
}
/// <summary>
/// Ensure that the next time this key is queried, the calculation
/// function will be called.
/// </summary>
/// <param name="key">The key to invalidate the value for.</param>
public void Invalidate(TParam key)
{
Contract.Requires(key != null);
lock (_lockObject)
{
if (!_cacheEntries.TryGetValue(key, out var toRemove))
{
return;
}
var releaseVar = toRemove.value;
_cacheMRUList.Remove(toRemove.param);
_cacheEntries.Remove(key);
// moved down to allow removal from list
// even if the release call fails.
_releaseFunction?.Invoke(releaseVar);
}
}
/// <summary>
/// Invalidate all the items in the cache.
/// </summary>
/// <param name="aggregateReleaseExceptions">
/// Flag to indicate whether Exceptions during the resource Release call should not fail on the first item.
/// But should try all items then throw an aggregate exception.
/// </param>
public void InvalidateAll(bool aggregateReleaseExceptions = false)
{
Dictionary<TParam, (LinkedListNode<TParam> param, TVal value)> oldCacheToClear = null;
lock (_lockObject)
{
if (_releaseFunction == null || _cacheEntries == null)
{
_cacheMRUList = new LinkedList<TParam>();
_cacheEntries = new Dictionary<TParam, (LinkedListNode<TParam> param, TVal value)>(_comparer);
return;
}
if (_cacheEntries.Count == 0)
{
return;
}
// by moving to a temp variable
// can free up the lock quicker for other calls to MRU
if (_releaseFunction != null)
{
// no point doing it, if nothing to release
oldCacheToClear = _cacheEntries;
}
_cacheMRUList = new LinkedList<TParam>();
_cacheEntries = new Dictionary<TParam, (LinkedListNode<TParam> param, TVal value)>(_comparer);
}
if (oldCacheToClear == null)
{
return;
}
if (aggregateReleaseExceptions)
{
var exceptions = new List<Exception>(oldCacheToClear.Count);
foreach (var item in oldCacheToClear)
{
try
{
_releaseFunction?.Invoke(item.Value.value);
}
catch (Exception e)
{
exceptions.Add(e);
}
}
if (exceptions.Count > 0)
{
throw new AggregateException("Exceptions throw during MRU Cache Invalidate All Item Release.", exceptions);
}
return;
}
// release mechanism that will throw on first failure.
// but they've still been removed from the active cache
// as the cache field was reassigned.
foreach (var item in oldCacheToClear)
{
_releaseFunction?.Invoke(item.Value.value);
}
}
/// <summary>
/// Returns all values currently in the cache.
/// </summary>
/// <returns>The values in the cache.</returns>
public IEnumerable<TVal> CachedValues()
{
lock (_lockObject)
{
return _cacheEntries.Select(x => x.Value.value);
}
}
private void MaintainCache()
{
while (_cacheMRUList.Count > _maxCacheSize)
{
var to_remove = _cacheMRUList.Last.Value;
_releaseFunction?.Invoke(_cacheEntries[to_remove].value);
_cacheEntries.Remove(_cacheMRUList.Last.Value);
_cacheMRUList.RemoveLast();
}
}
private void RefreshEntry(LinkedListNode<TParam> item)
{
// only juggle entries if more than 1 of them.
if (_cacheEntries.Count > 1)
{
_cacheMRUList.Remove(item);
_cacheMRUList.AddFirst(item);
}
}
[ContractInvariantMethod]
private void Invariants()
{
Contract.Invariant(_cacheEntries.Count == _cacheMRUList.Count);
Contract.Invariant(_cacheEntries.Count <= _maxCacheSize);
}
}
}