Skip to content

Commit

Permalink
feat: TupleField added in order to deserialize typing.Tuple as tuples…
Browse files Browse the repository at this point in the history
… rather than list. Closes marcosschroh#291 (marcosschroh#296)

Co-authored-by: Marcos Schroh <[email protected]>
  • Loading branch information
marcosschroh and marcosschroh authored Apr 21, 2023
1 parent e368977 commit 7b04717
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 24 deletions.
28 changes: 21 additions & 7 deletions dataclasses_avroschema/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 16,9 @@
import inflect
from faker import Faker
from pytz import utc
from typing_extensions import get_args
from typing_extensions import get_args, get_origin

from dataclasses_avroschema import schema_generator, serialization, types, utils

from . import field_utils
from . import field_utils, schema_generator, serialization, types, utils
from .exceptions import InvalidMap
from .types import CUSTOM_TYPES, JsonDict

Expand Down Expand Up @@ -255,7 253,7 @@ def get_avro_type(self) -> field_utils.PythonImmutableTypes:


@dataclasses.dataclass
class ListField(ContainerField):
class BaseListField(ContainerField):
items_type: typing.Any = None
internal_field: typing.Any = None

Expand Down Expand Up @@ -304,10 302,26 @@ def generate_items_type(self) -> typing.Any:

self.items_type = self.internal_field.get_avro_type()


@dataclasses.dataclass
class ListField(BaseListField):
def fake(self) -> typing.List:
return [self.internal_field.fake()]


@dataclasses.dataclass
class TupleField(BaseListField):
"""
This behaves on the same way as `ListField` with
as in avro schema does not exist `tuples`
The reason to have this is to generate a proper `fake`
"""

def fake(self) -> typing.Tuple:
return (self.internal_field.fake(),)


@dataclasses.dataclass
class DictField(ContainerField):
values_type: typing.Any = None
Expand Down Expand Up @@ -832,7 846,7 @@ def fake(self) -> decimal.Decimal:
}

