Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added KafkaDsn to networks.py and add default ports for HttpUrl #2447

Merged
merged 12 commits into from
Sep 3, 2021
2 changes: 2 additions & 0 deletions changes/2447-MihanixA.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- add `KafkaDsn` type
- `HttpUrl` now has default port 80 for http and 443 for https
1 change: 1 addition & 0 deletions pydantic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
'IPvAnyNetwork',
'PostgresDsn',
'RedisDsn',
'KafkaDsn',
'validate_email',
# parse
'Protocol',
Expand Down
80 changes: 65 additions & 15 deletions pydantic/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,27 @@

if TYPE_CHECKING:
import email_validator
from typing_extensions import TypedDict

from .config import BaseConfig
from .fields import ModelField
from .typing import AnyCallable

CallableGenerator = Generator[AnyCallable, None, None]

class Parts(TypedDict, total=False):
scheme: str
user: Optional[str]
password: Optional[str]
ipv4: Optional[str]
ipv6: Optional[str]
domain: Optional[str]
port: Optional[str]
path: Optional[str]
query: Optional[str]
fragment: Optional[str]


else:
email_validator = None

Expand All @@ -54,6 +69,7 @@
'IPvAnyNetwork',
'PostgresDsn',
'RedisDsn',
'KafkaDsn',
'validate_email',
]

Expand Down Expand Up @@ -109,6 +125,7 @@ class AnyUrl(str):
allowed_schemes: Optional[Set[str]] = None
tld_required: bool = False
user_required: bool = False
hidden_parts: Set[str] = set()

__slots__ = ('scheme', 'user', 'password', 'host', 'tld', 'host_type', 'port', 'path', 'query', 'fragment')

Expand Down Expand Up @@ -155,7 +172,7 @@ def build(
path: Optional[str] = None,
query: Optional[str] = None,
fragment: Optional[str] = None,
**kwargs: str,
**_kwargs: str,
) -> str:
url = scheme + '://'
if user:
Expand All @@ -165,7 +182,7 @@ def build(
if user or password:
url += '@'
url += host
if port:
if port and 'port' not in cls.hidden_parts:
url += ':' + port
if path:
url += path
Expand Down Expand Up @@ -196,7 +213,9 @@ def validate(cls, value: Any, field: 'ModelField', config: 'BaseConfig') -> 'Any
# the regex should always match, if it doesn't please report with details of the URL tried
assert m, 'URL regex failed unexpectedly'

parts = m.groupdict()
original_parts = cast('Parts', m.groupdict())
cls.hide_parts(original_parts)
parts = cls.apply_default_parts(original_parts)
parts = cls.validate_parts(parts)

host, tld, host_type, rebuild = cls.validate_host(parts)
Expand All @@ -219,7 +238,7 @@ def validate(cls, value: Any, field: 'ModelField', config: 'BaseConfig') -> 'Any
)

