Skip to content

Commit

Permalink
Un-pin and update Marshmallow and webargs
Browse files Browse the repository at this point in the history
As expected, these updates ended up requiring quite a few changes. I was
initially going to update only Marshmallow, but the older version of
webargs couldn't work with an updated Marshmallow, so I ended up needing
to do both at the same time.

The main changes required were:

* Schemas don't need to be specified as "strict" any more, so that could
  be removed from constructors and Meta classes.
* .validate() now returns a dict of errors (if any) instead of raising a
  ValidationError if anything goes wrong. This meant that I either need
  to check the returned dict, or switch to .load() to still get raised
  errors.
* Marshmallow doesn't support loading from two different field names as
  easily (and changed the name of that to data_key), so all the routes
  using "group_path" needed to be changed to just "path".
* Some of the Field methods and some decorated schema ones like
  @pre_load receive new arguments and needed to be updated to handle
  them and/or pass them on.
* webargs will no longer send a keyword argument for any fields that
  aren't specified and don't have a default value when using
  @use_kwargs. Because of this, I added missing= values for most
  optional fields, but still required some special treatment for the
  order query variable in a couple of topic listing views.

And finally, there is some strange behavior in webargs by default when a
form doesn't send any data for a field (due to the input not being
included or similar). When it doesn't find the field in form data, it
tries to fall back to checking for JSON data, but then crashes because
the request doesn't have any JSON data attached. I had to specify only
to look in the form data in a few places to fix this, but I've also
registered an issue against webargs related to it:
marshmallow-code/webargs#444
  • Loading branch information
Deimos committed Dec 8, 2019
1 parent 4b92096 commit 110a930
Show file tree
Hide file tree
Showing 39 changed files with 216 additions and 190 deletions.
4 changes: 2 additions & 2 deletions tildes/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jinja2==2.10.3 # via pyramid-jinja2
lazy-object-proxy==1.4.3 # via astroid
mako==1.1.0 # via alembic, pyramid-mako
markupsafe==1.1.1 # via jinja2, mako, pyramid-jinja2
marshmallow==2.20.5
marshmallow==3.3.0
mccabe==0.6.1 # via prospector, pylint
more-itertools==8.0.2 # via pytest, zipp
mypy-extensions==0.4.3 # via mypy
Expand Down Expand Up @@ -108,7 +108,7 @@ urllib3==1.25.7 # via requests, sentry-sdk
venusian==3.0.0 # via cornice, pyramid
waitress==1.3.1 # via webtest
wcwidth==0.1.7 # via prompt-toolkit, pytest
webargs==4.4.1
webargs==5.5.2
webassets==0.12.1 # via pyramid-webassets
webencodings==0.5.1 # via bleach, html5lib
webob==1.8.5 # via pyramid, webtest
Expand Down
4 changes: 2 additions & 2 deletions tildes/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ cornice
gunicorn
html5lib
ipython
marshmallow<3.0 # 3.0+ requires significant updates
marshmallow
Pillow
pip-tools
prometheus-client
Expand All @@ -35,6 +35,6 @@ SQLAlchemy
SQLAlchemy-Utils
stripe
titlecase
webargs<5.0 # 5.0.0 breaks many views, will require significant updates
webargs
wrapt
zope.sqlalchemy
4 changes: 2 additions & 2 deletions tildes/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jedi==0.15.1 # via ipython
jinja2==2.10.3 # via pyramid-jinja2
mako==1.1.0 # via alembic
markupsafe==1.1.1 # via jinja2, mako, pyramid-jinja2
marshmallow==2.20.5
marshmallow==3.3.0
parso==0.5.1 # via jedi
pastedeploy==2.0.1 # via plaster-pastedeploy
pexpect==4.7.0 # via ipython
Expand Down Expand Up @@ -65,7 +65,7 @@ translationstring==1.3 # via pyramid
urllib3==1.25.7 # via requests, sentry-sdk
venusian==3.0.0 # via cornice, pyramid
wcwidth==0.1.7 # via prompt-toolkit
webargs==4.4.1
webargs==5.5.2
webassets==0.12.1 # via pyramid-webassets
webencodings==0.5.1 # via bleach, html5lib
webob==1.8.5 # via pyramid
Expand Down
6 changes: 3 additions & 3 deletions tildes/tests/test_markdown_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class MarkdownFieldTestSchema(Schema):