CONTAINER_FIELDS_CLASSES = {
tuple: ListField,
tuple: TupleField,
list: ListField,
collections.abc.Sequence: ListField,
collections.abc.MutableSequence: ListField,
Expand Down Expand Up @@ -958,7 972,7 @@ def field_factory(
parent=parent,
)
elif isinstance(native_type, GenericAlias): # type: ignore
origin = native_type.__origin__
origin = get_origin(native_type)

if origin not in (
tuple,
Expand Down
3 changes: 3 additions & 0 deletions dataclasses_avroschema/schema_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 60,12 @@ def generate_documentation(self) -> typing.Optional[str]:
@dataclasses.dataclass
class AvroSchemaDefinition(BaseSchemaDefinition):
fields: typing.List[FieldType] = dataclasses.field(default_factory=list)
# mapping of field_name: FieldType
fields_map: typing.Dict[str, FieldType] = dataclasses.field(default_factory=dict)

def __post_init__(self) -> None:
self.fields = self.parse_dataclasses_fields()
self.fields_map = {field.name: field for field in self.fields}

def parse_dataclasses_fields(self) -> typing.List[FieldType]:
if utils.is_faust_model(self.klass):
Expand Down
42 changes: 30 additions & 12 deletions dataclasses_avroschema/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 8,8 @@
from dacite import Config, from_dict
from fastavro.validation import validate

from . import case
from .fields import EnumField, FieldType, RecordField, UnionField
from . import case, fields, serialization
from .schema_definition import AvroSchemaDefinition
from .serialization import deserialize, serialize, to_json
from .types import Fixed, JsonDict
from .utils import SchemaMetadata, is_dataclass_or_pydantic_model

Expand Down Expand Up @@ -97,7 95,7 @@ def avro_schema_to_python(
return json.loads(json.dumps(avro_schema))

@classmethod
def get_fields(cls: Type[CT]) -> List[FieldType]:
def get_fields(cls: Type[CT]) -> List[fields.FieldType]:
if cls.schema_def is None:
cls.generate_schema()
return cls.schema_def.fields # type: ignore
Expand All @@ -106,23 104,41 @@ def get_fields(cls: Type[CT]) -> List[FieldType]:
def _get_enum_type_map(cls: Type[CT]) -> Dict[str, enum.EnumMeta]:
enum_types = {}
for field_type in cls.get_fields():
if isinstance(field_type, EnumField):
if isinstance(field_type, fields.EnumField):
enum_types[field_type.name] = field_type.type
elif isinstance(field_type, UnionField):
elif isinstance(field_type, fields.UnionField):
for sub_type in field_type.type.__args__:
if inspect.isclass(sub_type) and issubclass(sub_type, enum.Enum):
enum_types[field_type.name] = sub_type
elif isinstance(field_type, RecordField):
elif isinstance(field_type, fields.RecordField):
enum_types.update(field_type.type._get_enum_type_map())
return enum_types

@classmethod
def _deserialize_list_type(cls: Type[CT], *, field_name: str, payload: Any) -> List:
data: List = []
for value in payload:
if isinstance(value, dict):
data.append(cls._deserialize_complex_types(value))
elif isinstance(value, list):
data.append(cls._deserialize_list_type(field_name=field_name, payload=value))
else:
data.append(value)
return data

@classmethod
def _deserialize_complex_types(cls: Type[CT], payload: Dict[str, Any]) -> Dict:
output = {}
output: Dict[str, Any] = {}
enum_type_map = cls._get_enum_type_map()
for field, value in payload.items():
if isinstance(value, dict):
output[field] = cls._deserialize_complex_types(value)
elif isinstance(value, list):
data = cls._deserialize_list_type(field_name=field, payload=value)
if isinstance(cls.schema_def.fields_map[field], fields.TupleField): # type: ignore
output[field] = tuple(data)
else:
output[field] = data
elif field in enum_type_map and isinstance(value, str):
try:
enum_field = enum_type_map[field]
Expand All @@ -148,8 164,10 @@ def standardize_custom_type(value: Any) -> Any:
return value.default
elif isinstance(value, dict):
return {k: AvroModel.standardize_custom_type(v) for k, v in value.items()}
elif isinstance(value, (list, tuple)):
elif isinstance(value, list):
return [AvroModel.standardize_custom_type(v) for v in value]
elif isinstance(value, tuple):
return tuple(AvroModel.standardize_custom_type(v) for v in value)
elif issubclass(type(value), enum.Enum):
return value.value
return value
Expand All @@ -162,7 180,7 @@ def asdict(self) -> JsonDict:
def serialize(self, serialization_type: str = AVRO) -> bytes:
schema = self.avro_schema_to_python()

return serialize(self.asdict(), schema, serialization_type=serialization_type)
return serialization.serialize(self.asdict(), schema, serialization_type=serialization_type)

@classmethod
def deserialize(
Expand All @@ -177,7 195,7 @@ def deserialize(
writer_schema: JsonDict = writer_schema.avro_schema_to_python() # type: ignore

schema = cls.avro_schema_to_python()
payload = deserialize(
payload = serialization.deserialize(
data, schema, serialization_type=serialization_type, writer_schema=writer_schema # type: ignore
)
output = cls._deserialize_complex_types(payload)
Expand All @@ -200,7 218,7 @@ def to_dict(self) -> JsonDict:
return self.asdict()

def to_json(self) -> str:
data = to_json(self.asdict())
data = serialization.to_json(self.asdict())
return json.dumps(data)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions dataclasses_avroschema/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 138,8 @@ def serialize_value(*, value: typing.Any) -> typing.Any:
value = str(value)
elif isinstance(value, dict):
value = to_json(value)
elif isinstance(value, list):
value = [serialize_value(value=item) for item in value]
elif isinstance(value, (list, tuple)):
value = type(value)(serialize_value(value=item) for item in value)

return value

Expand Down
21 changes: 20 additions & 1 deletion tests/fake/test_fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 113,27 @@ class User(AvroModel):
age: int
addresses: typing.List[Address]

assert isinstance(User.fake(), User)
user = User.fake()
assert isinstance(user, User)
assert User.avro_schema()


def test_fake_one_to_many_with_tuples() -> None:
"""
Test schema relationship one-to-many
"""

class Address(AvroModel):
street: str
street_number: int

class User(AvroModel):
addresses: typing.Tuple[Address, ...]

user = User.fake()
assert isinstance(user, User)
assert User.avro_schema()
assert isinstance(user.addresses, tuple)


def test_fake_one_to_many_map_relationship() -> None:
Expand Down
14 changes: 12 additions & 2 deletions tests/serialization/test_nested_schema_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,26 72,36 @@ class User(AvroModel):
name: str
age: int
addresses: typing.List[Address]
addresses_as_tuple: typing.Tuple[Address, ...]

created_at = datetime.datetime(2019, 10, 12, 17, 57, 42, tzinfo=datetime.timezone.utc)
address_data = {"street": "test", "street_number": 10, "created_at": created_at}

address = Address(**address_data)
second_address = Address(**address_data)

data_user = {
"name": "john",
"age": 20,
"addresses": [address],
"addresses_as_tuple": (
address,
second_address,
),
}

user = User(**data_user)

avro_binary = b"\x08john(\x02\x08test\x14\xe0\xd7\xf3\x91\xb8[\x00"
avro_json_binary = b'{"name": "john", "age": 20, "addresses": [{"street": "test", "street_number": 10, "created_at": 1570903062000}]}' # noqa
avro_binary = b"\x08john(\x02\x08test\x14\xe0\xd7\xf3\x91\xb8[\x00\x04\x08test\x14\xe0\xd7\xf3\x91\xb8[\x08test\x14\xe0\xd7\xf3\x91\xb8[\x00" # noqa
avro_json_binary = b'{"name": "john", "age": 20, "addresses": [{"street": "test", "street_number": 10, "created_at": 1570903062000}], "addresses_as_tuple": [{"street": "test", "street_number": 10, "created_at": 1570903062000}, {"street": "test", "street_number": 10, "created_at": 1570903062000}]}' # noqa
expected = {
"name": "john",
"age": 20,
"addresses": [{"street": "test", "street_number": 10, "created_at": created_at}],
"addresses_as_tuple": (
{"street": "test", "street_number": 10, "created_at": created_at},
{"street": "test", "street_number": 10, "created_at": created_at},
),
}

assert user.serialize() == avro_binary
Expand Down

0 comments on commit 7b04717

Please sign in to comment.