Skip to content

Commit 746dfb2

Browse files
alierenoalexmojakisamuelcolvin
authored
Implement parsing SQLAlchemy objects (#94)
* Implement parsing SQLAlchemy objects * Update devtools/prettier.py Co-authored-by: Alex Hall <[email protected]> * new MetaClass for checking sqlalchemy and some code improvements * removed json.dumps from _format_sqlalchemy_class function * Update devtools/prettier.py Co-authored-by: Alex Hall <[email protected]> * Update devtools/prettier.py Co-authored-by: Samuel Colvin <[email protected]> * add missing import and apply format * small optimisation to `_format_fields` Co-authored-by: Alex Hall <[email protected]> Co-authored-by: Samuel Colvin <[email protected]>
1 parent 66eaa80 commit 746dfb2

File tree

5 files changed

+71
-19
lines changed

5 files changed

+71
-19
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ jobs:
6868
env_vars: EXTRAS,PYTHON,OS
6969

7070
- name: uninstall extras
71-
run: pip uninstall -y multidict numpy pydantic asyncpg
71+
run: pip uninstall -y multidict numpy pydantic asyncpg sqlalchemy
7272

7373
- name: test without extras
7474
run: |

devtools/prettier.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from collections import OrderedDict
44
from collections.abc import Generator
55

6-
from .utils import DataClassType, LaxMapping, env_true, isatty
6+
from .utils import DataClassType, LaxMapping, SQLAlchemyClassType, env_true, isatty
77

88
__all__ = 'PrettyFormat', 'pformat', 'pprint'
99
MYPY = False
1010
if MYPY:
11-
from typing import Any, Iterable, Union
11+
from typing import Any, Iterable, Tuple, Union
1212

1313
PARENTHESES_LOOKUP = [
1414
(list, '[', ']'),
@@ -69,6 +69,7 @@ def __init__(
6969
# put this last as the check can be slow
7070
(LaxMapping, self._format_dict),
7171
(DataClassType, self._format_dataclass),
72+
(SQLAlchemyClassType, self._format_sqlalchemy_class),
7273
]
7374

7475
def __call__(self, value: 'Any', *, indent: int = 0, indent_first: bool = False, highlight: bool = False):
@@ -169,15 +170,7 @@ def _format_tuples(self, value: tuple, value_repr: str, indent_current: int, ind
169170
fields = getattr(value, '_fields', None)
170171
if fields:
171172
# named tuple
172-
self._stream.write(value.__class__.__name__ + '(\n')
173-
for field, v in zip(fields, value):
174-
self._stream.write(indent_new * self._c)
175-
if field: # field is falsy sometimes for odd things like call_args
176-
self._stream.write(str(field))
177-
self._stream.write('=')
178-
self._format(v, indent_new, False)
179-
self._stream.write(',\n')
180-
self._stream.write(indent_current * self._c + ')')
173+
self._format_fields(value, zip(fields, value), indent_current, indent_new)
181174
else:
182175
# normal tuples are just like other similar iterables
183176
return self._format_list_like(value, value_repr, indent_current, indent_new)
@@ -231,13 +224,15 @@ def _format_bytearray(self, value: 'Any', _: str, indent_current: int, indent_ne
231224
def _format_dataclass(self, value: 'Any', _: str, indent_current: int, indent_new: int):
232225
from dataclasses import asdict
233226

234-
before_ = indent_new * self._c
235-
self._stream.write(f'{value.__class__.__name__}(\n')
236-
for k, v in asdict(value).items():
237-
self._stream.write(f'{before_}{k}=')
238-
self._format(v, indent_new, False)
239-
self._stream.write(',\n')
240-
self._stream.write(indent_current * self._c + ')')
227+
self._format_fields(value, asdict(value).items(), indent_current, indent_new)
228+
229+
def _format_sqlalchemy_class(self, value: 'Any', _: str, indent_current: int, indent_new: int):
230+
fields = [
231+
(field, getattr(value, field))
232+
for field in dir(value)
233+
if not (field.startswith('_') or field in ['metadata', 'registry'])
234+
]
235+
self._format_fields(value, fields, indent_current, indent_new)
241236

242237
def _format_raw(self, _: 'Any', value_repr: str, indent_current: int, indent_new: int):
243238
lines = value_repr.splitlines(True)
@@ -256,6 +251,18 @@ def _format_raw(self, _: 'Any', value_repr: str, indent_current: int, indent_new
256251
else:
257252
self._stream.write(value_repr)
258253

254+
def _format_fields(
255+
self, value: 'Any', fields: 'Iterable[Tuple[str, Any]]', indent_current: int, indent_new: int
256+
) -> None:
257+
self._stream.write(f'{value.__class__.__name__}(\n')
258+
for field, v in fields:
259+
self._stream.write(indent_new * self._c)
260+
if field: # field is falsy sometimes for odd things like call_args
261+
self._stream.write(f'{field}=')
262+
self._format(v, indent_new, False)
263+
self._stream.write(',\n')
264+
self._stream.write(indent_current * self._c + ')')
265+
259266

260267
pformat = PrettyFormat()
261268
force_highlight = env_true('PY_DEVTOOLS_HIGHLIGHT', None)

devtools/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,17 @@ def __instancecheck__(self, instance: 'Any') -> bool:
135135

136136
class DataClassType(metaclass=MetaDataClassType):
137137
pass
138+
139+
140+
class MetaSQLAlchemyClassType(type):
141+
def __instancecheck__(self, instance: 'Any') -> bool:
142+
try:
143+
from sqlalchemy.ext.declarative import DeclarativeMeta
144+
except ImportError:
145+
return False
146+
else:
147+
return isinstance(instance.__class__, DeclarativeMeta)
148+
149+
150+
class SQLAlchemyClassType(metaclass=MetaSQLAlchemyClassType):
151+
pass

tests/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ pydantic
99
asyncpg
1010
numpy
1111
multidict
12+
sqlalchemy

tests/test_prettier.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@
2727
except ImportError:
2828
Record = None
2929

30+
try:
31+
from sqlalchemy import Column, Integer, String
32+
from sqlalchemy.ext.declarative import declarative_base
33+
SQLAlchemyBase = declarative_base()
34+
except ImportError:
35+
SQLAlchemyBase = None
36+
3037

3138
def test_dict():
3239
v = pformat({1: 2, 3: 4})
@@ -424,3 +431,26 @@ def test_asyncpg_record():
424431

425432
def test_dict_type():
426433
assert pformat(type({1: 2})) == "<class 'dict'>"
434+
435+
436+
@pytest.mark.skipif(SQLAlchemyBase is None, reason='sqlalchemy not installed')
437+
def test_sqlalchemy_object():
438+
class User(SQLAlchemyBase):
439+
__tablename__ = 'users'
440+
id = Column(Integer, primary_key=True)
441+
name = Column(String)
442+
fullname = Column(String)
443+
nickname = Column(String)
444+
user = User()
445+
user.id = 1
446+
user.name = "Test"
447+
user.fullname = "Test For SQLAlchemy"
448+
user.nickname = "test"
449+
assert pformat(user) == (
450+
"User(\n"
451+
" fullname='Test For SQLAlchemy',\n"
452+
" id=1,\n"
453+
" name='Test',\n"
454+
" nickname='test',\n"
455+
")"
456+
)

0 commit comments

Comments
 (0)