Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
kdeldycke committed Jun 1, 2023
1 parent d597c55 commit ed7ced2
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 17 deletions.
339 changes: 333 additions & 6 deletions click_extra/testing.py
Expand Up @@ -18,16 +18,29 @@
from __future__ import annotations

import os
import sys
import shlex
import subprocess
from pathlib import Path
from textwrap import indent
from typing import Iterable, Mapping, Optional, Union, cast, IO, Any, Mapping, Optional,ContextManager, Literal
from typing import Iterable, Mapping, Optional, Union, cast, IO, Any, Type, Mapping, BinaryIO, Optional,ContextManager, Literal, Sequence, Iterator, Tuple
import os
import inspect
from functools import partial
from pathlib import Path
from contextlib import nullcontext
from unittest.mock import patch
import io
import contextlib
import io
import os
import shlex
import sys
from types import TracebackType

from click import formatting
from click import termui
from click import utils

import click
import click.testing
Expand Down Expand Up @@ -146,19 +159,128 @@ def run_cmd(*args, extra_env: EnvVars | None = None, print_output: bool = True):
We need to collect them to help us identify which extra parameters passed to
``invoke()`` collides with its original signature.
.. note::
.. warning::
This has been `reported upstream to Click project
<https://github.com/pallets/click/issues/2110>`_ but was not considered an issue.
<https://github.com/pallets/click/issues/2110>`_ but has been rejected and not
considered an issue worth fixing.
"""


class BytesIOCopy(io.BytesIO):
"""Patch ``io.BytesIO`` to let the written stream be copied to another.
.. caution::
This has been `proposed upstream to Click project
<https://github.com/pallets/click/pull/2523>`_ but has not been merged yet.
"""

def __init__(self, copy_to: io.BytesIO) -> None:
super().__init__()
self.copy_to = copy_to

def flush(self) -> None:
super().flush()
self.copy_to.flush()

def write(self, b) -> int:
self.copy_to.write(b)
return super().write(b)


class StreamMixer:
"""Mixes ``<stdout>`` and ``<stderr>`` streams if ``mix_stderr=True``.
The result is available in the ``output`` attribute.
If ``mix_stderr=False``, the ``<stdout>`` and ``<stderr>`` streams are kept
independent and the ``output`` is the same as the ``<stdout>`` stream.
.. caution::
This has been `proposed upstream to Click project
<https://github.com/pallets/click/pull/2523>`_ but has not been merged yet.
"""

def __init__(self, mix_stderr: bool) -> None:
if not mix_stderr:
self.stdout = io.BytesIO()
self.stderr = io.BytesIO()
self.output = self.stdout

else:
self.output = io.BytesIO()
self.stdout = BytesIOCopy(copy_to=self.output)
self.stderr = BytesIOCopy(copy_to=self.output)


class ExtraResult(click.testing.Result):
"""Like ``click.testing.Result``, with finer ``<stdout>`` and ``<stderr>`` streams.
.. caution::
This has been `proposed upstream to Click project
<https://github.com/pallets/click/pull/2523>`_ but has not been merged yet.
"""

def __init__(
self,
runner: click.testing.CliRunner,
stdout_bytes: bytes,
stderr_bytes: bytes,
output_bytes: bytes,
return_value: Any,
exit_code: int,
exception: Optional[BaseException],
exc_info: Optional[
Tuple[Type[BaseException], BaseException, TracebackType]
] = None,
):
"""Same as original but adds ``output_bytes`` parameter.
Also makes ``stderr_bytes`` mandatory.
"""
self.output_bytes = output_bytes
super().__init__(
runner=runner,
stdout_bytes=stdout_bytes,
stderr_bytes=stderr_bytes,
return_value=return_value,
exit_code=exit_code,
exception=exception,
exc_info=exc_info,
)

@property
def output(self) -> str:
"""The terminal output as unicode string, as the user would see it.
.. caution::
Contrary to original ``click.testing.Result.output``, it is not a proxy for
``self.stdout``. It now possess its own stream to mix ``<stdout>`` and
``<stderr>`` depending on the ``mix_stderr`` value.
"""
return self.output_bytes.decode(self.runner.charset, "replace").replace(
"\r\n", "\n"
)

@property
def stderr(self) -> str:
"""The standard error as unicode string.
.. caution::
Contrary to original ``click.testing.Result.stderr``, it no longer raise an
exception, and always returns the ``<stderr>`` string.
"""
return self.stderr_bytes.decode(self.runner.charset, "replace").replace(
"\r\n", "\n"
)


class ExtraCliRunner(click.testing.CliRunner):
"""Augment ``click.testing.CliRunner`` with extra features and bug fixes."""

force_color: bool = False
"""Global class attribute to override the ``color`` parameter in ``invoke``.
.. note::
.. info::
This was initially developed to `force the initialization of the runner during
the setup of Sphinx new directives <sphinx#click_extra.sphinx.setup>`_. This
was the only way we found, as to patch some code we had to operate at the class
Expand All @@ -183,6 +305,210 @@ def __init__(
mix_stderr=mix_stderr
)

