From 0b68f46740176b27c6d3e3f650ed748216c92f3f Mon Sep 17 00:00:00 2001 From: puyj Date: Tue, 21 Jun 2022 18:32:12 +0800 Subject: [PATCH] add _validate_all for embedded messages --- python/protoc_gen_validate/validator.py | 51 ++++++++++++++++++++----- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/python/protoc_gen_validate/validator.py b/python/protoc_gen_validate/validator.py index e7cabcbbd..a84124d8b 100644 --- a/python/protoc_gen_validate/validator.py +++ b/python/protoc_gen_validate/validator.py @@ -66,22 +66,27 @@ def _validate_inner(proto_message: Message): class _Transformer(ast.NodeTransformer): """ - Consider generated functions has the following structure: + Consider the generator function has the following structure: - ``` + # Validates Message def generate_validate(p): ... if rules_stmt: raise ValidationFailed(msg) ... + if _has_field(p, "field_name"): # embedded message + embedded = validate(p.field_name) + if embedded is not None: + return embedded + ... return None - ``` - Transformer made the following three changes: + _Transformer will apply the following four changes to the original AST: - 1. Define a variable `err` that records all ValidationFailed error messages. - 2. Convert all `raise ValidationFailed(error_message)` to `err += error_message`. - 3. When `err` is not an empty string, `raise ValidationFailed(err)`. + 1. Define a variable `err` that records all `ValidationFailed` + 2. Convert all `raise ValidationFailed(error_message)` to `err += error_message` + 3. Convert `validate` to `_validate_all` + 4. return `err` . """ def visit_FunctionDef(self, node: ast.FunctionDef): @@ -95,18 +100,39 @@ def visit_Raise(self, node: ast.Raise): exc_str = " ".join(str(_.value) for _ in node.exc.args) return ast.parse(rf'err += "\n{exc_str}"').body[0] + def visit_If(self, node: ast.If): + self.generic_visit(node) + # if _has_field(p, field_name): + if isinstance(node.test, ast.Call) and getattr(node.test.func, "id", None) == "_has_field": + assign_node, if_node = node.body + new_assign_node = ast.AugAssign( + target=ast.Name(id="err", ctx=ast.Store()), + op=ast.Add(), + value=ast.Call( + func=ast.Name(id="_validate_all", ctx=ast.Load()), + args=assign_node.value.args, + keywords=[] + ) + ) + node.body = [new_assign_node] + return node + def visit_Return(self, node: ast.Return): - return ast.parse("if err:\n raise ValidationFailed(err)").body[0] + if hasattr(node.value, "value") and getattr(node.value, "value") is None: + return ast.parse("return err").body[0] + return node # Cache generated functions with the message descriptor's full_name as the cache key @lru_cache() def _validate_all_inner(proto_message: Message): func = file_template(ValidatingMessage(proto_message)) + comment = func.split("\n")[1] func_ast = ast.parse(func) func_ast = _Transformer().visit(func_ast) func_ast = ast.fix_missing_locations(func_ast) func = ast.unparse(func_ast) + func = comment + " All" + "\n" + func global printer printer += func + "\n" exec(func) @@ -116,10 +142,17 @@ def _validate_all_inner(proto_message: Message): return locals()['generate_validate_all'] -def validate_all(proto_message: Message): +def _validate_all(proto_message: Message) -> str: return _validate_all_inner(ValidatingMessage(proto_message))(proto_message) +# raise ValidationFailed if err +def validate_all(proto_message: Message): + err = _validate_all(proto_message) + if err: + raise ValidationFailed(err) + + def print_validate(): return "".join([s for s in printer.splitlines(True) if s.strip()])