Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanitizing search entry titles #3560

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 37 additions & 35 deletions mkdocs/contrib/search/search_index.py
Expand Up @@ -10,7 +10,6 @@

if TYPE_CHECKING:
from mkdocs.structure.pages import Page
from mkdocs.structure.toc import AnchorLink, TableOfContents

try:
from lunr import lunr # type: ignore
Expand All @@ -20,6 +19,7 @@
haslunrpy = False

log = logging.getLogger(__name__)
WHITESPACE_RE = re.compile(r'\s+')


class SearchIndex:
Expand All @@ -32,19 +32,6 @@ def __init__(self, **config) -> None:
self._entries: list[dict] = []
self.config = config

def _find_toc_by_id(self, toc, id_: str | None) -> AnchorLink | None:
"""
Given a table of contents and HTML ID, iterate through
and return the matched item in the TOC.
"""
for toc_item in toc:
if toc_item.id == id_:
return toc_item
toc_item_r = self._find_toc_by_id(toc_item.children, id_)
if toc_item_r is not None:
return toc_item_r
return None

def _add_entry(self, title: str | None, text: str, loc: str) -> None:
"""A simple wrapper to add an entry, dropping bad characters."""
text = text.replace('\u00a0', ' ')
Expand Down Expand Up @@ -76,21 +63,17 @@ def add_entry_from_context(self, page: Page) -> None:

if self.config['indexing'] in ['full', 'sections']:
for section in parser.data:
self.create_entry_for_section(section, page.toc, url)
self.create_entry_for_section(section, url)

def create_entry_for_section(
self, section: ContentSection, toc: TableOfContents, abs_url: str
) -> None:
def create_entry_for_section(self, section: ContentSection, abs_url: str) -> None:
"""
Given a section on the page, the table of contents and
the absolute url for the page create an entry in the
index.
Given a section of a page and the absolute url for the page
create an entry in the index.
"""
toc_item = self._find_toc_by_id(toc, section.id)

text = ' '.join(section.text) if self.config['indexing'] == 'full' else ''
if toc_item is not None:
self._add_entry(title=toc_item.title, text=text, loc=abs_url + toc_item.url)
if section.id is not None:
text = ' '.join(section.text) if self.config['indexing'] == 'full' else ''
title = WHITESPACE_RE.sub(' ', section.title).strip()
self._add_entry(title=title, text=text, loc=f'{abs_url}#{section.id}')

def generate_search_index(self) -> str:
"""Python to json conversion."""
Expand Down Expand Up @@ -153,11 +136,14 @@ def __init__(
) -> None:
self.text = text or []
self.id = id_
self.title = title
self.title = title or ''

def __eq__(self, other):
return self.text == other.text and self.id == other.id and self.title == other.title

def __repr__(self):
return f"{type(self).__name__}(text={self.text!r}, id={self.id!r}, title={self.title!r}"


_HEADER_TAGS = tuple(f"h{x}" for x in range(1, 7))

Expand All @@ -175,26 +161,37 @@ def __init__(self, *args, **kwargs) -> None:
self.data: list[ContentSection] = []
self.section: ContentSection | None = None
self.is_header_tag = False
self.is_permalink = False
self._stripped_html: list[str] = []

def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
"""Called at the start of every HTML tag."""
# Check for permalink in header
if self.is_header_tag and tag == 'a' and len(attrs):
atts = dict(attrs)
if 'headerlink' in (atts.get('class') or '').split():
self.is_permalink = True
return

# We only care about the opening tag for headings.
if tag not in _HEADER_TAGS:
return

# We are dealing with a new header, create a new section
# for it and assign the ID if it has one.
atts = dict(attrs)
self.is_header_tag = True
self.section = ContentSection()
self.section.id = atts.get('id')
self.data.append(self.section)

for attr in attrs:
if attr[0] == "id":
self.section.id = attr[1]

def handle_endtag(self, tag: str) -> None:
"""Called at the end of every HTML tag."""
# Check for permalinks
if self.is_permalink and tag == 'a':
self.is_permalink = False
return

# We only care about the opening tag for headings.
if tag not in _HEADER_TAGS:
return
Expand All @@ -203,6 +200,10 @@ def handle_endtag(self, tag: str) -> None:

def handle_data(self, data: str) -> None:
"""Called for the text contents of each tag."""
# Do not retain permalink text.
if self.is_permalink:
return

self._stripped_html.append(data)

if self.section is None:
Expand All @@ -212,12 +213,13 @@ def handle_data(self, data: str) -> None:
# overall page entry in the search. So just skip it.
return

