diff --git a/cookiecutter/main.py b/cookiecutter/main.py index 047d30a08..bf1e82369 100644 --- a/cookiecutter/main.py +++ b/cookiecutter/main.py @@ -4,8 +4,10 @@ The code in this module is also a good example of how to use Cookiecutter as a library rather than a script. """ +from copy import copy import logging import os +import sys from cookiecutter.config import get_user_config from cookiecutter.exceptions import InvalidModeException @@ -72,15 +74,17 @@ def cookiecutter( password=password, directory=directory, ) + import_patch = _patch_import_path_for_repo(repo_dir) template_name = os.path.basename(os.path.abspath(repo_dir)) if replay: - if isinstance(replay, bool): - context = load(config_dict['replay_dir'], template_name) - else: - path, template_name = os.path.split(os.path.splitext(replay)[0]) - context = load(path, template_name) + with import_patch: + if isinstance(replay, bool): + context = load(config_dict['replay_dir'], template_name) + else: + path, template_name = os.path.split(os.path.splitext(replay)[0]) + context = load(path, template_name) else: context_file = os.path.join(repo_dir, 'cookiecutter.json') logger.debug('context_file is %s', context_file) @@ -93,7 +97,8 @@ def cookiecutter( # prompt the user to manually configure at the command line. # except when 'no-input' flag is set - context['cookiecutter'] = prompt_for_config(context, no_input) + with import_patch: + context['cookiecutter'] = prompt_for_config(context, no_input) # include template dir or url in the context dict context['cookiecutter']['_template'] = template @@ -104,17 +109,31 @@ def cookiecutter( dump(config_dict['replay_dir'], template_name, context) # Create project from local context and project template. - result = generate_files( - repo_dir=repo_dir, - context=context, - overwrite_if_exists=overwrite_if_exists, - skip_if_file_exists=skip_if_file_exists, - output_dir=output_dir, - accept_hooks=accept_hooks, - ) + with import_patch: + result = generate_files( + repo_dir=repo_dir, + context=context, + overwrite_if_exists=overwrite_if_exists, + skip_if_file_exists=skip_if_file_exists, + output_dir=output_dir, + accept_hooks=accept_hooks, + ) # Cleanup (if required) if cleanup: rmtree(repo_dir) return result + + +class _patch_import_path_for_repo: + def __init__(self, repo_dir): + self._repo_dir = repo_dir + self._path = None + + def __enter__(self): + self._path = copy(sys.path) + sys.path.append(self._repo_dir) + + def __exit__(self, type, value, traceback): + sys.path = self._path diff --git a/cookiecutter/utils.py b/cookiecutter/utils.py index 5935a249c..a13236a98 100644 --- a/cookiecutter/utils.py +++ b/cookiecutter/utils.py @@ -8,6 +8,7 @@ import sys from cookiecutter.prompt import read_user_yes_no +from jinja2.ext import Extension logger = logging.getLogger(__name__) @@ -105,3 +106,15 @@ def prompt_and_delete(path, no_input=False): return False sys.exit() + + +def simple_filter(filter_function): + """Decorate a function to wrap it in a simplified jinja2 extension.""" + + class SimpleFilterExtension(Extension): + def __init__(self, environment): + super(SimpleFilterExtension, self).__init__(environment) + environment.filters[filter_function.__name__] = filter_function + + SimpleFilterExtension.__name__ = filter_function.__name__ + return SimpleFilterExtension diff --git a/docs/advanced/index.rst b/docs/advanced/index.rst index 90e7f2933..66d5faadd 100644 --- a/docs/advanced/index.rst +++ b/docs/advanced/index.rst @@ -23,3 +23,4 @@ Various advanced topics regarding cookiecutter usage. template_extensions directories new_line_characters + local_extensions diff --git a/docs/advanced/local_extensions.rst b/docs/advanced/local_extensions.rst new file mode 100644 index 000000000..bf87caf7b --- /dev/null +++ b/docs/advanced/local_extensions.rst @@ -0,0 +1,60 @@ +.. _`template extensions`: + +Local Extensions +---------------- + +*New in Cookiecutter X.x* + +A template may extend the Cookiecutter environment with local extensions. +These can be part of the template itself, providing it with more sophisticated custom tags and filters. + +To do so, a template author must specify the required extensions in ``cookiecutter.json`` as follows: + +.. code-block:: json + + { + "project_slug": "Foobar", + "year": "{% now 'utc', '%Y' %}", + "_extensions": ["local_extensions.FoobarExtension"] + } + +This example assumes that a ``local_extensions`` folder (python module) exists in the template root. +It will contain a ``main.py`` file, containing the following (for instance): + +.. code-block:: python + + # -*- coding: utf-8 -*- + + from jinja2.ext import Extension + + + class FoobarExtension(Extension): + def __init__(self, environment): + super(FoobarExtension, self).__init__(environment) + environment.filters['foobar'] = lambda v: v * 2 + +This will register the ``foobar`` filter for the template. + +For many cases, this will be unneccessarily complicated. It's likely that we'd only want to register a single function +as a filter. For this, we can use the ``simple_filter`` decorator: + +.. code-block:: json + + { + "project_slug": "Foobar", + "year": "{% now 'utc', '%Y' %}", + "_extensions": ["local_extensions.simplefilterextension"] + } + +.. code-block:: python + + # -*- coding: utf-8 -*- + + from cookiecutter.utils import simple_filter + + + @simple_filter + def simplefilterextension(v): + return v * 2 + +This snippet will achieve the exact same result as the previous one. diff --git a/tests/test-extensions/local_extension/cookiecutter.json b/tests/test-extensions/local_extension/cookiecutter.json new file mode 100644 index 000000000..8141fd508 --- /dev/null +++ b/tests/test-extensions/local_extension/cookiecutter.json @@ -0,0 +1,10 @@ +{ + "project_slug": "Foobar", + "test_value_class_based": "{{cookiecutter.project_slug | foobar}}", + "test_value_function_based": "{{cookiecutter.project_slug | simplefilterextension}}", + "_extensions": [ + "local_extensions.simplefilterextension", + "local_extensions.FoobarExtension" + ] +} + diff --git a/tests/test-extensions/local_extension/local_extensions/__init__.py b/tests/test-extensions/local_extension/local_extensions/__init__.py new file mode 100644 index 000000000..94e854abd --- /dev/null +++ b/tests/test-extensions/local_extension/local_extensions/__init__.py @@ -0,0 +1 @@ +from .main import FoobarExtension, simplefilterextension # noqa diff --git a/tests/test-extensions/local_extension/local_extensions/main.py b/tests/test-extensions/local_extension/local_extensions/main.py new file mode 100644 index 000000000..53f6f8f95 --- /dev/null +++ b/tests/test-extensions/local_extension/local_extensions/main.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- + +"""Provides custom extension, exposing a ``foobar`` filter.""" + +from jinja2.ext import Extension +from cookiecutter.utils import simple_filter + + +class FoobarExtension(Extension): + """Simple jinja2 extension for cookiecutter test purposes.""" + + def __init__(self, environment): + """Foobar Extension Constructor.""" + super(FoobarExtension, self).__init__(environment) + environment.filters['foobar'] = lambda v: v * 2 + + +@simple_filter +def simplefilterextension(v): + """Provide a simple function-based filter extension.""" + return v.upper() diff --git a/tests/test-extensions/local_extension/{{cookiecutter.project_slug}}/HISTORY.rst b/tests/test-extensions/local_extension/{{cookiecutter.project_slug}}/HISTORY.rst new file mode 100644 index 000000000..8bb7c6136 --- /dev/null +++ b/tests/test-extensions/local_extension/{{cookiecutter.project_slug}}/HISTORY.rst @@ -0,0 +1,8 @@ +History +------- + +0.1.0 +----- + +First release of {{cookiecutter.test_value_class_based}} on PyPI. +{{cookiecutter.test_value_function_based}} diff --git a/tests/test_cli.py b/tests/test_cli.py index a2d8accee..705ce1e64 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,6 +9,8 @@ from cookiecutter import utils from cookiecutter.__main__ import main +from cookiecutter.environment import StrictEnvironment +from cookiecutter.exceptions import UnknownExtension from cookiecutter.main import cookiecutter @@ -414,6 +416,35 @@ def test_echo_unknown_extension_error(tmpdir, cli_runner): assert 'Unable to load extension: ' in result.output +def test_local_extension(tmpdir, cli_runner): + """Test to verify correct work of extension, included in template.""" + output_dir = str(tmpdir.mkdir('output')) + template_path = 'tests/test-extensions/local_extension/' + + result = cli_runner( + '--no-input', + '--default-config', + '--output-dir', + output_dir, + template_path, + ) + assert result.exit_code == 0 + with open(os.path.join(output_dir, 'Foobar', 'HISTORY.rst')) as f: + data = f.read() + assert 'FoobarFoobar' in data + assert 'FOOBAR' in data + + +def test_local_extension_not_available(tmpdir, cli_runner): + """Test handling of included but unavailable local extension.""" + context = {'cookiecutter': {'_extensions': ['foobar']}} + + with pytest.raises(UnknownExtension) as err: + StrictEnvironment(context=context, keep_trailing_newline=True) + + assert 'Unable to load extension: ' in str(err.value) + + @pytest.mark.usefixtures('remove_fake_project_dir') def test_cli_extra_context(cli_runner): """Cli invocation replace content if called with replacement pairs."""