Skip to content

add dataclass arguments #2437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/datamodel_code_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def generate( # noqa: PLR0912, PLR0913, PLR0914, PLR0915
no_alias: bool = False,
formatters: list[Formatter] = DEFAULT_FORMATTERS,
parent_scoped_naming: bool = False,
dataclass_arguments: dict[str, Any] | None = None,
) -> None:
remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict()
if isinstance(input_, str):
Expand All @@ -304,6 +305,10 @@ def generate( # noqa: PLR0912, PLR0913, PLR0914, PLR0915
else:
input_text = None

dataclass_arguments = dict(dataclass_arguments or {})
dataclass_arguments.setdefault("frozen", frozen_dataclasses)
dataclass_arguments.setdefault("kw_only", keyword_only)

if isinstance(input_, Path) and not input_.is_absolute():
input_ = input_.expanduser().resolve()
if input_file_type == InputFileType.Auto:
Expand Down Expand Up @@ -489,6 +494,7 @@ def get_header_and_first_line(csv_file: IO[str]) -> dict[str, Any]:
formatters=formatters,
encoding=encoding,
parent_scoped_naming=parent_scoped_naming,
dataclass_arguments=dataclass_arguments,
**kwargs,
)

Expand Down
2 changes: 2 additions & 0 deletions src/datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def validate_root(cls, values: Any) -> Any: # noqa: N805
output_datetime_class: Optional[DatetimeClassType] = None # noqa: UP045
keyword_only: bool = False
frozen_dataclasses: bool = False
dataclass_arguments: dict[str, Any] | None = None
no_alias: bool = False
formatters: list[Formatter] = DEFAULT_FORMATTERS
parent_scoped_naming: bool = False
Expand Down Expand Up @@ -531,6 +532,7 @@ def main(args: Sequence[str] | None = None) -> Exit: # noqa: PLR0911, PLR0912,
no_alias=config.no_alias,
formatters=config.formatters,
parent_scoped_naming=config.parent_scoped_naming,
dataclass_arguments=config.dataclass_arguments,
)
except InvalidClassNameError as e:
print(f"{e} You have to set `--class-name` option", file=sys.stderr) # noqa: T201
Expand Down
11 changes: 11 additions & 0 deletions src/datamodel_code_generator/arguments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import locale
from argparse import ArgumentParser, FileType, HelpFormatter, Namespace
from operator import attrgetter
Expand Down Expand Up @@ -166,6 +167,16 @@ def start_section(self, heading: str | None) -> None:
action="store_true",
default=None,
)
model_options.add_argument(
"--dataclass-arguments",
type=json.loads,
default=None,
help=(
"Custom dataclass arguments as a JSON dictionary, "
'e.g. \'{"frozen": true, "kw_only": true}\'. '
"Overrides --frozen-dataclasses and similar flags."
),
)
model_options.add_argument(
"--reuse-model",
help="Reuse models on the field when a module has the model with the same content",
Expand Down
9 changes: 6 additions & 3 deletions src/datamodel_code_generator/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,11 @@ def __init__( # noqa: PLR0913
keyword_only: bool = False,
frozen: bool = False,
treat_dot_as_module: bool = False,
dataclass_arguments: dict[str, Any] | None = None,
) -> None:
self.keyword_only = keyword_only
self.frozen = frozen
self.dataclass_arguments = dataclass_arguments if dataclass_arguments is not None else {}
if not self.TEMPLATE_FILE_PATH:
msg = "TEMPLATE_FILE_PATH is undefined"
raise Exception(msg) # noqa: TRY002
Expand Down Expand Up @@ -428,14 +430,15 @@ def path(self) -> str:
return self.reference.path

def render(self, *, class_name: str | None = None) -> str:
return self._render(
context: dict[str, Any] = dict(
class_name=class_name or self.class_name,
fields=self.fields,
decorators=self.decorators,
base_class=self.base_class,
methods=self.methods,
description=self.description,
keyword_only=self.keyword_only,
frozen=self.frozen,
dataclass_arguments=self.dataclass_arguments,
**self.extra_template_data,
)

return self._render(**context)
9 changes: 9 additions & 0 deletions src/datamodel_code_generator/model/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__( # noqa: PLR0913
keyword_only: bool = False,
frozen: bool = False,
treat_dot_as_module: bool = False,
dataclass_arguments: dict[str, Any] | None = None,
) -> None:
super().__init__(
reference=reference,
Expand All @@ -73,6 +74,14 @@ def __init__( # noqa: PLR0913
frozen=frozen,
treat_dot_as_module=treat_dot_as_module,
)
if dataclass_arguments is not None:
self.dataclass_arguments = dataclass_arguments
else:
self.dataclass_arguments = {}
if frozen:
self.dataclass_arguments["frozen"] = True
if keyword_only:
self.dataclass_arguments["kw_only"] = True


class DataModelField(DataModelFieldBase):
Expand Down
15 changes: 9 additions & 6 deletions src/datamodel_code_generator/model/template/dataclass.jinja2
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}
{%- set args = [] %}
{%- for k, v in (dataclass_arguments or {}).items() %}
{%- if v is not none and v is not false %}
{%- set _ = args.append(k ~ '=' ~ (v|pprint)) %}
{%- endif %}
{%- endfor %}
{%- if args %}
@dataclass({{ args | join(', ') }})
{%- else %}
@dataclass
{%- if keyword_only or frozen -%}
(
{%- if keyword_only -%}kw_only=True{%- endif -%}
{%- if keyword_only and frozen -%}, {% endif -%}
{%- if frozen -%}frozen=True{%- endif -%}
)
{%- endif %}
{%- if base_class %}
class {{ class_name }}({{ base_class }}):
Expand Down
2 changes: 2 additions & 0 deletions src/datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def __init__( # noqa: PLR0913, PLR0915
no_alias: bool = False,
formatters: list[Formatter] = DEFAULT_FORMATTERS,
parent_scoped_naming: bool = False,
dataclass_arguments: dict[str, Any] | None = None,
) -> None:
self.keyword_only = keyword_only
self.frozen_dataclasses = frozen_dataclasses
Expand Down Expand Up @@ -435,6 +436,7 @@ def __init__( # noqa: PLR0913, PLR0915
self.use_title_as_name: bool = use_title_as_name
self.use_operation_id_as_name: bool = use_operation_id_as_name
self.use_unique_items_as_set: bool = use_unique_items_as_set
self.dataclass_arguments = dataclass_arguments or {}

if base_path:
self.base_path = base_path
Expand Down
18 changes: 16 additions & 2 deletions src/datamodel_code_generator/parser/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__( # noqa: PLR0913
no_alias: bool = False,
formatters: list[Formatter] = DEFAULT_FORMATTERS,
parent_scoped_naming: bool = False,
dataclass_arguments: dict[str, Any] | None = None,
) -> None:
super().__init__(
source=source,
Expand Down Expand Up @@ -235,6 +236,7 @@ def __init__( # noqa: PLR0913
no_alias=no_alias,
formatters=formatters,
parent_scoped_naming=parent_scoped_naming,
dataclass_arguments=dataclass_arguments,
)

self.data_model_scalar_type = data_model_scalar_type
Expand Down Expand Up @@ -289,10 +291,21 @@ def _resolve_types(self, paths: list[str], schema: graphql.GraphQLSchema) -> Non
self.support_graphql_types[resolved_type].append(type_)

def _create_data_model(self, model_type: type[DataModel] | None = None, **kwargs: Any) -> DataModel:
"""Create data model instance with conditional frozen parameter for DataClass."""
data_model_class = model_type or self.data_model_type
if issubclass(data_model_class, DataClass):
kwargs["frozen"] = self.frozen_dataclasses
dataclass_arguments = {}
if hasattr(self, "frozen_dataclasses"):
dataclass_arguments["frozen"] = self.frozen_dataclasses
if hasattr(self, "keyword_only"):
dataclass_arguments["kw_only"] = self.keyword_only
existing = kwargs.pop("dataclass_arguments", None)
if existing:
dataclass_arguments.update(existing)
kwargs["dataclass_arguments"] = dataclass_arguments
kwargs.pop("frozen", None)
kwargs.pop("keyword_only", None)
else:
kwargs.pop("dataclass_arguments", None)
return data_model_class(**kwargs)

def _typename_field(self, name: str) -> DataModelFieldBase:
Expand Down Expand Up @@ -468,6 +481,7 @@ def parse_object_like(
description=obj.description,
keyword_only=self.keyword_only,
treat_dot_as_module=self.treat_dot_as_module,
dataclass_arguments=self.dataclass_arguments,
)
self.results.append(data_model_type)

Expand Down
19 changes: 17 additions & 2 deletions src/datamodel_code_generator/parser/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def __init__( # noqa: PLR0913
no_alias: bool = False,
formatters: list[Formatter] = DEFAULT_FORMATTERS,
parent_scoped_naming: bool = False,
dataclass_arguments: dict[str, Any] | None = None,
) -> None:
super().__init__(
source=source,
Expand Down Expand Up @@ -514,6 +515,7 @@ def __init__( # noqa: PLR0913
no_alias=no_alias,
formatters=formatters,
parent_scoped_naming=parent_scoped_naming,
dataclass_arguments=dataclass_arguments,
)

self.remote_object_cache: DefaultPutDict[str, dict[str, Any]] = DefaultPutDict()
Expand Down Expand Up @@ -702,10 +704,21 @@ def parse_one_of(self, name: str, obj: JsonSchemaObject, path: list[str]) -> lis
return self.parse_combined_schema(name, obj, path, "oneOf")

def _create_data_model(self, model_type: type[DataModel] | None = None, **kwargs: Any) -> DataModel:
"""Create data model instance with conditional frozen parameter for DataClass."""
data_model_class = model_type or self.data_model_type
if issubclass(data_model_class, DataClass):
kwargs["frozen"] = self.frozen_dataclasses
dataclass_arguments = {}
if hasattr(self, "frozen_dataclasses"):
dataclass_arguments["frozen"] = self.frozen_dataclasses
if hasattr(self, "keyword_only"):
dataclass_arguments["kw_only"] = self.keyword_only
existing = kwargs.pop("dataclass_arguments", None)
if existing:
dataclass_arguments.update(existing)
kwargs["dataclass_arguments"] = dataclass_arguments
kwargs.pop("frozen", None)
kwargs.pop("keyword_only", None)
else:
kwargs.pop("dataclass_arguments", None)
return data_model_class(**kwargs)

def _parse_object_common_part( # noqa: PLR0913, PLR0917
Expand Down Expand Up @@ -767,6 +780,7 @@ def _parse_object_common_part( # noqa: PLR0913, PLR0917
description=obj.description if self.use_schema_description else None,
keyword_only=self.keyword_only,
treat_dot_as_module=self.treat_dot_as_module,
dataclass_arguments=self.dataclass_arguments,
)
self.results.append(data_model_type)

Expand Down Expand Up @@ -1008,6 +1022,7 @@ def parse_object(
nullable=obj.type_has_null,
keyword_only=self.keyword_only,
treat_dot_as_module=self.treat_dot_as_module,
dataclass_arguments=self.dataclass_arguments,
)
self.results.append(data_model_type)
return self.data_type(reference=reference)
Expand Down
3 changes: 3 additions & 0 deletions src/datamodel_code_generator/parser/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def __init__( # noqa: PLR0913
no_alias: bool = False,
formatters: list[Formatter] = DEFAULT_FORMATTERS,
parent_scoped_naming: bool = False,
dataclass_arguments: dict[str, Any] | None = None,
) -> None:
super().__init__(
source=source,
Expand Down Expand Up @@ -297,6 +298,7 @@ def __init__( # noqa: PLR0913
no_alias=no_alias,
formatters=formatters,
parent_scoped_naming=parent_scoped_naming,
dataclass_arguments=dataclass_arguments,
)
self.open_api_scopes: list[OpenAPIScope] = openapi_scopes or [OpenAPIScope.Schemas]

Expand Down Expand Up @@ -498,6 +500,7 @@ def parse_all_parameters(
custom_template_dir=self.custom_template_dir,
keyword_only=self.keyword_only,
treat_dot_as_module=self.treat_dot_as_module,
dataclass_arguments=self.dataclass_arguments,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Optional


@dataclass(kw_only=True, frozen=True)
@dataclass(frozen=True, kw_only=True)
class User:
name: str
age: int
Expand Down
83 changes: 83 additions & 0 deletions tests/model/dataclass/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,86 @@ def test_dataclass_kw_only_true_only() -> None:
# Verify frozen attribute is False (default)
assert dataclass.frozen is False
assert dataclass.keyword_only is True


def test_dataclass_legacy_keyword_only() -> None:
"""Test that legacy 'frozen' argument is supported if dataclass_arguments is not set."""
reference = Reference(path="TestModel", name="TestModel")
field = DataModelField(
name="field1",
data_type=DataType(type=Types.string),
required=True,
)

dataclass = DataClass(
reference=reference,
fields=[field],
keyword_only=True,
)

rendered = dataclass.render()
assert "@dataclass(kw_only=True)" in rendered


def test_dataclass_legacy_frozen() -> None:
"""Test that legacy 'frozen' argument is supported if dataclass_arguments is not set."""
reference = Reference(path="TestModel", name="TestModel")
field = DataModelField(
name="field1",
data_type=DataType(type=Types.string),
required=True,
)

dataclass = DataClass(
reference=reference,
fields=[field],
frozen=True,
)

rendered = dataclass.render()
assert "@dataclass(frozen=True)" in rendered


def test_dataclass_with_custom_dataclass_arguments() -> None:
"""Test that custom dataclass_arguments are rendered correctly."""
reference = Reference(path="TestModel", name="TestModel")
field = DataModelField(
name="field1",
data_type=DataType(type=Types.string),
required=True,
)

dataclass = DataClass(
reference=reference,
fields=[field],
dataclass_arguments={"slots": True, "repr": False, "order": True},
)

rendered = dataclass.render()
assert "@dataclass(slots=True, order=True)" in rendered
assert "repr=False" not in rendered


def test_dataclass_both_legacy_and_dataclass_arguments() -> None:
"""Test that dataclass_arguments take precedence over legacy flags."""
reference = Reference(path="TestModel", name="TestModel")
field = DataModelField(
name="field1",
data_type=DataType(type=Types.string),
required=True,
)

dataclass = DataClass(
reference=reference,
fields=[field],
frozen=True, # legacy flag
keyword_only=True, # legacy flag
dataclass_arguments={"frozen": False, "order": True},
)

rendered = dataclass.render()
assert "@dataclass(order=True)" in rendered
assert "@dataclass(frozen=False)" not in rendered
assert "@dataclass(frozen=True)" not in rendered
assert "@dataclass(kw_only=False)" not in rendered
assert "@dataclass(kw_only=True)" not in rendered
Loading