From 2c1652b7ff8f40d7faa23542071a42718d2c28ad Mon Sep 17 00:00:00 2001 From: puyj Date: Tue, 21 Jun 2022 16:03:10 +0800 Subject: [PATCH 1/6] [Python] Implement validate_all function Signed-off-by: puyj --- python/README.md | 4 +- python/protoc_gen_validate/validator.py | 57 +++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/python/README.md b/python/README.md index 017e68eca..e90c1d79c 100644 --- a/python/README.md +++ b/python/README.md @@ -8,11 +8,13 @@ in `validate.proto`. Implemented Python annotations are listed in the [rules com ### Example ```python3 from entities_pb2 import Person -from protoc_gen_validate.validator import validate, ValidationFailed +from protoc_gen_validate.validator import validate, ValidationFailed, validate_all p = Person(first_name="Foo", last_name="Bar", age=42) try: validate(p) + # you can also validate all rules + # validate_all(p) except ValidationFailed as err: print(err) ``` diff --git a/python/protoc_gen_validate/validator.py b/python/protoc_gen_validate/validator.py index d7be0b3d1..b2f5033eb 100644 --- a/python/protoc_gen_validate/validator.py +++ b/python/protoc_gen_validate/validator.py @@ -1,3 +1,4 @@ +import ast import re import struct import sys @@ -63,6 +64,62 @@ def _validate_inner(proto_message: Message): return locals()['generate_validate'] +class _Transformer(ast.NodeTransformer): + """ + Consider generated functions has the following structure: + + ``` + def generate_validate(p): + ... + if rules_stmt: + raise ValidationFailed(msg) + ... + return None + ``` + + Transformer made the following three changes: + + 1. Define a variable `err` that records all ValidationFailed error messages. + 2. Convert all `raise ValidationFailed(error_message)` to `err += error_message`. + 3. When `err` is not an empty string, `raise ValidationFailed(err)`. + """ + + def visit_FunctionDef(self, node: ast.FunctionDef): + self.generic_visit(node) + # add a suffix to the function name + node.name = node.name + "_all" + node.body.insert(0, ast.parse("err = ''").body[0]) + return node + + def visit_Raise(self, node: ast.Raise): + exc_str = " ".join(str(_.value) for _ in node.exc.args) + return ast.parse(rf'err += "\n{exc_str}"').body[0] + + def visit_Return(self, node: ast.Return): + return ast.parse("if err:\n raise ValidationFailed(err)").body[0] + + +# Cache generated functions with the message descriptor's full_name as the cache key +@lru_cache() +def _validate_all_inner(proto_message: Message): + func = file_template(ValidatingMessage(proto_message)) + func_ast = ast.parse(func) + func_ast = _Transformer().visit(func_ast) + func_ast = ast.fix_missing_locations(func_ast) + func = ast.unparse(func_ast) + global printer + printer += func + "\n" + exec(func) + try: + return generate_validate_all + except NameError: + return locals()['generate_validate_all'] + + +def validate_all(proto_message: Message): + return _validate_all_inner(ValidatingMessage(proto_message))(proto_message) + + def print_validate(): return "".join([s for s in printer.splitlines(True) if s.strip()]) From 4ac34bf571541aac9be4ad646ce4cbf8c0746a76 Mon Sep 17 00:00:00 2001 From: puyj Date: Tue, 21 Jun 2022 18:32:12 +0800 Subject: [PATCH 2/6] add _validate_all for embedded messages Signed-off-by: puyj --- python/protoc_gen_validate/validator.py | 51 ++++++++++++++++++++----- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/python/protoc_gen_validate/validator.py b/python/protoc_gen_validate/validator.py index b2f5033eb..8130ca5d1 100644 --- a/python/protoc_gen_validate/validator.py +++ b/python/protoc_gen_validate/validator.py @@ -66,22 +66,27 @@ def _validate_inner(proto_message: Message): class _Transformer(ast.NodeTransformer): """ - Consider generated functions has the following structure: + Consider the generator function has the following structure: - ``` + # Validates Message def generate_validate(p): ... if rules_stmt: raise ValidationFailed(msg) ... + if _has_field(p, "field_name"): # embedded message + embedded = validate(p.field_name) + if embedded is not None: + return embedded + ... return None - ``` - Transformer made the following three changes: + _Transformer will apply the following four changes to the original AST: - 1. Define a variable `err` that records all ValidationFailed error messages. - 2. Convert all `raise ValidationFailed(error_message)` to `err += error_message`. - 3. When `err` is not an empty string, `raise ValidationFailed(err)`. + 1. Define a variable `err` that records all `ValidationFailed` + 2. Convert all `raise ValidationFailed(error_message)` to `err += error_message` + 3. Convert `validate` to `_validate_all` + 4. return `err` . """ def visit_FunctionDef(self, node: ast.FunctionDef): @@ -95,18 +100,39 @@ def visit_Raise(self, node: ast.Raise): exc_str = " ".join(str(_.value) for _ in node.exc.args) return ast.parse(rf'err += "\n{exc_str}"').body[0] + def visit_If(self, node: ast.If): + self.generic_visit(node) + # if _has_field(p, field_name): + if isinstance(node.test, ast.Call) and getattr(node.test.func, "id", None) == "_has_field": + assign_node, if_node = node.body + new_assign_node = ast.AugAssign( + target=ast.Name(id="err", ctx=ast.Store()), + op=ast.Add(), + value=ast.Call( + func=ast.Name(id="_validate_all", ctx=ast.Load()), + args=assign_node.value.args, + keywords=[] + ) + ) + node.body = [new_assign_node] + return node + def visit_Return(self, node: ast.Return): - return ast.parse("if err:\n raise ValidationFailed(err)").body[0] + if hasattr(node.value, "value") and getattr(node.value, "value") is None: + return ast.parse("return err").body[0] + return node # Cache generated functions with the message descriptor's full_name as the cache key @lru_cache() def _validate_all_inner(proto_message: Message): func = file_template(ValidatingMessage(proto_message)) + comment = func.split("\n")[1] func_ast = ast.parse(func) func_ast = _Transformer().visit(func_ast) func_ast = ast.fix_missing_locations(func_ast) func = ast.unparse(func_ast) + func = comment + " All" + "\n" + func global printer printer += func + "\n" exec(func) @@ -116,10 +142,17 @@ def _validate_all_inner(proto_message: Message): return locals()['generate_validate_all'] -def validate_all(proto_message: Message): +def _validate_all(proto_message: Message) -> str: return _validate_all_inner(ValidatingMessage(proto_message))(proto_message) +# raise ValidationFailed if err +def validate_all(proto_message: Message): + err = _validate_all(proto_message) + if err: + raise ValidationFailed(err) + + def print_validate(): return "".join([s for s in printer.splitlines(True) if s.strip()]) From 7827b1d641c57ab154644d2fe63abcbf47090996 Mon Sep 17 00:00:00 2001 From: puyj Date: Thu, 6 Oct 2022 11:35:03 +0800 Subject: [PATCH 3/6] Refactoring _Transformer * Refactor _Transformer to make the code structure more intuitive Split `_Transformer` into 7 sub-Transformers (in execution order). When the python version is lower than 3.9, use the `unparse` function of the third-party library `astunparse` to replace the `unparse` function of the standard library `ast` --- python/protoc_gen_validate/validator.py | 182 +++++++++++++++++------- python/requirements.in | 1 + python/setup.cfg | 1 + 3 files changed, 133 insertions(+), 51 deletions(-) diff --git a/python/protoc_gen_validate/validator.py b/python/protoc_gen_validate/validator.py index 8130ca5d1..3458fc6b0 100644 --- a/python/protoc_gen_validate/validator.py +++ b/python/protoc_gen_validate/validator.py @@ -12,6 +12,13 @@ from jinja2 import Template from validate_email import validate_email + +if sys.version_info > (3, 9): + unparse = ast.unparse +else: + import astunparse + unparse = astunparse.unparse + printer = "" # Well known regex mapping. @@ -64,74 +71,147 @@ def _validate_inner(proto_message: Message): return locals()['generate_validate'] -class _Transformer(ast.NodeTransformer): - """ - Consider the generator function has the following structure: - - # Validates Message - def generate_validate(p): - ... - if rules_stmt: - raise ValidationFailed(msg) - ... - if _has_field(p, "field_name"): # embedded message - embedded = validate(p.field_name) - if embedded is not None: - return embedded - ... - return None - - _Transformer will apply the following four changes to the original AST: - - 1. Define a variable `err` that records all `ValidationFailed` - 2. Convert all `raise ValidationFailed(error_message)` to `err += error_message` - 3. Convert `validate` to `_validate_all` - 4. return `err` . - """ - +class ChangeFuncName(ast.NodeTransformer): def visit_FunctionDef(self, node: ast.FunctionDef): - self.generic_visit(node) - # add a suffix to the function name - node.name = node.name + "_all" - node.body.insert(0, ast.parse("err = ''").body[0]) + node.name = node.name + "_all" # add a suffix to the function name return node - def visit_Raise(self, node: ast.Raise): - exc_str = " ".join(str(_.value) for _ in node.exc.args) - return ast.parse(rf'err += "\n{exc_str}"').body[0] - def visit_If(self, node: ast.If): - self.generic_visit(node) - # if _has_field(p, field_name): - if isinstance(node.test, ast.Call) and getattr(node.test.func, "id", None) == "_has_field": - assign_node, if_node = node.body - new_assign_node = ast.AugAssign( - target=ast.Name(id="err", ctx=ast.Store()), - op=ast.Add(), - value=ast.Call( - func=ast.Name(id="_validate_all", ctx=ast.Load()), - args=assign_node.value.args, - keywords=[] - ) - ) - node.body = [new_assign_node] +class InitErr(ast.NodeTransformer): + def visit_FunctionDef(self, node: ast.FunctionDef): + node.body.insert(0, ast.parse("err = []").body[0]) return node + +class ReturnErr(ast.NodeTransformer): def visit_Return(self, node: ast.Return): + # Change the return value of the function from None to err if hasattr(node.value, "value") and getattr(node.value, "value") is None: return ast.parse("return err").body[0] return node +class ChangeInnerCall(ast.NodeTransformer): + def visit_Call(self, node: ast.Call): + """Changed the validation function of nested messages from `validate` to + `_validate_all`""" + if isinstance(node.func, ast.Name) and node.func.id == "validate": + node.func.id = "_validate_all" + return node + + +class ChangeRaise(ast.NodeTransformer): + def visit_Raise(self, node: ast.Raise): + """ + before: + raise ValidationFailed(reason) + after: + err.append(reason) + """ + # According to the content in the template, the exception object of all `raise` + # statements is `ValidationFailed`. + if not isinstance(node.exc, ast.Call): + return node + return ast.Expr( + value=ast.Call( + args=node.exc.args, + keywords=node.exc.keywords, + func=ast.Attribute( + attr="append", ctx=ast.Load(), value=ast.Name(id="err", ctx=ast.Load()) + ), + ) + ) + + +class ChangeEmbedded(ast.NodeTransformer): + """For embedded messages, there is a special structure in the template as follows: + + if _has_field(p, \"{{ name.split('.')[-1] }}\"): + embedded = validate(p.{{ name }}) + if embedded is not None: + return embedded + + We need to convert this code into the following form: + + if _has_field(p, \"{{ name.split('.')[-1] }}\"): + err += _validate_all(p.{{ name }} + + """ + @staticmethod + def _is_embedded_node(node: ast.Assign): + """Check if substructures match + + pattern: + embedded = validate(p.{{ name }}) + """ + if not isinstance(node, ast.Assign): + return False + if len(node.targets) != 1: + return False + target = node.targets[0] + value = node.value + if not (isinstance(target, ast.Name) and isinstance(value, ast.Call)): + return False + if not target.id == "embedded": + return False + return True + + def visit_If(self, node: ast.If): + self.generic_visit(node) + for child in ast.iter_child_nodes(node): + if self._is_embedded_node(child): + new_node = ast.AugAssign( + target=ast.Name(id="err", ctx=ast.Store()), op=ast.Add(), value=child.value + ) # err += _validate_all(p.{{ name }} + node.body = [new_node] + return node + return node + + +class ChangeExpr(ast.NodeTransformer): + + """If there is a pure `_validate_all` function call in the template function, + its return value needs to be recorded in err + + before: + _validate_all(item) + + after: + err += _validate_all(item} + + """ + + def visit_Expr(self, node: ast.Expr): + if not isinstance(node.value, ast.Call): + return node + call_node = node.value + if not isinstance(call_node.func, ast.Name): + return node + if not call_node.func.id == "_validate_all": + return node + return ast.AugAssign( + target=ast.Name(id="err", ctx=ast.Store()), op=ast.Add(), value=call_node + ) # err += _validate_all(item} + + # Cache generated functions with the message descriptor's full_name as the cache key @lru_cache() def _validate_all_inner(proto_message: Message): func = file_template(ValidatingMessage(proto_message)) comment = func.split("\n")[1] - func_ast = ast.parse(func) - func_ast = _Transformer().visit(func_ast) + func_ast = ast.parse(rf"{func}") + for transformer in [ + ChangeFuncName, + InitErr, + ReturnErr, + ChangeInnerCall, + ChangeRaise, + ChangeEmbedded, + ChangeExpr, + ]: # order is important! + func_ast = ast.fix_missing_locations(transformer().visit(func_ast)) func_ast = ast.fix_missing_locations(func_ast) - func = ast.unparse(func_ast) + func = unparse(func_ast) func = comment + " All" + "\n" + func global printer printer += func + "\n" @@ -150,7 +230,7 @@ def _validate_all(proto_message: Message) -> str: def validate_all(proto_message: Message): err = _validate_all(proto_message) if err: - raise ValidationFailed(err) + raise ValidationFailed('\n'.join(err)) def print_validate(): diff --git a/python/requirements.in b/python/requirements.in index 4fe65f8b0..f0fc9389e 100644 --- a/python/requirements.in +++ b/python/requirements.in @@ -4,3 +4,4 @@ validate-email>=1.3 Jinja2>=2.11.1 protobuf>=3.6.1 +astunparse>=1.6.3 diff --git a/python/setup.cfg b/python/setup.cfg index 2fef6dd1e..de52c2dd1 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -21,6 +21,7 @@ install_requires = validate-email>=1.3 Jinja2>=2.11.1 protobuf>=3.6.1 + astunparse>=1.6.3 python_requires = >=3.6 [options.data_files] From a19b16914511c8251ca79d216f14c64b070daf76 Mon Sep 17 00:00:00 2001 From: puyj Date: Fri, 7 Oct 2022 10:15:39 +0800 Subject: [PATCH 4/6] Add test for validate_all to Python harness * Add test for validate_all to Python harness The result of `validate` is guaranteed to be the same as the result of `validate_all`. When validation fails, the error message of `validate` is guaranteed to be the same as the first error message of `validate_all`. --- tests/harness/python/harness.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/harness/python/harness.py b/tests/harness/python/harness.py index 8ca0254f2..ae35598f3 100644 --- a/tests/harness/python/harness.py +++ b/tests/harness/python/harness.py @@ -1,7 +1,7 @@ import sys import inspect -from python.protoc_gen_validate.validator import validate, ValidationFailed +from python.protoc_gen_validate.validator import validate, validate_all, ValidationFailed from tests.harness.harness_pb2 import TestCase, TestResult from tests.harness.cases.bool_pb2 import * @@ -21,7 +21,6 @@ from tests.harness.cases.wkt_timestamp_pb2 import * from tests.harness.cases.kitchen_sink_pb2 import * - message_classes = {} for k, v in inspect.getmembers(sys.modules[__name__], inspect.isclass): if 'DESCRIPTOR' in dir(v): @@ -46,5 +45,25 @@ result.Valid = False result.Reasons[:] = [repr(e)] + try: + result_all = TestResult() + valid = validate_all(test_msg) + result_all.Valid = True + except ValidationFailed as e: + result_all.Valid = False + result_all.Reasons[:] = [repr(e)] + + if result.Valid != result_all.Valid: + raise ValueError(f"validation results mismatch, validate: {result.Valid}, " + f"validate_all: {result_all.Valid}") + if not result.Valid: + reason = list(result.Reasons)[0] # ValidationFailed("reason") + reason = reason[18:-2] # reason + reason_all = list(result_all.Reasons)[0] # ValidationFailed("reason1\nreason2\n...reason") + reason_all = reason_all[18:-2] # reason1\nreason2\n...reason + if not reason_all.startswith(reason): + raise ValueError(f"different first message, validate: {reason}, " + f"validate_all: {reason_all}") + sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf8') sys.stdout.write(result.SerializeToString().decode("utf-8")) From 0984e7e2255786d64622a441adc179e1c5e97d6c Mon Sep 17 00:00:00 2001 From: puyj Date: Fri, 7 Oct 2022 10:42:53 +0800 Subject: [PATCH 5/6] Update example --- python/README.md | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/python/README.md b/python/README.md index e90c1d79c..deb2b833f 100644 --- a/python/README.md +++ b/python/README.md @@ -10,13 +10,20 @@ in `validate.proto`. Implemented Python annotations are listed in the [rules com from entities_pb2 import Person from protoc_gen_validate.validator import validate, ValidationFailed, validate_all -p = Person(first_name="Foo", last_name="Bar", age=42) +p = Person(name="Foo") try: validate(p) - # you can also validate all rules - # validate_all(p) except ValidationFailed as err: - print(err) + print(err) # p.id is not greater than 999 + +try: + validate_all(p) +except ValidationFailed as err: + print(err) + # p.id is not greater than 999 + # p.email is not a valid email + # p.name pattern does not match ^[^[0-9]A-Za-z]+( [^[0-9]A-Za-z]+)*$ + # home is required. ``` [pgv-home]: https://github.com/envoyproxy/protoc-gen-validate From ef4f7b2df19820241ee29074f85817f33c7c7bd9 Mon Sep 17 00:00:00 2001 From: puyj Date: Wed, 19 Oct 2022 15:01:49 +0800 Subject: [PATCH 6/6] fix lint --- python/protoc_gen_validate/validator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/protoc_gen_validate/validator.py b/python/protoc_gen_validate/validator.py index 3458fc6b0..a6874ab72 100644 --- a/python/protoc_gen_validate/validator.py +++ b/python/protoc_gen_validate/validator.py @@ -12,7 +12,6 @@ from jinja2 import Template from validate_email import validate_email - if sys.version_info > (3, 9): unparse = ast.unparse else: