Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mbeliaev/resp types #535

Merged
merged 1 commit into from Apr 11, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
104 changes: 62 additions & 42 deletions responses/__init__.py
Expand Up @@ -13,8 +13,11 @@
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import Union
from warnings import warn

Expand Down Expand Up @@ -54,6 +57,7 @@

# Block of type annotations
_Body = Union[str, BaseException, "Response", BufferedReader, bytes]
_MatcherIterable = Iterable[Callable[[Any], Callable[..., Any]]]

Call = namedtuple("Call", ["request", "response"])
_real_send = HTTPAdapter.send
Expand Down Expand Up @@ -275,7 +279,7 @@ def _handle_body(

data = BytesIO(body)

def is_closed():
def is_closed() -> bool:
"""
Real Response uses HTTPResponse as body object.
Thus, when method is_closed is called first to check if there is any more
Expand Down Expand Up @@ -305,23 +309,29 @@ def is_closed():


class BaseResponse(object):
passthrough = False
content_type = None
headers = None
stream = False
passthrough: bool = False
content_type: Optional[str] = None
headers: Optional[Mapping[str, str]] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this also be the list of tuples? Or has that been normalized by this point.

headers=(('X-CSRF', 'random-string'), ('Content-Length', 4),)

is the format I'm thinking of.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@markstory
thing becomes a bit more complecated since we use HTTPHeaderDict

it has type:
class HTTPHeaderDict(MutableMapping[str, str]):, see
https://github.com/urllib3/urllib3/blob/e16beb210c03c6f5ce4e0908bddb6556442b6a37/src/urllib3/_collections.py#L211

then in our implementatino here, we extend it:

headers.extend(self.headers)

and extend takes either HTTPHeaderDict, Mapping or Iterable:
https://github.com/urllib3/urllib3/blob/e16beb210c03c6f5ce4e0908bddb6556442b6a37/src/urllib3/_collections.py#L334

so, generally list of tuples should be supportes and standard dict as well

stream: bool = False

def __init__(self, method, url, match_querystring=None, match=()):
self.method = method
def __init__(
self,
method: str,
url: "Union[Pattern[str], str]",
match_querystring: Union[bool, object] = None,
match: "_MatcherIterable" = (),
) -> None:
self.method: str = method
# ensure the url has a default path set if the url is a string
self.url = _ensure_url_default_path(url)
self.url: "Union[Pattern[str], str]" = _ensure_url_default_path(url)

if self._should_match_querystring(match_querystring):
match = tuple(match) + (_query_string_matcher(urlsplit(self.url).query),)

self.match = match
self.call_count = 0
self.match: "_MatcherIterable" = match
self.call_count: int = 0

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if not isinstance(other, BaseResponse):
return False

Expand All @@ -336,10 +346,12 @@ def __eq__(self, other):

return self_url == other_url

def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)

def _should_match_querystring(self, match_querystring_argument):
def _should_match_querystring(
self, match_querystring_argument: Union[bool, object]
) -> Union[bool, object]:
if isinstance(self.url, Pattern):
# the old default from <= 0.9.0
return False
Expand All @@ -358,7 +370,7 @@ def _should_match_querystring(self, match_querystring_argument):

return bool(urlsplit(self.url).query)

def _url_matches(self, url, other):
def _url_matches(self, url: "Union[Pattern[str], str]", other: str) -> bool:
if isinstance(url, str):
if _has_unicode(url):
url = _clean_unicode(url)
Expand All @@ -372,26 +384,28 @@ def _url_matches(self, url, other):
return False

@staticmethod
def _req_attr_matches(match, request):
def _req_attr_matches(
match: "_MatcherIterable", request: "PreparedRequest"
) -> Tuple[bool, str]:
for matcher in match:
valid, reason = matcher(request)
if not valid:
return False, reason

return True, ""

def get_headers(self):
def get_headers(self) -> HTTPHeaderDict:
headers = HTTPHeaderDict() # Duplicate headers are legal
if self.content_type is not None:
headers["Content-Type"] = self.content_type
if self.headers:
headers.extend(self.headers)
return headers

def get_response(self, request):
def get_response(self, request: "PreparedRequest") -> None:
raise NotImplementedError

def matches(self, request):
def matches(self, request: "PreparedRequest") -> Tuple[bool, str]:
if request.method != self.method:
return False, "Method does not match"

Expand All @@ -408,17 +422,17 @@ def matches(self, request):
class Response(BaseResponse):
def __init__(
self,
method,
url,
body="",
json=None,
status=200,
headers=None,
stream=None,
content_type=_UNSET,
auto_calculate_content_length=False,
method: str,
url: "Union[Pattern[str], str]",
body: _Body = "",
json: Optional[Any] = None,
status: int = 200,
headers: Optional[Mapping[str, str]] = None,
stream: bool = None,
content_type: Optional[str] = _UNSET,
auto_calculate_content_length: bool = False,
**kwargs,
):
) -> None:
# if we were passed a `json` argument,
# override the body and content_type
if json is not None:
Expand All @@ -433,22 +447,22 @@ def __init__(
else:
content_type = "text/plain"

self.body = body
self.status = status
self.headers = headers
self.body: _Body = body
self.status: int = status
self.headers: Optional[Mapping[str, str]] = headers

if stream is not None:
warn(
"stream argument is deprecated. Use stream parameter in request directly",
DeprecationWarning,
)

self.stream = stream
self.content_type = content_type
self.auto_calculate_content_length = auto_calculate_content_length
self.stream: bool = stream
self.content_type: Optional[str] = content_type
self.auto_calculate_content_length: bool = auto_calculate_content_length
super().__init__(method, url, **kwargs)

def get_response(self, request):
def get_response(self, request: "PreparedRequest") -> HTTPResponse:
if self.body and isinstance(self.body, Exception):
raise self.body

Expand All @@ -473,7 +487,7 @@ def get_response(self, request):
preload_content=False,
)

def __repr__(self):
def __repr__(self) -> str:
return (
"<Response(url='{url}' status={status} "
"content_type='{content_type}' headers='{headers}')>".format(
Expand All @@ -487,20 +501,26 @@ def __repr__(self):

class CallbackResponse(BaseResponse):
def __init__(
self, method, url, callback, stream=None, content_type="text/plain", **kwargs
):
self,
method: str,
url: "Union[Pattern[str], str]",
callback: Callable[[Any], Any],
stream: bool = None,
content_type: Optional[str] = "text/plain",
**kwargs,
) -> None:
self.callback = callback

if stream is not None:
warn(
"stream argument is deprecated. Use stream parameter in request directly",
DeprecationWarning,
)
self.stream = stream
self.content_type = content_type
self.stream: bool = stream
self.content_type: Optional[str] = content_type
super().__init__(method, url, **kwargs)

def get_response(self, request):
def get_response(self, request: "PreparedRequest") -> HTTPResponse:
headers = self.get_headers()

result = self.callback(request)
Expand Down Expand Up @@ -538,7 +558,7 @@ def get_response(self, request):


class PassthroughResponse(BaseResponse):
passthrough = True
passthrough: bool = True


class OriginalResponseShim(object):
Expand Down