Skip to content

Commit

Permalink
add _validate_all for embedded messages
Browse files Browse the repository at this point in the history
  • Loading branch information
HaloWorld committed Jun 21, 2022
1 parent 4ef9671 commit 0b68f46
Showing 1 changed file with 42 additions and 9 deletions.
51 changes: 42 additions & 9 deletions python/protoc_gen_validate/validator.py
Expand Up @@ -66,22 +66,27 @@ def _validate_inner(proto_message: Message):

class _Transformer(ast.NodeTransformer):
"""
Consider generated functions has the following structure:
Consider the generator function has the following structure:
```
# Validates Message
def generate_validate(p):
...
if rules_stmt:
raise ValidationFailed(msg)
...
if _has_field(p, "field_name"): # embedded message
embedded = validate(p.field_name)
if embedded is not None:
return embedded
...
return None
```
Transformer made the following three changes:
_Transformer will apply the following four changes to the original AST:
1. Define a variable `err` that records all ValidationFailed error messages.
2. Convert all `raise ValidationFailed(error_message)` to `err += error_message`.
3. When `err` is not an empty string, `raise ValidationFailed(err)`.
1. Define a variable `err` that records all `ValidationFailed`
2. Convert all `raise ValidationFailed(error_message)` to `err += error_message`
3. Convert `validate` to `_validate_all`
4. return `err` .
"""

def visit_FunctionDef(self, node: ast.FunctionDef):
Expand All @@ -95,18 +100,39 @@ def visit_Raise(self, node: ast.Raise):
exc_str = " ".join(str(_.value) for _ in node.exc.args)
return ast.parse(rf'err += "\n{exc_str}"').body[0]

def visit_If(self, node: ast.If):
self.generic_visit(node)
# if _has_field(p, field_name):
if isinstance(node.test, ast.Call) and getattr(node.test.func, "id", None) == "_has_field":
assign_node, if_node = node.body
new_assign_node = ast.AugAssign(
target=ast.Name(id="err", ctx=ast.Store()),
op=ast.Add(),
value=ast.Call(
func=ast.Name(id="_validate_all", ctx=ast.Load()),
args=assign_node.value.args,
keywords=[]
)
)
node.body = [new_assign_node]
return node

def visit_Return(self, node: ast.Return):
return ast.parse("if err:\n raise ValidationFailed(err)").body[0]
if hasattr(node.value, "value") and getattr(node.value, "value") is None:
return ast.parse("return err").body[0]
return node


# Cache generated functions with the message descriptor's full_name as the cache key
@lru_cache()
def _validate_all_inner(proto_message: Message):
func = file_template(ValidatingMessage(proto_message))
comment = func.split("\n")[1]
func_ast = ast.parse(func)
func_ast = _Transformer().visit(func_ast)
func_ast = ast.fix_missing_locations(func_ast)
func = ast.unparse(func_ast)
func = comment + " All" + "\n" + func
global printer
printer += func + "\n"
exec(func)
Expand All @@ -116,10 +142,17 @@ def _validate_all_inner(proto_message: Message):
return locals()['generate_validate_all']


def validate_all(proto_message: Message):
def _validate_all(proto_message: Message) -> str:
return _validate_all_inner(ValidatingMessage(proto_message))(proto_message)


# raise ValidationFailed if err
def validate_all(proto_message: Message):
err = _validate_all(proto_message)
if err:
raise ValidationFailed(err)


def print_validate():
return "".join([s for s in printer.splitlines(True) if s.strip()])

Expand Down

0 comments on commit 0b68f46

Please sign in to comment.