def validate_string(string):
"""Validate a string against a standard Markdown field."""
MarkdownFieldTestSchema(strict=True).validate({"markdown": string})
MarkdownFieldTestSchema().load({"markdown": string})


def test_normal_text_validates():
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_carriage_returns_stripped():
"""Ensure loading a value strips out carriage returns from the string."""
test_string = "some\r\nreturns\r\nin\nhere"

schema = MarkdownFieldTestSchema(strict=True)
schema = MarkdownFieldTestSchema()
result = schema.load({"markdown": test_string})

assert "\r" not in result.data["markdown"]
assert "\r" not in result["markdown"]
4 changes: 2 additions & 2 deletions tildes/tests/test_simplestring_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def process_string(string):
This also works for testing validation since .load() will raise a ValidationError if
an invalid string is attempted.
"""
schema = SimpleStringTestSchema(strict=True)
schema = SimpleStringTestSchema()
result = schema.load({"subject": string})

return result.data["subject"]
return result["subject"]


def test_changing_max_length():
Expand Down
16 changes: 8 additions & 8 deletions tildes/tests/test_title.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2018 Tildes contributors <code@tildes.net>
# SPDX-License-Identifier: AGPL-3.0-or-later

from marshmallow.exceptions import ValidationError
from marshmallow import ValidationError
from pytest import fixture, raises

from tildes.schemas.topic import TITLE_MAX_LENGTH, TopicSchema
Expand All @@ -23,44 +23,44 @@ def test_too_long_title_invalid(title_schema):
"""Ensure a too-long title is invalid."""
title = "x" * (TITLE_MAX_LENGTH + 1)
with raises(ValidationError):
title_schema.validate({"title": title})
title_schema.load({"title": title})


def test_empty_title_invalid(title_schema):
"""Ensure an empty title is invalid."""
with raises(ValidationError):
title_schema.validate({"title": ""})
title_schema.load({"title": ""})


def test_whitespace_only_title_invalid(title_schema):
"""Ensure a whitespace-only title is invalid."""
with raises(ValidationError):
title_schema.validate({"title": " \n "})
title_schema.load({"title": " \n "})


def test_whitespace_trimmed(title_schema):
"""Ensure leading/trailing whitespace on a title is removed."""
title = " actual title "
result = title_schema.load({"title": title})
assert result.data["title"] == "actual title"
assert result["title"] == "actual title"


def test_consecutive_whitespace_removed(title_schema):
"""Ensure consecutive whitespace in a title is compressed."""
title = "sure are \n a lot of spaces"
result = title_schema.load({"title": title})
assert result.data["title"] == "sure are a lot of spaces"
assert result["title"] == "sure are a lot of spaces"


def test_unicode_spaces_normalized(title_schema):
"""Test that some unicode space characters are converted to normal ones."""
title = "some\u2009weird\u00a0spaces\u205fin\u00a0here"
result = title_schema.load({"title": title})
assert result.data["title"] == "some weird spaces in here"
assert result["title"] == "some weird spaces in here"


def test_unicode_control_chars_removed(title_schema):
"""Test that some unicode control characters are stripped from titles."""
title = "nothing\u0000strange\u0085going\u009con\u007fhere"
result = title_schema.load({"title": title})
assert result.data["title"] == "nothingstrangegoingonhere"
assert result["title"] == "nothingstrangegoingonhere"
6 changes: 3 additions & 3 deletions tildes/tests/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@


def test_creation_validates_schema(mocker):
"""Ensure that model creation goes through schema validation."""
mocker.spy(UserSchema, "validate")
"""Ensure that model creation goes through schema validation (via load())."""
mocker.spy(UserSchema, "load")
User("testing", "testpassword")
call_args = [call[0] for call in UserSchema.validate.call_args_list]
call_args = [call[0] for call in UserSchema.load.call_args_list]
expected_args = {"username": "testing", "password": "testpassword"}
assert any(expected_args in call for call in call_args)

Expand Down
2 changes: 1 addition & 1 deletion tildes/tildes/models/database_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _validate_new_value(self, attribute: str, value: Any) -> Any:
return value

result = self.schema.load({attribute: value})
return result.data[attribute]
return result[attribute]


DatabaseModel = declarative_base(
Expand Down
2 changes: 1 addition & 1 deletion tildes/tildes/models/user/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def password(self) -> NoReturn:
def password(self, value: str) -> None:
# need to do manual validation since some password checks depend on checking the
# username at the same time (for similarity)
self.schema.validate({"username": self.username, "password": value})
self.schema.load({"username": self.username, "password": value})

self.password_hash = hash_string(value)

Expand Down
6 changes: 3 additions & 3 deletions tildes/tildes/resources/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
locations=("matchdict",),
)
def group_by_path(request: Request, path: str) -> Group:
"""Get a group specified by {group_path} in the route (or 404)."""
"""Get a group specified by {path} in the route (or 404)."""
# If loading the specified group path into the GroupSchema changed it, do a 301
# redirect to the resulting group path. This will happen in cases like the original
# url including capital letters in the group path, where we want to redirect to the
# proper all-lowercase path instead.
if path != request.matchdict["group_path"]:
request.matchdict["group_path"] = path
if path != request.matchdict["path"]:
request.matchdict["path"] = path
proper_url = request.route_url(request.matched_route.name, **request.matchdict)

raise HTTPMovedPermanently(location=proper_url)
Expand Down
4 changes: 2 additions & 2 deletions tildes/tildes/resources/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def topic_by_id36(request: Request, topic_id36: str) -> Topic:

# if there's also a group specified in the route, check that it's the same group as
# the topic was posted in, otherwise redirect to correct group
if "group_path" in request.matchdict:
path_from_route = request.matchdict["group_path"].lower()
if "path" in request.matchdict:
path_from_route = request.matchdict["path"].lower()
if path_from_route != topic.group.path:
raise HTTPFound(topic.permalink)

Expand Down
8 changes: 4 additions & 4 deletions tildes/tildes/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def includeme(config: Configurator) -> None:

config.add_route("register", "/register")

config.add_route("group", "/~{group_path}", factory=group_by_path)
with config.route_prefix_context("/~{group_path}"):
config.add_route("group", "/~{path}", factory=group_by_path)
with config.route_prefix_context("/~{path}"):
config.add_route("new_topic", "/new_topic", factory=group_by_path)

config.add_route("group_topics", "/topics", factory=group_by_path)
Expand Down Expand Up @@ -123,7 +123,7 @@ def includeme(config: Configurator) -> None:

# Add routes for the link-shortener under the /shortener path
with config.route_prefix_context("/shortener"):
config.add_route("shortener_group", "/~{group_path}", factory=group_by_path)
config.add_route("shortener_group", "/~{path}", factory=group_by_path)
config.add_route("shortener_topic", "/{topic_id36}", factory=topic_by_id36)


Expand All @@ -135,7 +135,7 @@ def add_ic_route(name: str, path: str, **kwargs: Any) -> None:
name = "ic_" + name
config.add_route(name, path, header="X-IC-Request:true", **kwargs)

with config.route_prefix_context("/group/{group_path}"):
with config.route_prefix_context("/group/{path}"):
add_ic_route("group_subscribe", "/subscribe", factory=group_by_path)
add_ic_route("group_user_settings", "/user_settings", factory=group_by_path)

Expand Down
11 changes: 1 addition & 10 deletions tildes/tildes/schemas/comment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,13 @@
class CommentSchema(Schema):
"""Marshmallow schema for comments."""

comment_id36 = ID36()
markdown = Markdown()
parent_comment_id36 = ID36()

class Meta:
"""Always use strict checking so error handlers are invoked."""

strict = True


class CommentLabelSchema(Schema):
"""Marshmallow schema for comment labels."""

name = Enum(CommentLabelOption)
reason = SimpleString(max_length=1000, missing=None)

class Meta:
"""Always use strict checking so error handlers are invoked."""

strict = True
42 changes: 29 additions & 13 deletions tildes/tildes/schemas/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Custom schema field definitions."""