@classmethod
def validate_parts(cls, parts: Dict[str, str]) -> Dict[str, str]:
def validate_parts(cls, parts: 'Parts') -> 'Parts':
"""
A method used to validate parts of an URL.
Could be overridden to set default values for parts if missing
Expand All @@ -242,10 +261,10 @@ def validate_parts(cls, parts: Dict[str, str]) -> Dict[str, str]:
return parts

@classmethod
def validate_host(cls, parts: Dict[str, str]) -> Tuple[str, Optional[str], str, bool]:
def validate_host(cls, parts: 'Parts') -> Tuple[str, Optional[str], str, bool]:
host, tld, host_type, rebuild = None, None, None, False
for f in ('domain', 'ipv4', 'ipv6'):
host = parts[f]
host = parts[f] # type: ignore[misc]
if host:
host_type = f
break
Expand Down Expand Up @@ -281,6 +300,21 @@ def validate_host(cls, parts: Dict[str, str]) -> Tuple[str, Optional[str], str,

return host, tld, host_type, rebuild # type: ignore

@staticmethod
def get_default_parts(parts: 'Parts') -> 'Parts':
return {}

@classmethod
def hide_parts(cls, original_parts: 'Parts') -> None:
cls.hidden_parts = set()

@classmethod
def apply_default_parts(cls, parts: 'Parts') -> 'Parts':
for key, value in cls.get_default_parts(parts).items():
if not parts[key]: # type: ignore[misc]
parts[key] = value # type: ignore[misc]
return parts

def __repr__(self) -> str:
extra = ', '.join(f'{n}={getattr(self, n)!r}' for n in self.__slots__ if getattr(self, n) is not None)
return f'{self.__class__.__name__}({super().__repr__()}, {extra})'
Expand All @@ -290,12 +324,21 @@ class AnyHttpUrl(AnyUrl):
allowed_schemes = {'http', 'https'}


class HttpUrl(AnyUrl):
allowed_schemes = {'http', 'https'}
class HttpUrl(AnyHttpUrl):
tld_required = True
# https://stackoverflow.com/questions/417142/what-is-the-maximum-length-of-a-url-in-different-browsers
max_length = 2083

@staticmethod
def get_default_parts(parts: 'Parts') -> 'Parts':
return {'port': '80' if parts['scheme'] == 'http' else '443'}

@classmethod
def hide_parts(cls, original_parts: 'Parts') -> None:
super().hide_parts(original_parts)
if 'port' in original_parts:
cls.hidden_parts.add('port')


class PostgresDsn(AnyUrl):
allowed_schemes = {
Expand All @@ -314,17 +357,24 @@ class PostgresDsn(AnyUrl):
class RedisDsn(AnyUrl):
allowed_schemes = {'redis', 'rediss'}

@classmethod
def validate_parts(cls, parts: Dict[str, str]) -> Dict[str, str]:
defaults = {
@staticmethod
def get_default_parts(parts: 'Parts') -> 'Parts':
return {
'domain': 'localhost' if not (parts['ipv4'] or parts['ipv6']) else '',
'port': '6379',
'path': '/0',
}
for key, value in defaults.items():
if not parts[key]:
parts[key] = value
return super().validate_parts(parts)


class KafkaDsn(AnyUrl):
allowed_schemes = {'kafka'}

@staticmethod
def get_default_parts(parts: 'Parts') -> 'Parts':
return {
'domain': 'localhost',
'port': '9092',
}


def stricturl(
Expand Down
74 changes: 73 additions & 1 deletion tests/test_networks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import pytest

from pydantic import AnyUrl, BaseModel, EmailStr, HttpUrl, NameEmail, PostgresDsn, RedisDsn, ValidationError, stricturl
from pydantic import (
AnyUrl,
BaseModel,
EmailStr,
HttpUrl,
KafkaDsn,
NameEmail,
PostgresDsn,
RedisDsn,
ValidationError,
stricturl,
)
from pydantic.networks import validate_email

try:
Expand Down Expand Up @@ -332,6 +343,45 @@ class Model(BaseModel):
assert Model(v=input).v.tld == output


def test_get_default_parts():
class MyConnectionString(AnyUrl):
@staticmethod
def get_default_parts(parts):
# get default parts allows to generate custom conn strings to services
return {
'user': 'admin',
'password': '123',
}

class C(BaseModel):
connection: MyConnectionString

c = C(connection='protocol://service:8080')
assert c.connection == 'protocol://admin:123@service:8080'
assert c.connection.user == 'admin'
assert c.connection.password == '123'


@pytest.mark.parametrize(
'url,port',
[
('https://www.example.com', '443'),
('https://www.example.com:443', '443'),
('https://www.example.com:8089', '8089'),
('http://www.example.com', '80'),
('http://www.example.com:80', '80'),
('http://www.example.com:8080', '8080'),
],
)
def test_http_urls_default_port(url, port):
class Model(BaseModel):
v: HttpUrl

m = Model(v=url)
assert m.v.port == port
assert m.v == url


def test_postgres_dsns():
class Model(BaseModel):
a: PostgresDsn
Expand Down Expand Up @@ -388,6 +438,28 @@ class Model(BaseModel):
assert m.a.path == '/0'


def test_kafka_dsns():
class Model(BaseModel):
a: KafkaDsn

m = Model(a='kafka://')
assert m.a.scheme == 'kafka'
assert m.a.host == 'localhost'
assert m.a.port == '9092'
assert m.a == 'kafka://localhost:9092'

m = Model(a='kafka://kafka1')
assert m.a == 'kafka://kafka1:9092'

with pytest.raises(ValidationError) as exc_info:
Model(a='http://example.org')
assert exc_info.value.errors()[0]['type'] == 'value_error.url.scheme'

m = Model(a='kafka://kafka3:9093')
assert m.a.user is None
assert m.a.password is None


def test_custom_schemes():
class Model(BaseModel):
v: stricturl(strip_whitespace=False, allowed_schemes={'ws', 'wss'}) # noqa: F821
Expand Down