Skip to content

Add cached_method decorator for per-instance method caches #5570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion cirq-core/cirq/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import traceback
import warnings
from types import ModuleType
from typing import Any, Callable, Optional, Dict, Tuple, Type, Set
from typing import Any, Callable, Dict, Optional, overload, Set, Tuple, Type, TypeVar

import numpy as np
import pandas as pd
Expand All @@ -39,6 +39,54 @@
from backports.cached_property import cached_property # type: ignore[no-redef]


TFunc = TypeVar('TFunc', bound=Callable)


@overload
def cached_method(__func: TFunc) -> TFunc:
...


@overload
def cached_method(*, maxsize: int = 128) -> Callable[[TFunc], TFunc]:
...


def cached_method(method: Optional[TFunc] = None, *, maxsize: int = 128) -> Any:
"""Decorator that adds a per-instance LRU cache for a method.

Can be applied with or without parameters to customize the underlying cache:

@cached_method
def foo(self, name: str) -> int:
...

@cached_method(maxsize=1000)
def bar(self, name: str) -> int:
...
"""

def decorator(func):
cache_name = f'_{func.__name__}_cache'

@functools.wraps(func)
def wrapped(self, *args, **kwargs):
cached = getattr(self, cache_name, None)
if cached is None:

@functools.lru_cache(maxsize=maxsize)
def cached_func(*args, **kwargs):
return func(self, *args, **kwargs)

object.__setattr__(self, cache_name, cached_func)
cached = cached_func
return cached(*args, **kwargs)

return wrapped

return decorator if method is None else decorator(method)


def proper_repr(value: Any) -> str:
"""Overrides sympy and numpy returning repr strings that don't parse."""

Expand Down
34 changes: 33 additions & 1 deletion cirq-core/cirq/_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import dataclasses
import importlib
import logging
Expand All @@ -21,7 +22,7 @@
import types
import warnings
from types import ModuleType
from typing import Any, Callable, Optional
from typing import Any, Callable, Dict, Optional, Tuple
from importlib.machinery import ModuleSpec
from unittest import mock

Expand All @@ -35,6 +36,7 @@
import cirq.testing
from cirq._compat import (
block_overlapping_deprecation,
cached_method,
cached_property,
proper_repr,
dataclass_repr,
Expand Down Expand Up @@ -985,3 +987,33 @@ def bar(self):
bar2 = foo.bar
assert bar2 is bar
assert foo.bar_calls == 1


class Bar:
def __init__(self):
self.foo_calls: Dict[int, int] = collections.Counter()
self.bar_calls: Dict[int, int] = collections.Counter()

@cached_method
def foo(self, n: int) -> Tuple[int, int]:
self.foo_calls[n] += 1
return (id(self), n)

@cached_method(maxsize=1)
def bar(self, n: int) -> Tuple[int, int]:
self.bar_calls[n] += 1
return (id(self), 2 * n)


def test_cached_method():
b = Bar()
assert b.foo(123) == b.foo(123) == b.foo(123) == (id(b), 123)
assert b.foo(234) == b.foo(234) == b.foo(234) == (id(b), 234)
assert b.foo_calls == {123: 1, 234: 1}

assert b.bar(123) == b.bar(123) == (id(b), 123 * 2)
assert b.bar_calls == {123: 1}
assert b.bar(234) == b.bar(234) == (id(b), 234 * 2)
assert b.bar_calls == {123: 1, 234: 1}
assert b.bar(123) == b.bar(123) == (id(b), 123 * 2)
assert b.bar_calls == {123: 2, 234: 1}