Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Support for validate_all function #606

Merged
merged 6 commits into from Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 12 additions & 3 deletions python/README.md
Expand Up @@ -8,13 +8,22 @@ in `validate.proto`. Implemented Python annotations are listed in the [rules com
### Example
```python3
from entities_pb2 import Person
from protoc_gen_validate.validator import validate, ValidationFailed
from protoc_gen_validate.validator import validate, ValidationFailed, validate_all

p = Person(first_name="Foo", last_name="Bar", age=42)
p = Person(name="Foo")
try:
validate(p)
except ValidationFailed as err:
print(err)
print(err) # p.id is not greater than 999

try:
validate_all(p)
except ValidationFailed as err:
print(err)
# p.id is not greater than 999
# p.email is not a valid email
# p.name pattern does not match ^[^[0-9]A-Za-z]+( [^[0-9]A-Za-z]+)*$
# home is required.
```

[pgv-home]: https://github.com/envoyproxy/protoc-gen-validate
Expand Down
169 changes: 169 additions & 0 deletions python/protoc_gen_validate/validator.py
@@ -1,3 +1,4 @@
import ast
import re
import struct
import sys
Expand All @@ -11,6 +12,12 @@
from jinja2 import Template
from validate_email import validate_email

if sys.version_info > (3, 9):
unparse = ast.unparse
else:
import astunparse
unparse = astunparse.unparse

printer = ""

# Well known regex mapping.
Expand Down Expand Up @@ -63,6 +70,168 @@ def _validate_inner(proto_message: Message):
return locals()['generate_validate']


class ChangeFuncName(ast.NodeTransformer):
def visit_FunctionDef(self, node: ast.FunctionDef):
node.name = node.name + "_all" # add a suffix to the function name
return node


class InitErr(ast.NodeTransformer):
def visit_FunctionDef(self, node: ast.FunctionDef):
node.body.insert(0, ast.parse("err = []").body[0])
return node


class ReturnErr(ast.NodeTransformer):
def visit_Return(self, node: ast.Return):
# Change the return value of the function from None to err
if hasattr(node.value, "value") and getattr(node.value, "value") is None:
return ast.parse("return err").body[0]
return node


class ChangeInnerCall(ast.NodeTransformer):
def visit_Call(self, node: ast.Call):
"""Changed the validation function of nested messages from `validate` to
`_validate_all`"""
if isinstance(node.func, ast.Name) and node.func.id == "validate":
node.func.id = "_validate_all"
return node


class ChangeRaise(ast.NodeTransformer):
def visit_Raise(self, node: ast.Raise):
"""
before:
raise ValidationFailed(reason)
after:
err.append(reason)
"""
# According to the content in the template, the exception object of all `raise`
# statements is `ValidationFailed`.
if not isinstance(node.exc, ast.Call):
return node
return ast.Expr(
value=ast.Call(
args=node.exc.args,
keywords=node.exc.keywords,
func=ast.Attribute(
attr="append", ctx=ast.Load(), value=ast.Name(id="err", ctx=ast.Load())
),
)
)


class ChangeEmbedded(ast.NodeTransformer):
"""For embedded messages, there is a special structure in the template as follows:

if _has_field(p, \"{{ name.split('.')[-1] }}\"):
embedded = validate(p.{{ name }})
if embedded is not None:
return embedded

We need to convert this code into the following form:

if _has_field(p, \"{{ name.split('.')[-1] }}\"):
err += _validate_all(p.{{ name }}

"""
@staticmethod
def _is_embedded_node(node: ast.Assign):
"""Check if substructures match

pattern:
embedded = validate(p.{{ name }})
"""
if not isinstance(node, ast.Assign):
return False
if len(node.targets) != 1:
return False
target = node.targets[0]
value = node.value
if not (isinstance(target, ast.Name) and isinstance(value, ast.Call)):
return False
if not target.id == "embedded":
return False
return True

def visit_If(self, node: ast.If):
self.generic_visit(node)
for child in ast.iter_child_nodes(node):
if self._is_embedded_node(child):
new_node = ast.AugAssign(
target=ast.Name(id="err", ctx=ast.Store()), op=ast.Add(), value=child.value
) # err += _validate_all(p.{{ name }}
node.body = [new_node]
return node
return node


class ChangeExpr(ast.NodeTransformer):

"""If there is a pure `_validate_all` function call in the template function,
its return value needs to be recorded in err

before:
_validate_all(item)

after:
err += _validate_all(item}

"""

def visit_Expr(self, node: ast.Expr):
if not isinstance(node.value, ast.Call):
return node
call_node = node.value
if not isinstance(call_node.func, ast.Name):
return node
if not call_node.func.id == "_validate_all":
return node
return ast.AugAssign(
target=ast.Name(id="err", ctx=ast.Store()), op=ast.Add(), value=call_node
) # err += _validate_all(item}


# Cache generated functions with the message descriptor's full_name as the cache key
@lru_cache()
def _validate_all_inner(proto_message: Message):
func = file_template(ValidatingMessage(proto_message))
comment = func.split("\n")[1]
func_ast = ast.parse(rf"{func}")
for transformer in [
ChangeFuncName,
InitErr,
ReturnErr,
ChangeInnerCall,
ChangeRaise,
ChangeEmbedded,
ChangeExpr,
]: # order is important!
func_ast = ast.fix_missing_locations(transformer().visit(func_ast))
func_ast = ast.fix_missing_locations(func_ast)
func = unparse(func_ast)
func = comment + " All" + "\n" + func
global printer
printer += func + "\n"
exec(func)
try:
return generate_validate_all
except NameError:
return locals()['generate_validate_all']


def _validate_all(proto_message: Message) -> str:
return _validate_all_inner(ValidatingMessage(proto_message))(proto_message)


# raise ValidationFailed if err
def validate_all(proto_message: Message):
err = _validate_all(proto_message)
if err:
raise ValidationFailed('\n'.join(err))


def print_validate():
return "".join([s for s in printer.splitlines(True) if s.strip()])

Expand Down
1 change: 1 addition & 0 deletions python/requirements.in
Expand Up @@ -4,3 +4,4 @@
validate-email>=1.3
Jinja2>=2.11.1
protobuf>=3.6.1
astunparse>=1.6.3
1 change: 1 addition & 0 deletions python/setup.cfg
Expand Up @@ -21,6 +21,7 @@ install_requires =
validate-email>=1.3
Jinja2>=2.11.1
protobuf>=3.6.1
astunparse>=1.6.3
python_requires = >=3.6

[options.data_files]
Expand Down
23 changes: 21 additions & 2 deletions tests/harness/python/harness.py
@@ -1,7 +1,7 @@
import sys
import inspect

from python.protoc_gen_validate.validator import validate, ValidationFailed
from python.protoc_gen_validate.validator import validate, validate_all, ValidationFailed

from tests.harness.harness_pb2 import TestCase, TestResult
from tests.harness.cases.bool_pb2 import *
Expand All @@ -21,7 +21,6 @@
from tests.harness.cases.wkt_timestamp_pb2 import *
from tests.harness.cases.kitchen_sink_pb2 import *


message_classes = {}
for k, v in inspect.getmembers(sys.modules[__name__], inspect.isclass):
if 'DESCRIPTOR' in dir(v):
Expand All @@ -46,5 +45,25 @@
result.Valid = False
result.Reasons[:] = [repr(e)]

try:
result_all = TestResult()
valid = validate_all(test_msg)
result_all.Valid = True
except ValidationFailed as e:
result_all.Valid = False
result_all.Reasons[:] = [repr(e)]

if result.Valid != result_all.Valid:
raise ValueError(f"validation results mismatch, validate: {result.Valid}, "
f"validate_all: {result_all.Valid}")
if not result.Valid:
reason = list(result.Reasons)[0] # ValidationFailed("reason")
reason = reason[18:-2] # reason
reason_all = list(result_all.Reasons)[0] # ValidationFailed("reason1\nreason2\n...reason")
reason_all = reason_all[18:-2] # reason1\nreason2\n...reason
if not reason_all.startswith(reason):
raise ValueError(f"different first message, validate: {reason}, "
f"validate_all: {reason_all}")

sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf8')
sys.stdout.write(result.SerializeToString().decode("utf-8"))