Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 14f87bc

Browse files
committedJun 6, 2022
Migrate from aioredis to redis with asyncio support
Add test for redis type Fix imports from wrong module (for tests_sync)
1 parent c2dbcfc commit 14f87bc

15 files changed

+82
-104
lines changed
 

‎aredis_om/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .async_redis import redis # isort:skip
12
from .checks import has_redis_json, has_redisearch
23
from .connections import get_redis_connection
34
from .model.migrations.migrator import MigrationError, Migrator

‎aredis_om/async_redis.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from redis import asyncio as redis

‎aredis_om/connections.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import os
22

3-
import aioredis
3+
from . import redis
44

55

66
URL = os.environ.get("REDIS_OM_URL", None)
77

88

9-
def get_redis_connection(**kwargs) -> aioredis.Redis:
9+
def get_redis_connection(**kwargs) -> redis.Redis:
1010
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
1111
# environment variable, we'll create the Redis client from the URL.
1212
url = kwargs.pop("url", URL)
1313
if url:
14-
return aioredis.Redis.from_url(url, **kwargs)
14+
return redis.Redis.from_url(url, **kwargs)
1515

1616
# Decode from UTF-8 by default
1717
if "decode_responses" not in kwargs:
1818
kwargs["decode_responses"] = True
19-
return aioredis.Redis(**kwargs)
19+
return redis.Redis(**kwargs)

‎aredis_om/model/migrations/migrator.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from enum import Enum
55
from typing import List, Optional
66

7-
from aioredis import Redis, ResponseError
7+
from ... import redis
88

99

1010
log = logging.getLogger(__name__)
@@ -39,18 +39,19 @@ def schema_hash_key(index_name):
3939
return f"{index_name}:hash"
4040

4141

42-
async def create_index(redis: Redis, index_name, schema, current_hash):
43-
db_number = redis.connection_pool.connection_kwargs.get("db")
42+
async def create_index(conn: redis.Redis, index_name, schema, current_hash):
43+
db_number = conn.connection_pool.connection_kwargs.get("db")
4444
if db_number and db_number > 0:
4545
raise MigrationError(
4646
"Creating search indexes is only supported in database 0. "
4747
f"You attempted to create an index in database {db_number}"
4848
)
4949
try:
50-
await redis.execute_command(f"ft.info {index_name}")
51-
except ResponseError:
52-
await redis.execute_command(f"ft.create {index_name} {schema}")
53-
await redis.set(schema_hash_key(index_name), current_hash)
50+
await conn.execute_command(f"ft.info {index_name}")
51+
except redis.ResponseError:
52+
await conn.execute_command(f"ft.create {index_name} {schema}")
53+
# TODO: remove "type: ignore" when type stubs will be fixed
54+
await conn.set(schema_hash_key(index_name), current_hash) # type: ignore
5455
else:
5556
log.info("Index already exists, skipping. Index hash: %s", index_name)
5657

@@ -67,7 +68,7 @@ class IndexMigration:
6768
schema: str
6869
hash: str
6970
action: MigrationAction
70-
redis: Redis
71+
conn: redis.Redis
7172
previous_hash: Optional[str] = None
7273

7374
async def run(self):
@@ -78,14 +79,14 @@ async def run(self):
7879

7980
async def create(self):
8081
try:
81-
await create_index(self.redis, self.index_name, self.schema, self.hash)
82-
except ResponseError:
82+
await create_index(self.conn, self.index_name, self.schema, self.hash)
83+
except redis.ResponseError:
8384
log.info("Index already exists: %s", self.index_name)
8485

8586
async def drop(self):
8687
try:
87-
await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}")
88-
except ResponseError:
88+
await self.conn.execute_command(f"FT.DROPINDEX {self.index_name}")
89+
except redis.ResponseError:
8990
log.info("Index does not exist: %s", self.index_name)
9091

9192

@@ -105,7 +106,7 @@ async def detect_migrations(self):
105106

