Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/local extensions #1240

Merged
merged 11 commits into from Mar 18, 2022
19 changes: 18 additions & 1 deletion cookiecutter/environment.py
Expand Up @@ -28,7 +28,11 @@ def __init__(self, **kwargs):
'cookiecutter.extensions.SlugifyExtension',
'jinja2_time.TimeExtension',
]
extensions = default_extensions + self._read_extensions(context)
extensions = (
default_extensions
+ self._read_extensions(context)
+ self._read_local_extensions(context)
)

try:
super(ExtensionLoaderMixin, self).__init__(extensions=extensions, **kwargs)
Expand All @@ -48,6 +52,19 @@ def _read_extensions(self, context):
else:
return [str(ext) for ext in extensions]

def _read_local_extensions(self, context):
"""Return list of extension modules in the template, to be passed on to the Jinja2 env.

If context does not contain the relevant info, return an empty
list instead.
"""
try:
extensions = context['cookiecutter']['_local_extensions']
except KeyError:
return []
else:
return [str(ext) for ext in extensions]


class StrictEnvironment(ExtensionLoaderMixin, Environment):
"""Create strict Jinja2 environment.
Expand Down
48 changes: 34 additions & 14 deletions cookiecutter/main.py
Expand Up @@ -4,8 +4,11 @@
The code in this module is also a good example of how to use Cookiecutter as a
library rather than a script.
"""
from __future__ import unicode_literals

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for adding this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, it seems I added this because I needed it for an initial change. It's been over 1.5 years, and I guess I forgot about that. I'll remove it

from copy import copy
import logging
import os
import sys

from cookiecutter.config import get_user_config
from cookiecutter.exceptions import InvalidModeException
Expand Down Expand Up @@ -72,15 +75,17 @@ def cookiecutter(
password=password,
directory=directory,
)
import_patch = _patch_import_path_for_repo(repo_dir)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like it would add all the python files at the same level as the template directory to the import path. That might cause unexpected issues for people. Would it be better to limit the scope here to just the local extensions directory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While it makes some sense, I'm not sure it would be the best way to go.
The general idea with the current implementation is that template extension loading will essentially continue to work the same way it always has, except that they will also be loaded from local sources. While the provided example shows extensions in the "local_extensions" folder, I'm not suggesting that the extensions need to be contained in that specific folder, or even that extensions should all exist in the same folder.

I'm not convinced that it's necessary to limit the scope, but I'll make the change if I haven't changed your mind. I assume that adding another config keyword doesn't cause any problems? In that case I would suggest changing this:

    {
        "project_slug": "Foobar",
        "year": "{% now 'utc', '%Y' %}",
        "_extensions": ["local_extensions.FoobarExtension"]
    }

to

    {
        "project_slug": "Foobar",
        "year": "{% now 'utc', '%Y' %}",
        "_local_extensions_folders": ["local_extensions"],
        "_extensions": ["FoobarExtension"]
    }

any thoughts?


template_name = os.path.basename(os.path.abspath(repo_dir))

