Skip to content

Commit

Permalink
Feature/local extensions (#1240)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrey Shpak <ashpak@ashpak.ru>
  • Loading branch information
mwesterhof and insspb committed Mar 18, 2022
1 parent 0b40625 commit c0e7698
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 14 deletions.
47 changes: 33 additions & 14 deletions cookiecutter/main.py
Expand Up @@ -4,8 +4,10 @@
The code in this module is also a good example of how to use Cookiecutter as a
library rather than a script.
"""
from copy import copy
import logging
import os
import sys

from cookiecutter.config import get_user_config
from cookiecutter.exceptions import InvalidModeException
Expand Down Expand Up @@ -73,15 +75,17 @@ def cookiecutter(
password=password,
directory=directory,
)
import_patch = _patch_import_path_for_repo(repo_dir)

template_name = os.path.basename(os.path.abspath(repo_dir))

if replay:
if isinstance(replay, bool):
context = load(config_dict['replay_dir'], template_name)
else:
path, template_name = os.path.split(os.path.splitext(replay)[0])
context = load(path, template_name)
with import_patch:
if isinstance(replay, bool):
context = load(config_dict['replay_dir'], template_name)
else:
path, template_name = os.path.split(os.path.splitext(replay)[0])
context = load(path, template_name)
else:
context_file = os.path.join(repo_dir, 'cookiecutter.json')
logger.debug('context_file is %s', context_file)
Expand All @@ -94,7 +98,8 @@ def cookiecutter(

# prompt the user to manually configure at the command line.
# except when 'no-input' flag is set
context['cookiecutter'] = prompt_for_config(context, no_input)
with import_patch:
context['cookiecutter'] = prompt_for_config(context, no_input)

# include template dir or url in the context dict
context['cookiecutter']['_template'] = template
Expand All @@ -105,17 +110,31 @@ def cookiecutter(
dump(config_dict['replay_dir'], template_name, context)

# Create project from local context and project template.
result = generate_files(
repo_dir=repo_dir,
context=context,
overwrite_if_exists=overwrite_if_exists,
skip_if_file_exists=skip_if_file_exists,
output_dir=output_dir,
accept_hooks=accept_hooks,
)
with import_patch:
result = generate_files(
repo_dir=repo_dir,
context=context,
overwrite_if_exists=overwrite_if_exists,
skip_if_file_exists=skip_if_file_exists,
output_dir=output_dir,
accept_hooks=accept_hooks,
)

# Cleanup (if required)
if cleanup:
rmtree(repo_dir)

return result


class _patch_import_path_for_repo:
def __init__(self, repo_dir):
self._repo_dir = repo_dir
self._path = None

def __enter__(self):
self._path = copy(sys.path)
sys.path.append(self._repo_dir)

def __exit__(self, type, value, traceback):
sys.path = self._path
13 changes: 13 additions & 0 deletions cookiecutter/utils.py
Expand Up @@ -8,6 +8,7 @@
import sys

from cookiecutter.prompt import read_user_yes_no
from jinja2.ext import Extension

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -105,3 +106,15 @@ def prompt_and_delete(path, no_input=False):
return False

sys.exit()


def simple_filter(filter_function):
"""Decorate a function to wrap it in a simplified jinja2 extension."""

class SimpleFilterExtension(Extension):
def __init__(self, environment):
super(SimpleFilterExtension, self).__init__(environment)
environment.filters[filter_function.__name__] = filter_function

SimpleFilterExtension.__name__ = filter_function.__name__
return SimpleFilterExtension
1 change: 1 addition & 0 deletions docs/advanced/index.rst
Expand Up @@ -23,3 +23,4 @@ Various advanced topics regarding cookiecutter usage.
template_extensions
directories
new_line_characters
local_extensions
60 changes: 60 additions & 0 deletions docs/advanced/local_extensions.rst
@@ -0,0 +1,60 @@
.. _`template extensions`:

Local Extensions
----------------

*New in Cookiecutter X.x*

A template may extend the Cookiecutter environment with local extensions.
These can be part of the template itself, providing it with more sophisticated custom tags and filters.

To do so, a template author must specify the required extensions in ``cookiecutter.json`` as follows:

.. code-block:: json
{
"project_slug": "Foobar",
"year": "{% now 'utc', '%Y' %}",
"_extensions": ["local_extensions.FoobarExtension"]
}
This example assumes that a ``local_extensions`` folder (python module) exists in the template root.
It will contain a ``main.py`` file, containing the following (for instance):

This comment has been minimized.

Copy link
@alkatar21

alkatar21 May 25, 2022

Contributor

The description is wrong here, because not the main.py is searched, but a module as far as I can see. In the test also the examples are provided via the __init__.py.

Furthermore, for small extensions a local_extensions.py would also work, this is also mentioned in #1240 (comment), but it was not adapted in the PR.


.. code-block:: python
# -*- coding: utf-8 -*-
from jinja2.ext import Extension
class FoobarExtension(Extension):
def __init__(self, environment):
super(FoobarExtension, self).__init__(environment)
environment.filters['foobar'] = lambda v: v * 2
This will register the ``foobar`` filter for the template.

For many cases, this will be unneccessarily complicated. It's likely that we'd only want to register a single function
as a filter. For this, we can use the ``simple_filter`` decorator:

.. code-block:: json
{
"project_slug": "Foobar",
"year": "{% now 'utc', '%Y' %}",
"_extensions": ["local_extensions.simplefilterextension"]
}
.. code-block:: python
# -*- coding: utf-8 -*-
from cookiecutter.utils import simple_filter
@simple_filter
def simplefilterextension(v):
return v * 2
This snippet will achieve the exact same result as the previous one.
10 changes: 10 additions & 0 deletions tests/test-extensions/local_extension/cookiecutter.json
@@ -0,0 +1,10 @@
{
"project_slug": "Foobar",
"test_value_class_based": "{{cookiecutter.project_slug | foobar}}",
"test_value_function_based": "{{cookiecutter.project_slug | simplefilterextension}}",
"_extensions": [
"local_extensions.simplefilterextension",
"local_extensions.FoobarExtension"
]
}

@@ -0,0 +1 @@
from .main import FoobarExtension, simplefilterextension # noqa
21 changes: 21 additions & 0 deletions tests/test-extensions/local_extension/local_extensions/main.py
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-

"""Provides custom extension, exposing a ``foobar`` filter."""

from jinja2.ext import Extension
from cookiecutter.utils import simple_filter


class FoobarExtension(Extension):
"""Simple jinja2 extension for cookiecutter test purposes."""

def __init__(self, environment):
"""Foobar Extension Constructor."""
super(FoobarExtension, self).__init__(environment)
environment.filters['foobar'] = lambda v: v * 2


@simple_filter
def simplefilterextension(v):
"""Provide a simple function-based filter extension."""
return v.upper()
@@ -0,0 +1,8 @@
History
-------

0.1.0
-----

First release of {{cookiecutter.test_value_class_based}} on PyPI.
{{cookiecutter.test_value_function_based}}
31 changes: 31 additions & 0 deletions tests/test_cli.py
Expand Up @@ -10,6 +10,8 @@

from cookiecutter import utils
from cookiecutter.__main__ import main
from cookiecutter.environment import StrictEnvironment
from cookiecutter.exceptions import UnknownExtension
from cookiecutter.main import cookiecutter


Expand Down Expand Up @@ -412,6 +414,35 @@ def test_echo_unknown_extension_error(output_dir, cli_runner):
assert 'Unable to load extension: ' in result.output


def test_local_extension(tmpdir, cli_runner):
"""Test to verify correct work of extension, included in template."""
output_dir = str(tmpdir.mkdir('output'))
template_path = 'tests/test-extensions/local_extension/'

result = cli_runner(
'--no-input',
'--default-config',
'--output-dir',
output_dir,
template_path,
)
assert result.exit_code == 0
with open(os.path.join(output_dir, 'Foobar', 'HISTORY.rst')) as f:
data = f.read()
assert 'FoobarFoobar' in data
assert 'FOOBAR' in data


def test_local_extension_not_available(tmpdir, cli_runner):
"""Test handling of included but unavailable local extension."""
context = {'cookiecutter': {'_extensions': ['foobar']}}

with pytest.raises(UnknownExtension) as err:
StrictEnvironment(context=context, keep_trailing_newline=True)

assert 'Unable to load extension: ' in str(err.value)


@pytest.mark.usefixtures('remove_fake_project_dir')
def test_cli_extra_context(cli_runner):
"""Cli invocation replace content if called with replacement pairs."""
Expand Down

1 comment on commit c0e7698

@frankhuurman
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we already use local extensions after pip installing cookiecutter?

Please sign in to comment.