From c0e7698097d905938cd9ecd1a08d9a2c99c4f887 Mon Sep 17 00:00:00 2001 From: Marco Westerhof Date: Fri, 18 Mar 2022 17:04:19 +0100 Subject: [PATCH] Feature/local extensions (#1240) Co-authored-by: Andrey Shpak --- cookiecutter/main.py | 47 ++++++++++----- cookiecutter/utils.py | 13 ++++ docs/advanced/index.rst | 1 + docs/advanced/local_extensions.rst | 60 +++++++++++++++++++ .../local_extension/cookiecutter.json | 10 ++++ .../local_extensions/__init__.py | 1 + .../local_extension/local_extensions/main.py | 21 +++++++ .../{{cookiecutter.project_slug}}/HISTORY.rst | 8 +++ tests/test_cli.py | 31 ++++++++++ 9 files changed, 178 insertions(+), 14 deletions(-) create mode 100644 docs/advanced/local_extensions.rst create mode 100644 tests/test-extensions/local_extension/cookiecutter.json create mode 100644 tests/test-extensions/local_extension/local_extensions/__init__.py create mode 100644 tests/test-extensions/local_extension/local_extensions/main.py create mode 100644 tests/test-extensions/local_extension/{{cookiecutter.project_slug}}/HISTORY.rst diff --git a/cookiecutter/main.py b/cookiecutter/main.py index 7c9eef731..bc2f262df 100644 --- a/cookiecutter/main.py +++ b/cookiecutter/main.py @@ -4,8 +4,10 @@ The code in this module is also a good example of how to use Cookiecutter as a library rather than a script. """ +from copy import copy import logging import os +import sys from cookiecutter.config import get_user_config from cookiecutter.exceptions import InvalidModeException @@ -73,15 +75,17 @@ def cookiecutter( password=password, directory=directory, ) + import_patch = _patch_import_path_for_repo(repo_dir) template_name = os.path.basename(os.path.abspath(repo_dir)) if replay: - if isinstance(replay, bool): - context = load(config_dict['replay_dir'], template_name) - else: - path, template_name = os.path.split(os.path.splitext(replay)[0]) - context = load(path, template_name) + with import_patch: + if isinstance(replay, bool): + context = load(config_dict['replay_dir'], template_name) + else: + path, template_name = os.path.split(os.path.splitext(replay)[0]) + context = load(path, template_name) else: context_file = os.path.join(repo_dir, 'cookiecutter.json') logger.debug('context_file is %s', context_file) @@ -94,7 +98,8 @@ def cookiecutter( # prompt the user to manually configure at the command line. # except when 'no-input' flag is set - context['cookiecutter'] = prompt_for_config(context, no_input) + with import_patch: + context['cookiecutter'] = prompt_for_config(context, no_input) # include template dir or url in the context dict context['cookiecutter']['_template'] = template @@ -105,17 +110,31 @@ def cookiecutter( dump(config_dict['replay_dir'], template_name, context) # Create project from local context and project template. - result = generate_files( - repo_dir=repo_dir, - context=context, - overwrite_if_exists=overwrite_if_exists, - skip_if_file_exists=skip_if_file_exists, - output_dir=output_dir, - accept_hooks=accept_hooks, - ) + with import_patch: + result = generate_files( + repo_dir=repo_dir, + context=context, + overwrite_if_exists=overwrite_if_exists, + skip_if_file_exists=skip_if_file_exists, + output_dir=output_dir, + accept_hooks=accept_hooks, + ) # Cleanup (if required) if cleanup: rmtree(repo_dir) return result + + +class _patch_import_path_for_repo: + def __init__(self, repo_dir): + self._repo_dir = repo_dir + self._path = None + + def __enter__(self): + self._path = copy(sys.path) + sys.path.append(self._repo_dir) + + def __exit__(self, type, value, traceback): + sys.path = self._path diff --git a/cookiecutter/utils.py b/cookiecutter/utils.py index 19b727a52..ef533171a 100644 --- a/cookiecutter/utils.py +++ b/cookiecutter/utils.py @@ -8,6 +8,7 @@ import sys from cookiecutter.prompt import read_user_yes_no +from jinja2.ext import Extension logger = logging.getLogger(__name__) @@ -105,3 +106,15 @@ def prompt_and_delete(path, no_input=False): return False sys.exit() + + +def simple_filter(filter_function): + """Decorate a function to wrap it in a simplified jinja2 extension.""" + + class SimpleFilterExtension(Extension): + def __init__(self, environment): + super(SimpleFilterExtension, self).__init__(environment) + environment.filters[filter_function.__name__] = filter_function + + SimpleFilterExtension.__name__ = filter_function.__name__ + return SimpleFilterExtension diff --git a/docs/advanced/index.rst b/docs/advanced/index.rst index 90e7f2933..66d5faadd 100644 --- a/docs/advanced/index.rst +++ b/docs/advanced/index.rst @@ -23,3 +23,4 @@ Various advanced topics regarding cookiecutter usage. template_extensions directories new_line_characters + local_extensions diff --git a/docs/advanced/local_extensions.rst b/docs/advanced/local_extensions.rst new file mode 100644 index 000000000..bf87caf7b --- /dev/null +++ b/docs/advanced/local_extensions.rst @@ -0,0 +1,60 @@ +.. _`template extensions`: + +Local Extensions +---------------- + +*New in Cookiecutter X.x* + +A template may extend the Cookiecutter environment with local extensions. +These can be part of the template itself, providing it with more sophisticated custom tags and filters. + +To do so, a template author must specify the required extensions in ``cookiecutter.json`` as follows: + +.. code-block:: json + + { + "project_slug": "Foobar", + "year": "{% now 'utc', '%Y' %}", + "_extensions": ["local_extensions.FoobarExtension"] + } + +This example assumes that a ``local_extensions`` folder (python module) exists in the template root. +It will contain a ``main.py`` file, containing the following (for instance): + +.. code-block:: python + + # -*- coding: utf-8 -*- + + from jinja2.ext import Extension + + + class FoobarExtension(Extension): + def __init__(self, environment): + super(FoobarExtension, self).__init__(environment) + environment.filters['foobar'] = lambda v: v * 2 + +This will register the ``foobar`` filter for the template. + +For many cases, this will be unneccessarily complicated. It's likely that we'd only want to register a single function +as a filter. For this, we can use the ``simple_filter`` decorator: + +.. code-block:: json + + { + "project_slug": "Foobar", + "year": "{% now 'utc', '%Y' %}", + "_extensions": ["local_extensions.simplefilterextension"] + } + +.. code-block:: python + + # -*- coding: utf-8 -*- + + from cookiecutter.utils import simple_filter + + + @simple_filter + def simplefilterextension(v): + return v * 2 + +This snippet will achieve the exact same result as the previous one. diff --git a/tests/test-extensions/local_extension/cookiecutter.json b/tests/test-extensions/local_extension/cookiecutter.json new file mode 100644 index 000000000..8141fd508 --- /dev/null +++ b/tests/test-extensions/local_extension/cookiecutter.json @@ -0,0 +1,10 @@ +{ + "project_slug": "Foobar", + "test_value_class_based": "{{cookiecutter.project_slug | foobar}}", + "test_value_function_based": "{{cookiecutter.project_slug | simplefilterextension}}", + "_extensions": [ + "local_extensions.simplefilterextension", + "local_extensions.FoobarExtension" + ] +} + diff --git a/tests/test-extensions/local_extension/local_extensions/__init__.py b/tests/test-extensions/local_extension/local_extensions/__init__.py new file mode 100644 index 000000000..94e854abd --- /dev/null +++ b/tests/test-extensions/local_extension/local_extensions/__init__.py @@ -0,0 +1 @@ +from .main import FoobarExtension, simplefilterextension # noqa diff --git a/tests/test-extensions/local_extension/local_extensions/main.py b/tests/test-extensions/local_extension/local_extensions/main.py new file mode 100644 index 000000000..53f6f8f95 --- /dev/null +++ b/tests/test-extensions/local_extension/local_extensions/main.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- + +"""Provides custom extension, exposing a ``foobar`` filter.""" + +from jinja2.ext import Extension +from cookiecutter.utils import simple_filter + + +class FoobarExtension(Extension): + """Simple jinja2 extension for cookiecutter test purposes.""" + + def __init__(self, environment): + """Foobar Extension Constructor.""" + super(FoobarExtension, self).__init__(environment) + environment.filters['foobar'] = lambda v: v * 2 + + +@simple_filter +def simplefilterextension(v): + """Provide a simple function-based filter extension.""" + return v.upper() diff --git a/tests/test-extensions/local_extension/{{cookiecutter.project_slug}}/HISTORY.rst b/tests/test-extensions/local_extension/{{cookiecutter.project_slug}}/HISTORY.rst new file mode 100644 index 000000000..8bb7c6136 --- /dev/null +++ b/tests/test-extensions/local_extension/{{cookiecutter.project_slug}}/HISTORY.rst @@ -0,0 +1,8 @@ +History +------- + +0.1.0 +----- + +First release of {{cookiecutter.test_value_class_based}} on PyPI. +{{cookiecutter.test_value_function_based}} diff --git a/tests/test_cli.py b/tests/test_cli.py index 19740ef26..ad6abd1e0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -10,6 +10,8 @@ from cookiecutter import utils from cookiecutter.__main__ import main +from cookiecutter.environment import StrictEnvironment +from cookiecutter.exceptions import UnknownExtension from cookiecutter.main import cookiecutter @@ -412,6 +414,35 @@ def test_echo_unknown_extension_error(output_dir, cli_runner): assert 'Unable to load extension: ' in result.output +def test_local_extension(tmpdir, cli_runner): + """Test to verify correct work of extension, included in template.""" + output_dir = str(tmpdir.mkdir('output')) + template_path = 'tests/test-extensions/local_extension/' + + result = cli_runner( + '--no-input', + '--default-config', + '--output-dir', + output_dir, + template_path, + ) + assert result.exit_code == 0 + with open(os.path.join(output_dir, 'Foobar', 'HISTORY.rst')) as f: + data = f.read() + assert 'FoobarFoobar' in data + assert 'FOOBAR' in data + + +def test_local_extension_not_available(tmpdir, cli_runner): + """Test handling of included but unavailable local extension.""" + context = {'cookiecutter': {'_extensions': ['foobar']}} + + with pytest.raises(UnknownExtension) as err: + StrictEnvironment(context=context, keep_trailing_newline=True) + + assert 'Unable to load extension: ' in str(err.value) + + @pytest.mark.usefixtures('remove_fake_project_dir') def test_cli_extra_context(cli_runner): """Cli invocation replace content if called with replacement pairs."""