106107
for name, cls in model_registry.items():
107108
hash_key = schema_hash_key(cls.Meta.index_name)
108-
redis = cls.db()
109+
conn = cls.db()
109110
try:
110111
schema = cls.redisearch_schema()
111112
except NotImplementedError:
@@ -114,21 +115,21 @@ async def detect_migrations(self):
114115
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec
115116

116117
try:
117-
await redis.execute_command("ft.info", cls.Meta.index_name)
118-
except ResponseError:
118+
await conn.execute_command("ft.info", cls.Meta.index_name)
119+
except redis.ResponseError:
119120
self.migrations.append(
120121
IndexMigration(
121122
name,
122123
cls.Meta.index_name,
123124
schema,
124125
current_hash,
125126
MigrationAction.CREATE,
126-
redis,
127+
conn,
127128
)
128129
)
129130
continue
130131

131-
stored_hash = await redis.get(hash_key)
132+
stored_hash = await conn.get(hash_key)
132133
schema_out_of_date = current_hash != stored_hash
133134

134135
if schema_out_of_date:
@@ -140,7 +141,7 @@ async def detect_migrations(self):
140141
schema,
141142
current_hash,
142143
MigrationAction.DROP,
143-
redis,
144+
conn,
144145
stored_hash,
145146
)
146147
)
@@ -151,7 +152,7 @@ async def detect_migrations(self):
151152
schema,
152153
current_hash,
153154
MigrationAction.CREATE,
154-
redis,
155+
conn,
155156
stored_hash,
156157
)
157158
)

‎aredis_om/model/model.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
no_type_check,
2525
)
2626

27-
import aioredis
28-
from aioredis.client import Pipeline
2927
from pydantic import BaseModel, validator
3028
from pydantic.fields import FieldInfo as PydanticFieldInfo
3129
from pydantic.fields import ModelField, Undefined, UndefinedType
@@ -35,9 +33,10 @@
3533
from typing_extensions import Protocol, get_args, get_origin
3634
from ulid import ULID
3735

36+
from .. import redis
3837
from ..checks import has_redis_json, has_redisearch
3938
from ..connections import get_redis_connection
40-
from ..unasync_util import ASYNC_MODE
39+
from ..util import ASYNC_MODE
4140
from .encoders import jsonable_encoder
4241
from .render_tree import render_tree
4342
from .token_escaper import TokenEscaper
@@ -975,7 +974,7 @@ class BaseMeta(Protocol):
975974
global_key_prefix: str
976975
model_key_prefix: str
977976
primary_key_pattern: str
978-
database: aioredis.Redis
977+
database: redis.Redis
979978
primary_key: PrimaryKey
980979
primary_key_creator_cls: Type[PrimaryKeyCreator]
981980
index_name: str
@@ -994,7 +993,7 @@ class DefaultMeta:
994993
global_key_prefix: Optional[str] = None
995994
model_key_prefix: Optional[str] = None
996995
primary_key_pattern: Optional[str] = None
997-
database: Optional[aioredis.Redis] = None
996+
database: Optional[redis.Redis] = None
998997
primary_key: Optional[PrimaryKey] = None
999998
primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
1000999
index_name: Optional[str] = None
@@ -1127,10 +1126,14 @@ async def update(self, **field_values):
11271126
"""Update this model instance with the specified key-value pairs."""
11281127
raise NotImplementedError
11291128

1130-
async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
1129+
async def save(
1130+
self, pipeline: Optional[redis.client.Pipeline] = None
1131+
) -> "RedisModel":
11311132
raise NotImplementedError
11321133

1133-
async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None):
1134+
async def expire(
1135+
self, num_seconds: int, pipeline: Optional[redis.client.Pipeline] = None
1136+
):
11341137
if pipeline is None:
11351138
db = self.db()
11361139
else:
@@ -1241,7 +1244,7 @@ def get_annotations(cls):
12411244
async def add(
12421245
cls,
12431246
models: Sequence["RedisModel"],
1244-
pipeline: Optional[Pipeline] = None,
1247+
pipeline: Optional[redis.client.Pipeline] = None,
12451248
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
12461249
) -> Sequence["RedisModel"]:
12471250
if pipeline is None:
@@ -1301,7 +1304,9 @@ def __init_subclass__(cls, **kwargs):
13011304
f"HashModels cannot index dataclass fields. Field: {name}"
13021305
)
13031306