if replay:
if isinstance(replay, bool):
context = load(config_dict['replay_dir'], template_name)
else:
path, template_name = os.path.split(os.path.splitext(replay)[0])
context = load(path, template_name)
with import_patch:
if isinstance(replay, bool):
context = load(config_dict['replay_dir'], template_name)
else:
path, template_name = os.path.split(os.path.splitext(replay)[0])
context = load(path, template_name)
else:
context_file = os.path.join(repo_dir, 'cookiecutter.json')
logger.debug('context_file is %s', context_file)
Expand All @@ -93,7 +98,8 @@ def cookiecutter(

# prompt the user to manually configure at the command line.
# except when 'no-input' flag is set
context['cookiecutter'] = prompt_for_config(context, no_input)
with import_patch:
context['cookiecutter'] = prompt_for_config(context, no_input)

# include template dir or url in the context dict
context['cookiecutter']['_template'] = template
Expand All @@ -104,17 +110,31 @@ def cookiecutter(
dump(config_dict['replay_dir'], template_name, context)

# Create project from local context and project template.
result = generate_files(
repo_dir=repo_dir,
context=context,
overwrite_if_exists=overwrite_if_exists,
skip_if_file_exists=skip_if_file_exists,
output_dir=output_dir,
accept_hooks=accept_hooks,
)
with import_patch:
result = generate_files(
repo_dir=repo_dir,
context=context,
overwrite_if_exists=overwrite_if_exists,
skip_if_file_exists=skip_if_file_exists,
output_dir=output_dir,
accept_hooks=accept_hooks,
)

# Cleanup (if required)
if cleanup:
rmtree(repo_dir)

return result


class _patch_import_path_for_repo:
def __init__(self, repo_dir):
self._repo_dir = repo_dir
self._path = None

def __enter__(self):
self._path = copy(sys.path)
sys.path.append(self._repo_dir)

def __exit__(self, type, value, traceback):
sys.path = self._path
13 changes: 13 additions & 0 deletions cookiecutter/utils.py
Expand Up @@ -8,6 +8,7 @@
import sys

from cookiecutter.prompt import read_user_yes_no
from jinja2.ext import Extension

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -105,3 +106,15 @@ def prompt_and_delete(path, no_input=False):
return False

sys.exit()


def simple_filter(filter_function):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😍

"""Decorate a function to wrap it in a simplified jinja2 extension."""

class SimpleFilterExtension(Extension):
def __init__(self, environment):
super(SimpleFilterExtension, self).__init__(environment)
environment.filters[filter_function.__name__] = filter_function

SimpleFilterExtension.__name__ = filter_function.__name__
return SimpleFilterExtension
1 change: 1 addition & 0 deletions docs/advanced/index.rst
Expand Up @@ -23,3 +23,4 @@ Various advanced topics regarding cookiecutter usage.
template_extensions
directories
new_line_characters
local_extensions
60 changes: 60 additions & 0 deletions docs/advanced/local_extensions.rst
@@ -0,0 +1,60 @@
.. _`template extensions`:

Local Extensions
----------------

*New in Cookiecutter X.x*

A template may extend the Cookiecutter environment with local extensions.
These can be part of the template itself, providing it with more sophisticated custom tags and filters.

To do so, a template author must specify the required extensions in ``cookiecutter.json`` as follows:

.. code-block:: json

{
"project_slug": "Foobar",
"year": "{% now 'utc', '%Y' %}",
"_local_extensions": ["local_extensions.FoobarExtension"]
}

This example assumes that a ``local_extensions`` folder (python module) exists in the template root.
It will contain a ``main.py`` file, containing the following (for instance):
Comment on lines +21 to +22

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I only need one or two extensions and they're not that complicated, I can just define them directly in a module local_extensions.py, correct? Maybe instead of having the documentation suggest a more-complicated package setup, it could suggest simply that you have a module local_extensions available in the template root with extensions defined, with no reference to a folder or main.py. That would leave it up to the user to use a standalone module or a package structure depending on their needs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that seems fair. i'll make the change

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I gather reading the code you are adding the path of the local directory. So in effect you could create a folder that is named extensions that contains an __init__.py file. In that file you'll have the extensions and you'd refer to that in the cookiecutter.json file as 'extensions.MyExtension' and it should all work.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be updated to reflect the fact that we are adding the folder to sys path instead and the extension names would be derived from the modules/paths in the repo dir instead? The main.py may throw people off.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm honestly not sure what it is you're referring to. This is meant to be an example involving a python module as a "container" for extensions, and that's literally what is being explained here.

Maybe if we're indeed going with the "local extension folder config option" idea, then maybe this example could be simplified a bit, but the basic principle is kept as simple as possible while still being pythonic, and the changes in main.py are only to facilitate the extension loading and make it all possible

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

main.py must be __init__.py


.. code-block:: python

# -*- coding: utf-8 -*-

from jinja2.ext import Extension


class FoobarExtension(Extension):
def __init__(self, environment):
super(FoobarExtension, self).__init__(environment)
environment.filters['foobar'] = lambda v: v * 2

This will register the ``foobar`` filter for the template.

For many cases, this will be unneccessarily complicated. It's likely that we'd only want to register a single function
as a filter. For this, we can use the ``simple_filter`` decorator:

.. code-block:: json

{
"project_slug": "Foobar",
"year": "{% now 'utc', '%Y' %}",
"_local_extensions": ["local_extensions.foobarextension"]
}

.. code-block:: python

# -*- coding: utf-8 -*-

from cookiecutter.utils import simple_filter


@simple_filter
def simplefilterextension(v):
return v * 2

This snippet will achieve the exact same result as the previous one.
10 changes: 10 additions & 0 deletions tests/test-extensions/local_extension/cookiecutter.json
@@ -0,0 +1,10 @@
{
"project_slug": "Foobar",
"test_value_class_based": "{{cookiecutter.project_slug | foobar}}",
"test_value_function_based": "{{cookiecutter.project_slug | simplefilterextension}}",
"_local_extensions": [
"local_extensions.simplefilterextension",
"local_extensions.FoobarExtension"
]
}

@@ -0,0 +1 @@
from .main import FoobarExtension, simplefilterextension # noqa
21 changes: 21 additions & 0 deletions tests/test-extensions/local_extension/local_extensions/main.py
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-

"""Provides custom extension, exposing a ``foobar`` filter."""

from jinja2.ext import Extension
from cookiecutter.utils import simple_filter


class FoobarExtension(Extension):
"""Simple jinja2 extension for cookiecutter test purposes."""

def __init__(self, environment):
"""Foobar Extension Constructor."""
super(FoobarExtension, self).__init__(environment)
environment.filters['foobar'] = lambda v: v * 2


@simple_filter
def simplefilterextension(v):
"""Provide a simple function-based filter extension."""
return v.upper()
@@ -0,0 +1,8 @@
History
-------

0.1.0
-----

First release of {{cookiecutter.test_value_class_based}} on PyPI.
{{cookiecutter.test_value_function_based}}
27 changes: 27 additions & 0 deletions tests/test_cli.py
Expand Up @@ -9,6 +9,8 @@

from cookiecutter import utils
from cookiecutter.__main__ import main
from cookiecutter.environment import StrictEnvironment
from cookiecutter.exceptions import UnknownExtension
from cookiecutter.main import cookiecutter


Expand Down Expand Up @@ -414,6 +416,31 @@ def test_echo_unknown_extension_error(tmpdir, cli_runner):
assert 'Unable to load extension: ' in result.output


def test_local_extension(tmpdir, cli_runner):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be good to include tests for non-happy path? I can think of two cases, but maybe there are more:

  • I include a filter in my template, but didn't include the extension in the _local_extensions list
  • I include a filter in my template and extension in the _local_extensions list, but haven't defined that extension (or it's somehow not available to be imported)
    • It looks like there's a similar test for built-in Jinja extensions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good idea. i'll pick that up ASAP

"""Test to verify correct work of extension, included in template."""
output_dir = str(tmpdir.mkdir('output'))
template_path = 'tests/test-extensions/local_extension/'

result = cli_runner(
'--no-input', '--default-config', '--output-dir', output_dir, template_path,
)
assert result.exit_code == 0
with open(os.path.join(output_dir, 'Foobar', 'HISTORY.rst')) as f:
data = f.read()
assert 'FoobarFoobar' in data
assert 'FOOBAR' in data


def test_local_extension_not_available(tmpdir, cli_runner):
"""Test handling of included but unavailable local extension."""
context = {'cookiecutter': {'_local_extensions': ['foobar']}}

with pytest.raises(UnknownExtension) as err:
StrictEnvironment(context=context, keep_trailing_newline=True)

assert 'Unable to load extension: ' in str(err.value)


@pytest.mark.usefixtures('remove_fake_project_dir')
def test_cli_extra_context(cli_runner):
"""Cli invocation replace content if called with replacement pairs."""
Expand Down