@contextlib.contextmanager
def isolation(
self,
input: Optional[Union[str, bytes, IO[Any]]] = None,
env: Optional[Mapping[str, Optional[str]]] = None,
color: bool = False,
) -> Iterator[Tuple[io.BytesIO, io.BytesIO, io.BytesIO]]:
"""Copy of ``click.testing.CliRunner.isolation()`` with extra features.
- An additional output stream is returned, which is a mix of ``<stdout>`` and
``<stderr>`` streams if ``mix_stderr=True``.
- Always returns the ``<stderr>`` stream.
.. caution::
This is a hard-copy of the modified ``isolation()`` method `from click#2523 PR
<https://github.com/pallets/click/pull/2523/files#diff-b07fd6fad9f9ea8be5cbcbeaf34c956703b929b2de95c56229e77c328a7c6010>`_
which has not been merged upstream yet.
"""
bytes_input = click.testing.make_input_stream(input, self.charset)
echo_input = None

old_stdin = sys.stdin
old_stdout = sys.stdout
old_stderr = sys.stderr
old_forced_width = formatting.FORCED_WIDTH
formatting.FORCED_WIDTH = 80

env = self.make_env(env)

stream_mixer = StreamMixer(mix_stderr=self.mix_stderr)

if self.echo_stdin:
bytes_input = echo_input = cast(
BinaryIO, click.testing.EchoingStdin(bytes_input, stream_mixer.stdout)
)

sys.stdin = text_input = click.testing._NamedTextIOWrapper(
bytes_input, encoding=self.charset, name="<stdin>", mode="r"
)

if self.echo_stdin:
# Force unbuffered reads, otherwise TextIOWrapper reads a
# large chunk which is echoed early.
text_input._CHUNK_SIZE = 1 # type: ignore

sys.stdout = click.testing._NamedTextIOWrapper(
stream_mixer.stdout, encoding=self.charset, name="<stdout>", mode="w"
)

sys.stderr = click.testing._NamedTextIOWrapper(
stream_mixer.stderr,
encoding=self.charset,
name="<stderr>",
mode="w",
errors="backslashreplace",
)

@click.testing._pause_echo(echo_input)
def visible_input(prompt: Optional[str] = None) -> str:
sys.stdout.write(prompt or "")
val = text_input.readline().rstrip("\r\n")
sys.stdout.write(f"{val}\n")
sys.stdout.flush()
return val

@click.testing._pause_echo(echo_input)
def hidden_input(prompt: Optional[str] = None) -> str:
sys.stdout.write(f"{prompt or ''}\n")
sys.stdout.flush()
return text_input.readline().rstrip("\r\n")

@click.testing._pause_echo(echo_input)
def _getchar(echo: bool) -> str:
char = sys.stdin.read(1)

if echo:
sys.stdout.write(char)

sys.stdout.flush()
return char

default_color = color

def should_strip_ansi(
stream: Optional[IO[Any]] = None, color: Optional[bool] = None
) -> bool:
if color is None:
return not default_color
return not color

