Skip to content

Commit

Permalink
Full typing for scrapy/spiders. (#6356)
Browse files Browse the repository at this point in the history
  • Loading branch information
wRAR committed May 13, 2024
1 parent 4ed5c5a commit b8e333c
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 89 deletions.
6 changes: 3 additions & 3 deletions scrapy/spiders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Spider(object_ref):

def __init__(self, name: Optional[str] = None, **kwargs: Any):
if name is not None:
self.name = name
self.name: str = name
elif not getattr(self, "name", None):
raise ValueError(f"{type(self).__name__} must have a name")
self.__dict__.update(kwargs)
Expand Down Expand Up @@ -67,8 +67,8 @@ def from_crawler(cls, crawler: Crawler, *args: Any, **kwargs: Any) -> Self:
return spider

def _set_crawler(self, crawler: Crawler) -> None:
self.crawler = crawler
self.settings = crawler.settings
self.crawler: Crawler = crawler
self.settings: BaseSettings = crawler.settings
crawler.signals.connect(self.close, signals.spider_closed)

def start_requests(self) -> Iterable[Request]:
Expand Down
140 changes: 96 additions & 44 deletions scrapy/spiders/crawl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,27 @@
from __future__ import annotations

import copy
from typing import TYPE_CHECKING, AsyncIterable, Awaitable, Sequence
from typing import (
TYPE_CHECKING,
Any,
AsyncIterable,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
TypeVar,
Union,
cast,
)

from twisted.python.failure import Failure

from scrapy.http import HtmlResponse, Request, Response
from scrapy.link import Link
from scrapy.linkextractors import LinkExtractor
from scrapy.spiders import Spider
from scrapy.utils.asyncgen import collect_asyncgen
Expand All @@ -20,20 +38,32 @@
# typing.Self requires Python 3.11
from typing_extensions import Self

from scrapy.crawler import Crawler

def _identity(x):

_T = TypeVar("_T")
ProcessLinksT = Callable[[List[Link]], List[Link]]
ProcessRequestT = Callable[[Request, Response], Optional[Request]]


def _identity(x: _T) -> _T:
return x


def _identity_process_request(request, response):
def _identity_process_request(
request: Request, response: Response
) -> Optional[Request]:
return request


def _get_method(method, spider):
def _get_method(
method: Union[Callable, str, None], spider: Spider
) -> Optional[Callable]:
if callable(method):
return method
if isinstance(method, str):
return getattr(spider, method, None)
return None


_default_link_extractor = LinkExtractor()
Expand All @@ -42,84 +72,104 @@ def _get_method(method, spider):
class Rule:
def __init__(
self,
link_extractor=None,
callback=None,
cb_kwargs=None,
follow=None,
process_links=None,
process_request=None,
errback=None,
link_extractor: Optional[LinkExtractor] = None,
callback: Union[Callable, str, None] = None,
cb_kwargs: Optional[Dict[str, Any]] = None,
follow: Optional[bool] = None,
process_links: Union[ProcessLinksT, str, None] = None,
process_request: Union[ProcessRequestT, str, None] = None,
errback: Union[Callable[[Failure], Any], str, None] = None,
):
self.link_extractor = link_extractor or _default_link_extractor
self.callback = callback
self.errback = errback
self.cb_kwargs = cb_kwargs or {}
self.process_links = process_links or _identity
self.process_request = process_request or _identity_process_request
self.follow = follow if follow is not None else not callback

def _compile(self, spider):
self.link_extractor: LinkExtractor = link_extractor or _default_link_extractor
self.callback: Union[Callable, str, None] = callback
self.errback: Union[Callable[[Failure], Any], str, None] = errback
self.cb_kwargs: Dict[str, Any] = cb_kwargs or {}
self.process_links: Union[ProcessLinksT, str] = process_links or _identity
self.process_request: Union[ProcessRequestT, str] = (
process_request or _identity_process_request
)
self.follow: bool = follow if follow is not None else not callback

def _compile(self, spider: Spider) -> None:
# this replaces method names with methods and we can't express this in type hints
self.callback = _get_method(self.callback, spider)
self.errback = _get_method(self.errback, spider)
self.process_links = _get_method(self.process_links, spider)
self.process_request = _get_method(self.process_request, spider)
self.errback = cast(Callable[[Failure], Any], _get_method(self.errback, spider))
self.process_links = cast(
ProcessLinksT, _get_method(self.process_links, spider)
)
self.process_request = cast(
ProcessRequestT, _get_method(self.process_request, spider)
)


class CrawlSpider(Spider):
rules: Sequence[Rule] = ()
_rules: List[Rule]
_follow_links: bool

def __init__(self, *a, **kw):
def __init__(self, *a: Any, **kw: Any):
super().__init__(*a, **kw)
self._compile_rules()

def _parse(self, response, **kwargs):
def _parse(self, response: Response, **kwargs: Any) -> Any:
return self._parse_response(
response=response,
callback=self.parse_start_url,
cb_kwargs=kwargs,
follow=True,
)

def parse_start_url(self, response, **kwargs):
def parse_start_url(self, response: Response, **kwargs: Any) -> Any:
return []

def process_results(self, response: Response, results: list):
def process_results(self, response: Response, results: Any) -> Any:
return results

def _build_request(self, rule_index, link):
def _build_request(self, rule_index: int, link: Link) -> Request:
return Request(
url=link.url,
callback=self._callback,
errback=self._errback,
meta={"rule": rule_index, "link_text": link.text},
)

def _requests_to_follow(self, response):
def _requests_to_follow(self, response: Response) -> Iterable[Optional[Request]]:
if not isinstance(response, HtmlResponse):
return
seen = set()
seen: Set[Link] = set()
for rule_index, rule in enumerate(self._rules):
links = [
links: List[Link] = [
lnk
for lnk in rule.link_extractor.extract_links(response)
if lnk not in seen
]
for link in rule.process_links(links):
for link in cast(ProcessLinksT, rule.process_links)(links):
seen.add(link)
request = self._build_request(rule_index, link)
yield rule.process_request(request, response)
yield cast(ProcessRequestT, rule.process_request)(request, response)

def _callback(self, response, **cb_kwargs):
rule = self._rules[response.meta["rule"]]
def _callback(self, response: Response, **cb_kwargs: Any) -> Any:
rule = self._rules[cast(int, response.meta["rule"])]
return self._parse_response(
response, rule.callback, {**rule.cb_kwargs, **cb_kwargs}, rule.follow
response,
cast(Callable, rule.callback),
{**rule.cb_kwargs, **cb_kwargs},
rule.follow,
)

def _errback(self, failure):
rule = self._rules[failure.request.meta["rule"]]
return self._handle_failure(failure, rule.errback)
def _errback(self, failure: Failure) -> Iterable[Any]:
rule = self._rules[cast(int, failure.request.meta["rule"])] # type: ignore[attr-defined]
return self._handle_failure(
failure, cast(Callable[[Failure], Any], rule.errback)
)

async def _parse_response(self, response, callback, cb_kwargs, follow=True):
async def _parse_response(
self,
response: Response,
callback: Optional[Callable],
cb_kwargs: Dict[str, Any],
follow: bool = True,
) -> AsyncIterable[Any]:
if callback:
cb_res = callback(response, **cb_kwargs) or ()
if isinstance(cb_res, AsyncIterable):
Expand All @@ -134,21 +184,23 @@ async def _parse_response(self, response, callback, cb_kwargs, follow=True):
for request_or_item in self._requests_to_follow(response):
yield request_or_item

def _handle_failure(self, failure, errback):
def _handle_failure(
self, failure: Failure, errback: Optional[Callable[[Failure], Any]]
) -> Iterable[Any]:
if errback:
results = errback(failure) or ()
yield from iterate_spider_output(results)

def _compile_rules(self):
def _compile_rules(self) -> None:
self._rules = []
for rule in self.rules:
self._rules.append(copy.copy(rule))
self._rules[-1]._compile(self)

@classmethod
def from_crawler(cls, crawler, *args, **kwargs) -> Self:
def from_crawler(cls, crawler: Crawler, *args: Any, **kwargs: Any) -> Self:
spider = super().from_crawler(crawler, *args, **kwargs)
spider._follow_links = crawler.settings.getbool( # type: ignore[attr-defined]
spider._follow_links = crawler.settings.getbool(
"CRAWLSPIDER_FOLLOW_LINKS", True
)
return spider
48 changes: 30 additions & 18 deletions scrapy/spiders/feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
See documentation in docs/topics/spiders.rst
"""

from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

from scrapy.exceptions import NotConfigured, NotSupported
from scrapy.http import Response, TextResponse
from scrapy.selector import Selector
from scrapy.spiders import Spider
from scrapy.utils.iterators import csviter, xmliter_lxml
Expand All @@ -22,11 +25,13 @@ class XMLFeedSpider(Spider):
use iternodes, since it's a faster and cleaner.
"""

iterator = "iternodes"
itertag = "item"
namespaces = ()
iterator: str = "iternodes"
itertag: str = "item"
namespaces: Sequence[Tuple[str, str]] = ()

def process_results(self, response, results):
def process_results(
self, response: Response, results: Iterable[Any]
) -> Iterable[Any]:
"""This overridable method is called for each result (item or request)
returned by the spider, and it's intended to perform any last time
processing required before returning the results to the framework core,
Expand All @@ -36,20 +41,20 @@ def process_results(self, response, results):
"""
return results

def adapt_response(self, response):
def adapt_response(self, response: Response) -> Response:
"""You can override this function in order to make any changes you want
to into the feed before parsing it. This function must return a
response.
"""
return response

def parse_node(self, response, selector):
def parse_node(self, response: Response, selector: Selector) -> Any:
"""This method must be overridden with your custom spider functionality"""
if hasattr(self, "parse_item"): # backward compatibility
return self.parse_item(response, selector)
raise NotImplementedError

def parse_nodes(self, response, nodes):
def parse_nodes(self, response: Response, nodes: Iterable[Selector]) -> Any:
"""This method is called for the nodes matching the provided tag name
(itertag). Receives the response and an Selector for each node.
Overriding this method is mandatory. Otherwise, you spider won't work.
Expand All @@ -61,20 +66,25 @@ def parse_nodes(self, response, nodes):
ret = iterate_spider_output(self.parse_node(response, selector))
yield from self.process_results(response, ret)

def _parse(self, response, **kwargs):
def _parse(self, response: Response, **kwargs: Any) -> Any:
if not hasattr(self, "parse_node"):
raise NotConfigured(
"You must define parse_node method in order to scrape this XML feed"
)

response = self.adapt_response(response)
nodes: Iterable[Selector]
if self.iterator == "iternodes":
nodes = self._iternodes(response)
elif self.iterator == "xml":
if not isinstance(response, TextResponse):
raise ValueError("Response content isn't text")
selector = Selector(response, type="xml")
self._register_namespaces(selector)
nodes = selector.xpath(f"//{self.itertag}")
elif self.iterator == "html":
if not isinstance(response, TextResponse):
raise ValueError("Response content isn't text")
selector = Selector(response, type="html")
self._register_namespaces(selector)
nodes = selector.xpath(f"//{self.itertag}")
Expand All @@ -83,12 +93,12 @@ def _parse(self, response, **kwargs):

return self.parse_nodes(response, nodes)

def _iternodes(self, response):
def _iternodes(self, response: Response) -> Iterable[Selector]:
for node in xmliter_lxml(response, self.itertag):
self._register_namespaces(node)
yield node

def _register_namespaces(self, selector):
def _register_namespaces(self, selector: Selector) -> None:
for prefix, uri in self.namespaces:
selector.register_namespace(prefix, uri)

Expand All @@ -102,27 +112,29 @@ class CSVFeedSpider(Spider):
and the file's headers.
"""

delimiter = (
delimiter: Optional[str] = (
None # When this is None, python's csv module's default delimiter is used
)
quotechar = (
quotechar: Optional[str] = (
None # When this is None, python's csv module's default quotechar is used
)
headers = None
headers: Optional[List[str]] = None

def process_results(self, response, results):
def process_results(
self, response: Response, results: Iterable[Any]
) -> Iterable[Any]:
"""This method has the same purpose as the one in XMLFeedSpider"""
return results

def adapt_response(self, response):
def adapt_response(self, response: Response) -> Response:
"""This method has the same purpose as the one in XMLFeedSpider"""
return response

def parse_row(self, response, row):
def parse_row(self, response: Response, row: Dict[str, str]) -> Any:
"""This method must be overridden with your custom spider functionality"""
raise NotImplementedError

def parse_rows(self, response):
def parse_rows(self, response: Response) -> Any:
"""Receives a response and a dict (representing each row) with a key for
each provided (or detected) header of the CSV file. This spider also
gives the opportunity to override adapt_response and
Expand All @@ -135,7 +147,7 @@ def parse_rows(self, response):
ret = iterate_spider_output(self.parse_row(response, row))
yield from self.process_results(response, ret)

def _parse(self, response, **kwargs):
def _parse(self, response: Response, **kwargs: Any) -> Any:
if not hasattr(self, "parse_row"):
raise NotConfigured(
"You must define parse_row method in order to scrape this CSV feed"
Expand Down

0 comments on commit b8e333c

Please sign in to comment.