1304-
async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
1307+
async def save(
1308+
self, pipeline: Optional[redis.client.Pipeline] = None
1309+
) -> "HashModel":
13051310
self.check()
13061311
if pipeline is None:
13071312
db = self.db()
@@ -1473,7 +1478,9 @@ def __init__(self, *args, **kwargs):
14731478
)
14741479
super().__init__(*args, **kwargs)
14751480

1476-
async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
1481+
async def save(
1482+
self, pipeline: Optional[redis.client.Pipeline] = None
1483+
) -> "JsonModel":
14771484
self.check()
14781485
if pipeline is None:
14791486
db = self.db()

‎aredis_om/sync_redis.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import redis

‎aredis_om/unasync_util.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

‎aredis_om/util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import inspect
2+
3+
4+
def is_async_mode():
5+
async def f():
6+
"""Unasync transforms async functions in sync functions"""
7+
return None
8+
9+
return inspect.iscoroutinefunction(f)
10+
11+
12+
ASYNC_MODE = is_async_mode()

‎make_sync.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
ADDITIONAL_REPLACEMENTS = {
77
"aredis_om": "redis_om",
8-
"aioredis": "redis",
8+
"async_redis": "sync_redis",
99
":tests.": ":tests_sync.",
1010
"pytest_asyncio": "pytest",
1111
"py_test_mark_asyncio": "py_test_mark_sync",
@@ -26,11 +26,12 @@ def main():
2626
),
2727
]
2828
filepaths = []
29-
for root, _, filenames in os.walk(
30-
Path(__file__).absolute().parent
31-
):
29+
for root, _, filenames in os.walk(Path(__file__).absolute().parent):
3230
for filename in filenames:
33-
if filename.rpartition(".")[-1] in ("py", "pyi",):
31+
if filename.rpartition(".")[-1] in (
32+
"py",
33+
"pyi",
34+
):
3435
filepaths.append(os.path.join(root, filename))
3536

3637
unasync.unasync_files(filepaths, rules)

‎poetry.lock

Lines changed: 1 addition & 20 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ include=[
2323

2424
[tool.poetry.dependencies]
2525
python = "^3.7"
26-
redis = ">=3.5.3,<5.0.0"
27-
aioredis = "^2.0.0"
26+
redis = ">=4.2.0,<5.0.0"
2827
pydantic = "^1.8.2"
2928
click = "^8.0.1"
3029
six = "^1.16.0"

‎tests/test_hash_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
# We need to run this check as sync code (during tests) even in async mode
2424
# because we call it in the top-level module scope.
2525
from redis_om import has_redisearch
26-
from tests.conftest import py_test_mark_asyncio
26+
27+
from .conftest import py_test_mark_asyncio
28+
2729

2830
if not has_redisearch():
2931
pytestmark = pytest.mark.skip

‎tests/test_json_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
# We need to run this check as sync code (during tests) even in async mode
2626
# because we call it in the top-level module scope.
2727
from redis_om import has_redis_json
28-
from tests.conftest import py_test_mark_asyncio
28+
29+
from .conftest import py_test_mark_asyncio
30+
2931

3032
if not has_redis_json():
3133
pytestmark = pytest.mark.skip

‎tests/test_oss_redis_features.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from pydantic import ValidationError
1010

1111
from aredis_om import HashModel, Migrator, NotFoundError, RedisModelError
12-
from tests.conftest import py_test_mark_asyncio
12+
13+
from .conftest import py_test_mark_asyncio
1314

1415

1516
today = datetime.date.today()

‎tests/test_redis_type.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from aredis_om import redis
2+
from aredis_om.util import ASYNC_MODE
3+
4+
5+
def test_redis_type():
6+
import redis as sync_redis_module
7+
import redis.asyncio as async_redis_module
8+
9+
mapping = {True: async_redis_module, False: sync_redis_module}
10+
assert mapping[ASYNC_MODE] is redis

0 commit comments

Comments
 (0)
Please sign in to comment.