Skip to content

Commit

Permalink
[Python] Implement validate_all function
Browse files Browse the repository at this point in the history
  • Loading branch information
HaloWorld committed Jun 21, 2022
1 parent 0e25aab commit 4ef9671
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ in `validate.proto`. Implemented Python annotations are listed in the [rules com
### Example
```python3
from entities_pb2 import Person
from protoc_gen_validate.validator import validate, ValidationFailed
from protoc_gen_validate.validator import validate, ValidationFailed, validate_all

p = Person(first_name="Foo", last_name="Bar", age=42)
try:
validate(p)
# you can also validate all rules
# validate_all(p)
except ValidationFailed as err:
print(err)
```
Expand Down
57 changes: 57 additions & 0 deletions python/protoc_gen_validate/validator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import re
import struct
import sys
Expand Down Expand Up @@ -63,6 +64,62 @@ def _validate_inner(proto_message: Message):
return locals()['generate_validate']


class _Transformer(ast.NodeTransformer):
"""
Consider generated functions has the following structure:
```
def generate_validate(p):
...
if rules_stmt:
raise ValidationFailed(msg)
...
return None
```
Transformer made the following three changes:
1. Define a variable `err` that records all ValidationFailed error messages.
2. Convert all `raise ValidationFailed(error_message)` to `err += error_message`.
3. When `err` is not an empty string, `raise ValidationFailed(err)`.
"""

def visit_FunctionDef(self, node: ast.FunctionDef):
self.generic_visit(node)
# add a suffix to the function name
node.name = node.name + "_all"
node.body.insert(0, ast.parse("err = ''").body[0])
return node

def visit_Raise(self, node: ast.Raise):
exc_str = " ".join(str(_.value) for _ in node.exc.args)
return ast.parse(rf'err += "\n{exc_str}"').body[0]

def visit_Return(self, node: ast.Return):
return ast.parse("if err:\n raise ValidationFailed(err)").body[0]


# Cache generated functions with the message descriptor's full_name as the cache key
@lru_cache()
def _validate_all_inner(proto_message: Message):
func = file_template(ValidatingMessage(proto_message))
func_ast = ast.parse(func)
func_ast = _Transformer().visit(func_ast)
func_ast = ast.fix_missing_locations(func_ast)
func = ast.unparse(func_ast)
global printer
printer += func + "\n"
exec(func)
try:
return generate_validate_all
except NameError:
return locals()['generate_validate_all']


def validate_all(proto_message: Message):
return _validate_all_inner(ValidatingMessage(proto_message))(proto_message)


def print_validate():
return "".join([s for s in printer.splitlines(True) if s.strip()])

Expand Down

0 comments on commit 4ef9671

Please sign in to comment.