diff --git a/python/README.md b/python/README.md index 017e68eca..deb2b833f 100644 --- a/python/README.md +++ b/python/README.md @@ -8,13 +8,22 @@ in `validate.proto`. Implemented Python annotations are listed in the [rules com ### Example ```python3 from entities_pb2 import Person -from protoc_gen_validate.validator import validate, ValidationFailed +from protoc_gen_validate.validator import validate, ValidationFailed, validate_all -p = Person(first_name="Foo", last_name="Bar", age=42) +p = Person(name="Foo") try: validate(p) except ValidationFailed as err: - print(err) + print(err) # p.id is not greater than 999 + +try: + validate_all(p) +except ValidationFailed as err: + print(err) + # p.id is not greater than 999 + # p.email is not a valid email + # p.name pattern does not match ^[^[0-9]A-Za-z]+( [^[0-9]A-Za-z]+)*$ + # home is required. ``` [pgv-home]: https://github.com/envoyproxy/protoc-gen-validate diff --git a/python/protoc_gen_validate/validator.py b/python/protoc_gen_validate/validator.py index d7be0b3d1..a6874ab72 100644 --- a/python/protoc_gen_validate/validator.py +++ b/python/protoc_gen_validate/validator.py @@ -1,3 +1,4 @@ +import ast import re import struct import sys @@ -11,6 +12,12 @@ from jinja2 import Template from validate_email import validate_email +if sys.version_info > (3, 9): + unparse = ast.unparse +else: + import astunparse + unparse = astunparse.unparse + printer = "" # Well known regex mapping. @@ -63,6 +70,168 @@ def _validate_inner(proto_message: Message): return locals()['generate_validate'] +class ChangeFuncName(ast.NodeTransformer): + def visit_FunctionDef(self, node: ast.FunctionDef): + node.name = node.name + "_all" # add a suffix to the function name + return node + + +class InitErr(ast.NodeTransformer): + def visit_FunctionDef(self, node: ast.FunctionDef): + node.body.insert(0, ast.parse("err = []").body[0]) + return node + + +class ReturnErr(ast.NodeTransformer): + def visit_Return(self, node: ast.Return): + # Change the return value of the function from None to err + if hasattr(node.value, "value") and getattr(node.value, "value") is None: + return ast.parse("return err").body[0] + return node + + +class ChangeInnerCall(ast.NodeTransformer): + def visit_Call(self, node: ast.Call): + """Changed the validation function of nested messages from `validate` to + `_validate_all`""" + if isinstance(node.func, ast.Name) and node.func.id == "validate": + node.func.id = "_validate_all" + return node + + +class ChangeRaise(ast.NodeTransformer): + def visit_Raise(self, node: ast.Raise): + """ + before: + raise ValidationFailed(reason) + after: + err.append(reason) + """ + # According to the content in the template, the exception object of all `raise` + # statements is `ValidationFailed`. + if not isinstance(node.exc, ast.Call): + return node + return ast.Expr( + value=ast.Call( + args=node.exc.args, + keywords=node.exc.keywords, + func=ast.Attribute( + attr="append", ctx=ast.Load(), value=ast.Name(id="err", ctx=ast.Load()) + ), + ) + ) + + +class ChangeEmbedded(ast.NodeTransformer): + """For embedded messages, there is a special structure in the template as follows: + + if _has_field(p, \"{{ name.split('.')[-1] }}\"): + embedded = validate(p.{{ name }}) + if embedded is not None: + return embedded + + We need to convert this code into the following form: + + if _has_field(p, \"{{ name.split('.')[-1] }}\"): + err += _validate_all(p.{{ name }} + + """ + @staticmethod + def _is_embedded_node(node: ast.Assign): + """Check if substructures match + + pattern: + embedded = validate(p.{{ name }}) + """ + if not isinstance(node, ast.Assign): + return False + if len(node.targets) != 1: + return False + target = node.targets[0] + value = node.value + if not (isinstance(target, ast.Name) and isinstance(value, ast.Call)): + return False + if not target.id == "embedded": + return False + return True + + def visit_If(self, node: ast.If): + self.generic_visit(node) + for child in ast.iter_child_nodes(node): + if self._is_embedded_node(child): + new_node = ast.AugAssign( + target=ast.Name(id="err", ctx=ast.Store()), op=ast.Add(), value=child.value + ) # err += _validate_all(p.{{ name }} + node.body = [new_node] + return node + return node + + +class ChangeExpr(ast.NodeTransformer): + + """If there is a pure `_validate_all` function call in the template function, + its return value needs to be recorded in err + + before: + _validate_all(item) + + after: + err += _validate_all(item} + + """ + + def visit_Expr(self, node: ast.Expr): + if not isinstance(node.value, ast.Call): + return node + call_node = node.value + if not isinstance(call_node.func, ast.Name): + return node + if not call_node.func.id == "_validate_all": + return node + return ast.AugAssign( + target=ast.Name(id="err", ctx=ast.Store()), op=ast.Add(), value=call_node + ) # err += _validate_all(item} + + +# Cache generated functions with the message descriptor's full_name as the cache key +@lru_cache() +def _validate_all_inner(proto_message: Message): + func = file_template(ValidatingMessage(proto_message)) + comment = func.split("\n")[1] + func_ast = ast.parse(rf"{func}") + for transformer in [ + ChangeFuncName, + InitErr, + ReturnErr, + ChangeInnerCall, + ChangeRaise, + ChangeEmbedded, + ChangeExpr, + ]: # order is important! + func_ast = ast.fix_missing_locations(transformer().visit(func_ast)) + func_ast = ast.fix_missing_locations(func_ast) + func = unparse(func_ast) + func = comment + " All" + "\n" + func + global printer + printer += func + "\n" + exec(func) + try: + return generate_validate_all + except NameError: + return locals()['generate_validate_all'] + + +def _validate_all(proto_message: Message) -> str: + return _validate_all_inner(ValidatingMessage(proto_message))(proto_message) + + +# raise ValidationFailed if err +def validate_all(proto_message: Message): + err = _validate_all(proto_message) + if err: + raise ValidationFailed('\n'.join(err)) + + def print_validate(): return "".join([s for s in printer.splitlines(True) if s.strip()]) diff --git a/python/requirements.in b/python/requirements.in index 4fe65f8b0..f0fc9389e 100644 --- a/python/requirements.in +++ b/python/requirements.in @@ -4,3 +4,4 @@ validate-email>=1.3 Jinja2>=2.11.1 protobuf>=3.6.1 +astunparse>=1.6.3 diff --git a/python/setup.cfg b/python/setup.cfg index 2fef6dd1e..de52c2dd1 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -21,6 +21,7 @@ install_requires = validate-email>=1.3 Jinja2>=2.11.1 protobuf>=3.6.1 + astunparse>=1.6.3 python_requires = >=3.6 [options.data_files] diff --git a/tests/harness/python/harness.py b/tests/harness/python/harness.py index 8ca0254f2..ae35598f3 100644 --- a/tests/harness/python/harness.py +++ b/tests/harness/python/harness.py @@ -1,7 +1,7 @@ import sys import inspect -from python.protoc_gen_validate.validator import validate, ValidationFailed +from python.protoc_gen_validate.validator import validate, validate_all, ValidationFailed from tests.harness.harness_pb2 import TestCase, TestResult from tests.harness.cases.bool_pb2 import * @@ -21,7 +21,6 @@ from tests.harness.cases.wkt_timestamp_pb2 import * from tests.harness.cases.kitchen_sink_pb2 import * - message_classes = {} for k, v in inspect.getmembers(sys.modules[__name__], inspect.isclass): if 'DESCRIPTOR' in dir(v): @@ -46,5 +45,25 @@ result.Valid = False result.Reasons[:] = [repr(e)] + try: + result_all = TestResult() + valid = validate_all(test_msg) + result_all.Valid = True + except ValidationFailed as e: + result_all.Valid = False + result_all.Reasons[:] = [repr(e)] + + if result.Valid != result_all.Valid: + raise ValueError(f"validation results mismatch, validate: {result.Valid}, " + f"validate_all: {result_all.Valid}") + if not result.Valid: + reason = list(result.Reasons)[0] # ValidationFailed("reason") + reason = reason[18:-2] # reason + reason_all = list(result_all.Reasons)[0] # ValidationFailed("reason1\nreason2\n...reason") + reason_all = reason_all[18:-2] # reason1\nreason2\n...reason + if not reason_all.startswith(reason): + raise ValueError(f"different first message, validate: {reason}, " + f"validate_all: {reason_all}") + sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf8') sys.stdout.write(result.SerializeToString().decode("utf-8"))