old_visible_prompt_func = termui.visible_prompt_func
old_hidden_prompt_func = termui.hidden_prompt_func
old__getchar_func = termui._getchar
old_should_strip_ansi = utils.should_strip_ansi
termui.visible_prompt_func = visible_input
termui.hidden_prompt_func = hidden_input
termui._getchar = _getchar
utils.should_strip_ansi = should_strip_ansi

old_env = {}
try:
for key, value in env.items():
old_env[key] = os.environ.get(key)
if value is None:
try:
del os.environ[key]
except Exception:
pass
else:
os.environ[key] = value
yield (stream_mixer.stdout, stream_mixer.stderr, stream_mixer.output)
finally:
for key, value in old_env.items():
if value is None:
try:
del os.environ[key]
except Exception:
pass
else:
os.environ[key] = value
sys.stdout = old_stdout
sys.stderr = old_stderr
sys.stdin = old_stdin
termui.visible_prompt_func = old_visible_prompt_func
termui.hidden_prompt_func = old_hidden_prompt_func
termui._getchar = old__getchar_func
utils.should_strip_ansi = old_should_strip_ansi
formatting.FORCED_WIDTH = old_forced_width

def invoke2(
self,
cli: click.core.BaseCommand,
args: Optional[Union[str, Sequence[str]]] = None,
input: Optional[Union[str, bytes, IO[Any]]] = None,
env: Optional[Mapping[str, Optional[str]]] = None,
catch_exceptions: bool = True,
color: bool = False,
**extra: Any,
) -> click.testing.Result:
"""Copy of ``click.testing.CliRunner.invoke()`` with
.. caution::
This is a hard-copy of the modified ``invoke()`` method `from click#2523 PR
<https://github.com/pallets/click/pull/2523/files#diff-b07fd6fad9f9ea8be5cbcbeaf34c956703b929b2de95c56229e77c328a7c6010>`_
which has not been merged upstream yet.
"""
exc_info = None
with self.isolation(input=input, env=env, color=color) as outstreams:
return_value = None
exception: Optional[BaseException] = None
exit_code = 0

if isinstance(args, str):
args = shlex.split(args)

try:
prog_name = extra.pop("prog_name")
except KeyError:
prog_name = self.get_default_prog_name(cli)

try:
return_value = cli.main(args=args or (), prog_name=prog_name, **extra)
except SystemExit as e:
exc_info = sys.exc_info()
e_code = cast(Optional[Union[int, Any]], e.code)

if e_code is None:
e_code = 0

if e_code != 0:
exception = e

if not isinstance(e_code, int):
sys.stdout.write(str(e_code))
sys.stdout.write("\n")
e_code = 1

exit_code = e_code

except Exception as e:
if not catch_exceptions:
raise
exception = e
exit_code = 1
exc_info = sys.exc_info()
finally:
sys.stdout.flush()
stdout = outstreams[0].getvalue()
stderr = outstreams[1].getvalue()
output = outstreams[2].getvalue()

return ExtraResult(
runner=self,
stdout_bytes=stdout,
stderr_bytes=stderr,
output_bytes=output,
return_value=return_value,
exit_code=exit_code,
exception=exception,
exc_info=exc_info, # type: ignore
)


def invoke(
self,
cli: click.core.BaseCommand,
Expand Down Expand Up @@ -278,7 +604,7 @@ def invoke(
extra_params_bypass = patch.object(cli, "main", partial(cli.main, **extra_bypass))

with extra_params_bypass:
result = super().invoke(
result = self.invoke2(
cli=cli,
args=args,
input=input,
Expand All @@ -292,11 +618,12 @@ def invoke(
if color is False:
result.stdout_bytes = strip_ansi(result.stdout_bytes)
result.stderr_bytes = strip_ansi(result.stderr_bytes)
result.output_bytes = strip_ansi(result.output_bytes)

print_cli_run(
[self.get_default_prog_name(cli)] + list(args),
result.output,
result.stderr if result.stderr_bytes is not None else "",
result.stderr,
result.exit_code,
)

Expand Down

0 comments on commit ed7ced2

Please sign in to comment.