import enum
from typing import Any, Optional, Type
from typing import Any, Mapping, Optional, Type

import sqlalchemy_utils
from marshmallow.exceptions import ValidationError
Expand All @@ -16,6 +16,10 @@
from tildes.lib.string import simplify_string


# type alias for the data argument passed to _deserialize methods
DataType = Optional[Mapping[str, Any]]


class Enum(Field):
"""Field for a native Python Enum (or subclasses)."""

Expand All @@ -25,11 +29,15 @@ def __init__(self, enum_class: Optional[Type] = None, *args: Any, **kwargs: Any)
super().__init__(*args, **kwargs)
self._enum_class = enum_class

def _serialize(self, value: enum.Enum, attr: str, obj: object) -> str:
def _serialize(
self, value: enum.Enum, attr: str, obj: object, **kwargs: Any
) -> str:
"""Serialize the enum value - lowercase version of its name."""
return value.name.lower()

def _deserialize(self, value: str, attr: str, data: dict) -> enum.Enum:
def _deserialize(
self, value: str, attr: Optional[str], data: DataType, **kwargs: Any,
) -> enum.Enum:
"""Deserialize a string to the enum member with that name."""
if not self._enum_class:
raise ValidationError("Cannot deserialize with no enum class.")
Expand All @@ -43,9 +51,9 @@ def _deserialize(self, value: str, attr: str, data: dict) -> enum.Enum:
class ID36(String):
"""Field for a base-36 ID."""