# If this is a header, then the data is the title.
# Otherwise it is content of something under that header
# section.
if self.is_header_tag:
self.section.title = data
# Write text data to title, being sure not to overwrite
# text from previous children of heading. Text data
# retains its whitespace and so none is added here.
self.section.title += data
else:
# Write text data for elements under a heading section.
self.section.text.append(data.rstrip('\n'))

@property
Expand Down
64 changes: 37 additions & 27 deletions mkdocs/tests/search_tests.py
Expand Up @@ -9,8 +9,7 @@
from mkdocs.contrib.search import search_index
from mkdocs.structure.files import File
from mkdocs.structure.pages import Page
from mkdocs.structure.toc import get_toc
from mkdocs.tests.base import dedent, get_markdown_toc, load_config
from mkdocs.tests.base import dedent, load_config


def strip_whitespace(string):
Expand Down Expand Up @@ -283,7 +282,19 @@ def test_content_parser(self):
parser.close()

self.assertEqual(
parser.data, [search_index.ContentSection(text=["TEST"], id_="title", title="Title")]
parser.data,
[search_index.ContentSection(text=["TEST"], id_="title", title="Title")],
)

def test_content_parser_header_has_child(self):
parser = search_index.ContentParser()

parser.feed('<h1 id="title">Title <span>title</span> TITLE</h1>TEST')
parser.close()

self.assertEqual(
parser.data,
[search_index.ContentSection(text=["TEST"], id_="title", title="Title title TITLE")],
)

def test_content_parser_no_id(self):
Expand All @@ -293,7 +304,8 @@ def test_content_parser_no_id(self):
parser.close()

self.assertEqual(
parser.data, [search_index.ContentSection(text=["TEST"], id_=None, title="Title")]
parser.data,
[search_index.ContentSection(text=["TEST"], id_=None, title="Title")],
)

def test_content_parser_content_before_header(self):
Expand All @@ -303,7 +315,8 @@ def test_content_parser_content_before_header(self):
parser.close()

self.assertEqual(
parser.data, [search_index.ContentSection(text=["TEST"], id_=None, title="Title")]
parser.data,
[search_index.ContentSection(text=["TEST"], id_=None, title="Title")],
)

def test_content_parser_no_sections(self):
Expand All @@ -313,30 +326,30 @@ def test_content_parser_no_sections(self):

self.assertEqual(parser.data, [])

def test_find_toc_by_id(self):
"""Test finding the relevant TOC item by the tag ID."""
index = search_index.SearchIndex()
def test_content_parser_with_permalink(self):
parser = search_index.ContentParser()

md = dedent(
"""
# Heading 1
## Heading 2
### Heading 3
"""
parser.feed(
'<h1 id="title">Title<a class="headerlink" href="#title" title="Permanent link">&para;</a></h1>TEST'
)
toc = get_toc(get_markdown_toc(md))
parser.close()

toc_item = index._find_toc_by_id(toc, "heading-1")
self.assertEqual(toc_item.url, "#heading-1")
self.assertEqual(toc_item.title, "Heading 1")
self.assertEqual(
parser.data,
[search_index.ContentSection(text=["TEST"], id_="title", title="Title")],
)

toc_item2 = index._find_toc_by_id(toc, "heading-2")
self.assertEqual(toc_item2.url, "#heading-2")
self.assertEqual(toc_item2.title, "Heading 2")
def test_content_parser_with_nonpermalink(self):
parser = search_index.ContentParser()

# Ensure only the whole class name is being matched.
parser.feed('<h1 id="title">Title <a class="fooheaderlink" href="#">title</a></h1>TEST')
parser.close()

toc_item3 = index._find_toc_by_id(toc, "heading-3")
self.assertEqual(toc_item3.url, "#heading-3")
self.assertEqual(toc_item3.title, "Heading 3")
self.assertEqual(
parser.data,
[search_index.ContentSection(text=["TEST"], id_="title", title="Title title")],
)

def test_create_search_index(self):
html_content = """
Expand Down Expand Up @@ -369,7 +382,6 @@ def test_create_search_index(self):
### Heading 3
"""
)
toc = get_toc(get_markdown_toc(md))

full_content = ''.join(f"Heading{i}Content{i}" for i in range(1, 4))

Expand All @@ -379,7 +391,6 @@ def test_create_search_index(self):
for page in pages:
# Fake page.read_source() and page.render()
page.markdown = md
page.toc = toc
page.content = html_content

index = search_index.SearchIndex(**plugin.config)
Expand Down Expand Up @@ -425,7 +436,6 @@ def test_page(title, filename, config):
## Heading 2
### Heading 3"""
)
test_page.toc = get_toc(get_markdown_toc(test_page.markdown))
return test_page

def validate_full(data, page):
Expand Down