diff --git a/pulsar-client-cpp/python/pulsar/schema/definition.py b/pulsar-client-cpp/python/pulsar/schema/definition.py index 41c094dcd215f..6db71d86fde9f 100644 --- a/pulsar-client-cpp/python/pulsar/schema/definition.py +++ b/pulsar-client-cpp/python/pulsar/schema/definition.py @@ -97,15 +97,23 @@ def __init__(self, default=None, required_default=False, required=False, *args, @classmethod def schema(cls): + return cls.schema_info(set()) + + @classmethod + def schema_info(cls, defined_names): + if cls.__name__ in defined_names: + return cls.__name__ + + defined_names.add(cls.__name__) schema = { 'name': str(cls.__name__), 'type': 'record', 'fields': [] } - for name in sorted(cls._fields.keys()): field = cls._fields[name] - field_type = field.schema() if field._required else ['null', field.schema()] + field_type = field.schema_info(defined_names) \ + if field._required else ['null', field.schema_info(defined_names)] schema['fields'].append({ 'name': name, 'type': field_type, @@ -198,6 +206,9 @@ def schema(self): # For primitive types, the schema would just be the type itself return self.type() + def schema_info(self, defined_names): + return self.type() + def default(self): return self._default @@ -347,6 +358,9 @@ def python_type(self): return self.enum_type def validate_type(self, name, val): + if val is None: + return None + if type(val) is str: # The enum was passed as a string, we need to check it against the possible values if val in self.enum_type.__members__: @@ -367,6 +381,12 @@ def validate_type(self, name, val): return val def schema(self): + return self.schema_info(set()) + + def schema_info(self, defined_names): + if self.enum_type.__name__ in defined_names: + return self.enum_type.__name__ + defined_names.add(self.enum_type.__name__) return { 'type': self.type(), 'name': self.enum_type.__name__, @@ -393,6 +413,9 @@ def python_type(self): return list def validate_type(self, name, val): + if val is None: + return None + super(Array, self).validate_type(name, val) for x in val: @@ -402,9 +425,12 @@ def validate_type(self, name, val): return val def schema(self): + return self.schema_info(set()) + + def schema_info(self, defined_names): return { 'type': self.type(), - 'items': self.array_type.schema() if isinstance(self.array_type, (Array, Map, Record)) + 'items': self.array_type.schema_info(defined_names) if isinstance(self.array_type, (Array, Map, Record)) else self.array_type.type() } @@ -428,6 +454,9 @@ def python_type(self): return dict def validate_type(self, name, val): + if val is None: + return None + super(Map, self).validate_type(name, val) for k, v in val.items(): @@ -440,9 +469,12 @@ def validate_type(self, name, val): return val def schema(self): + return self.schema_info(set()) + + def schema_info(self, defined_names): return { 'type': self.type(), - 'values': self.value_type.schema() if isinstance(self.value_type, (Array, Map, Record)) + 'values': self.value_type.schema_info(defined_names) if isinstance(self.value_type, (Array, Map, Record)) else self.value_type.type() } diff --git a/pulsar-client-cpp/python/schema_test.py b/pulsar-client-cpp/python/schema_test.py index 7ec0c9a0499bd..35d9316c983c8 100755 --- a/pulsar-client-cpp/python/schema_test.py +++ b/pulsar-client-cpp/python/schema_test.py @@ -19,6 +19,8 @@ # from unittest import TestCase, main + +import fastavro import pulsar from pulsar.schema import * from enum import Enum @@ -46,6 +48,7 @@ class Example(Record): h = Bytes() i = Map(String()) + fastavro.parse_schema(Example.schema()) self.assertEqual(Example.schema(), { "name": "Example", "type": "record", @@ -84,6 +87,7 @@ class Example(Record): sub = MySubRecord # Test with class sub2 = MySubRecord() # Test with instance + fastavro.parse_schema(Example.schema()) self.assertEqual(Example.schema(), { "name": "Example", "type": "record", @@ -99,13 +103,7 @@ class Example(Record): }] }, {"name": "sub2", - "type": ["null", { - "name": "MySubRecord", - "type": "record", - "fields": [{"name": "x", "type": ["null", "int"]}, - {"name": "y", "type": ["null", "long"]}, - {"name": "z", "type": ["null", "string"]}] - }] + "type": ["null", 'MySubRecord'] } ] }) @@ -896,12 +894,22 @@ class NestedObj4(Record): na4 = String() nb4 = Integer() + class Color(Enum): + red = 1 + green = 2 + blue = 3 + class ComplexRecord(Record): a = Integer() b = Integer() + color = Color + color2 = Color nested = NestedObj2() + nested2 = NestedObj2() mapNested = Map(NestedObj3()) + mapNested2 = Map(NestedObj3()) arrayNested = Array(NestedObj4()) + arrayNested2 = Array(NestedObj4()) print('complex schema: ', ComplexRecord.schema()) self.assertEqual(ComplexRecord.schema(), { @@ -909,18 +917,23 @@ class ComplexRecord(Record): "type": "record", "fields": [ {"name": "a", "type": ["null", "int"]}, - {'name': 'arrayNested', 'type': ['null', - {'type': 'array', 'items': {'name': 'NestedObj4', 'type': 'record', 'fields': [ + {'name': 'arrayNested', 'type': ['null', {'type': 'array', 'items': + {'name': 'NestedObj4', 'type': 'record', 'fields': [ {'name': 'na4', 'type': ['null', 'string']}, {'name': 'nb4', 'type': ['null', 'int']} ]}} ]}, + {'name': 'arrayNested2', 'type': ['null', {'type': 'array', 'items': 'NestedObj4'}]}, {"name": "b", "type": ["null", "int"]}, + {'name': 'color', 'type': ['null', {'type': 'enum', 'name': 'Color', 'symbols': [ + 'red', 'green', 'blue']}]}, + {'name': 'color2', 'type': ['null', 'Color']}, {'name': 'mapNested', 'type': ['null', {'type': 'map', 'values': {'name': 'NestedObj3', 'type': 'record', 'fields': [ {'name': 'na3', 'type': ['null', 'int']} ]}} ]}, + {'name': 'mapNested2', 'type': ['null', {'type': 'map', 'values': 'NestedObj3'}]}, {"name": "nested", "type": ['null', {'name': 'NestedObj2', 'type': 'record', 'fields': [ {'name': 'na2', 'type': ['null', 'int']}, {'name': 'nb2', 'type': ['null', 'boolean']}, @@ -928,7 +941,8 @@ class ComplexRecord(Record): {'name': 'na1', 'type': ['null', 'string']}, {'name': 'nb1', 'type': ['null', 'double']} ]}]} - ]}]} + ]}]}, + {"name": "nested2", "type": ['null', 'NestedObj2']} ] }) @@ -939,13 +953,22 @@ def encode_and_decode(schema_type): nested_obj1 = NestedObj1(na1='na1 value', nb1=20.5) nested_obj2 = NestedObj2(na2=22, nb2=True, nc2=nested_obj1) - r = ComplexRecord(a=1, b=2, nested=nested_obj2, mapNested={ + r = ComplexRecord(a=1, b=2, color=Color.red, color2=Color.blue, + nested=nested_obj2, nested2=nested_obj2, + mapNested={ 'a': NestedObj3(na3=1), 'b': NestedObj3(na3=2), 'c': NestedObj3(na3=3) + }, mapNested2={ + 'd': NestedObj3(na3=4), + 'e': NestedObj3(na3=5), + 'f': NestedObj3(na3=6) }, arrayNested=[ NestedObj4(na4='value na4 1', nb4=100), NestedObj4(na4='value na4 2', nb4=200) + ], arrayNested2=[ + NestedObj4(na4='value na4 3', nb4=300), + NestedObj4(na4='value na4 4', nb4=400) ]) data_encode = data_schema.encode(r) @@ -954,17 +977,30 @@ def encode_and_decode(schema_type): self.assertEqual(data_decode, r) self.assertEqual(data_decode.a, 1) self.assertEqual(data_decode.b, 2) + self.assertEqual(data_decode.color, Color.red) + self.assertEqual(data_decode.color2, Color.blue) self.assertEqual(data_decode.nested.na2, 22) self.assertEqual(data_decode.nested.nb2, True) self.assertEqual(data_decode.nested.nc2.na1, 'na1 value') self.assertEqual(data_decode.nested.nc2.nb1, 20.5) + self.assertEqual(data_decode.nested2.na2, 22) + self.assertEqual(data_decode.nested2.nb2, True) + self.assertEqual(data_decode.nested2.nc2.na1, 'na1 value') + self.assertEqual(data_decode.nested2.nc2.nb1, 20.5) self.assertEqual(data_decode.mapNested['a'].na3, 1) self.assertEqual(data_decode.mapNested['b'].na3, 2) self.assertEqual(data_decode.mapNested['c'].na3, 3) + self.assertEqual(data_decode.mapNested2['d'].na3, 4) + self.assertEqual(data_decode.mapNested2['e'].na3, 5) + self.assertEqual(data_decode.mapNested2['f'].na3, 6) self.assertEqual(data_decode.arrayNested[0].na4, 'value na4 1') self.assertEqual(data_decode.arrayNested[0].nb4, 100) self.assertEqual(data_decode.arrayNested[1].na4, 'value na4 2') self.assertEqual(data_decode.arrayNested[1].nb4, 200) + self.assertEqual(data_decode.arrayNested2[0].na4, 'value na4 3') + self.assertEqual(data_decode.arrayNested2[0].nb4, 300) + self.assertEqual(data_decode.arrayNested2[1].na4, 'value na4 4') + self.assertEqual(data_decode.arrayNested2[1].nb4, 400) print('Encode and decode complex schema finish. schema_type: ', schema_type) encode_and_decode('avro')