def __init__(self) -> None:
def __init__(self, **kwargs: Any) -> None:
"""Initialize the field with a regex validator."""
super().__init__(validate=Regexp(ID36_REGEX))
super().__init__(validate=Regexp(ID36_REGEX), **kwargs)


class ShortTimePeriod(Field):
Expand All @@ -55,7 +63,7 @@ class ShortTimePeriod(Field):
"""

def _deserialize(
self, value: str, attr: str, data: dict
self, value: str, attr: Optional[str], data: DataType, **kwargs: Any,
) -> Optional[SimpleHoursPeriod]:
"""Deserialize to a SimpleHoursPeriod object."""
if value == "all":
Expand All @@ -67,7 +75,7 @@ def _deserialize(
raise ValidationError("Invalid time period")

def _serialize(
self, value: Optional[SimpleHoursPeriod], attr: str, obj: object
self, value: Optional[SimpleHoursPeriod], attr: str, obj: object, **kwargs: Any,
) -> Optional[str]:
"""Serialize the value to the "short form" string."""
if not value:
Expand Down Expand Up @@ -95,13 +103,15 @@ def _validate(self, value: str) -> None:
if value.isspace():
raise ValidationError("Cannot be entirely whitespace.")

def _deserialize(self, value: str, attr: str, data: dict) -> str:
def _deserialize(
self, value: str, attr: Optional[str], data: DataType, **kwargs: Any,
) -> str:
"""Deserialize the string, removing carriage returns in the process."""
value = value.replace("\r", "")

return value

def _serialize(self, value: str, attr: str, obj: object) -> str:
def _serialize(self, value: str, attr: str, obj: object, **kwargs: Any) -> str:
"""Serialize the value (no-op in this case)."""
return value

Expand All @@ -126,23 +136,29 @@ def __init__(self, max_length: Optional[int] = None, **kwargs: Any):

super().__init__(validate=Length(min=1, max=max_length), **kwargs)

def _deserialize(self, value: str, attr: str, data: dict) -> str:
def _deserialize(
self, value: str, attr: Optional[str], data: DataType, **kwargs: Any,
) -> str:
"""Deserialize the string, removing/replacing as necessary."""
return simplify_string(value)

def _serialize(self, value: str, attr: str, obj: object) -> str:
def _serialize(self, value: str, attr: str, obj: object, **kwargs: Any) -> str:
"""Serialize the value (no-op in this case)."""
return value


class Ltree(Field):
"""Field for postgresql ltree type."""

def _serialize(self, value: sqlalchemy_utils.Ltree, attr: str, obj: object) -> str:
def _serialize(
self, value: sqlalchemy_utils.Ltree, attr: str, obj: object, **kwargs: Any
) -> str:
"""Serialize the Ltree value - use the (string) path."""
return value.path

def _deserialize(self, value: str, attr: str, data: dict) -> sqlalchemy_utils.Ltree:
def _deserialize(
self, value: str, attr: Optional[str], data: DataType, **kwargs: Any,
) -> sqlalchemy_utils.Ltree:
"""Deserialize a string path to an Ltree object."""
# convert to lowercase and replace spaces with underscores
value = value.lower().replace(" ", "_")
Expand Down

0 comments on commit 110a930

Please sign in to comment.