Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use load_replay_file directly when file is directly provided. #1491

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
5 changes: 2 additions & 3 deletions cookiecutter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from cookiecutter.exceptions import InvalidModeException
from cookiecutter.generate import generate_context, generate_files
from cookiecutter.prompt import prompt_for_config
from cookiecutter.replay import dump, load
from cookiecutter.replay import dump, load, load_replay_file
from cookiecutter.repository import determine_repo_dir
from cookiecutter.utils import rmtree

Expand Down Expand Up @@ -79,8 +79,7 @@ def cookiecutter(
if isinstance(replay, bool):
context = load(config_dict['replay_dir'], template_name)
else:
path, template_name = os.path.split(os.path.splitext(replay)[0])
context = load(path, template_name)
context = load_replay_file(replay)
else:
context_file = os.path.join(repo_dir, 'cookiecutter.json')
logger.debug('context_file is %s', context_file)
Expand Down
18 changes: 11 additions & 7 deletions cookiecutter/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,21 @@ def dump(replay_dir, template_name, context):
json.dump(context, outfile, indent=2)


def load(replay_dir, template_name):
"""Read json data from file."""
if not isinstance(template_name, str):
raise TypeError('Template name is required to be of type str')

replay_file = get_file_name(replay_dir, template_name)

def load_replay_file(replay_file):
"""Read cookiecutter's parameter values from replay file."""
with open(replay_file, 'r') as infile:
context = json.load(infile)

if 'cookiecutter' not in context:
raise ValueError('Context is required to contain a cookiecutter key')

return context


def load(replay_dir, template_name):
"""Read json data from file."""
if not isinstance(template_name, str):
raise TypeError('Template name is required to be of type str')

replay_file = get_file_name(replay_dir, template_name)
return load_replay_file(replay_file)
13 changes: 8 additions & 5 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Collection of tests around cookiecutter's replay feature."""
from cookiecutter.main import cookiecutter
from cookiecutter import replay


def test_replay_dump_template_name(
Expand Down Expand Up @@ -58,13 +59,15 @@ def test_custom_replay_file(monkeypatch, mocker, user_config_file):
"""Check that reply.load is called with the custom replay_file."""
monkeypatch.chdir('tests/fake-repo-tmpl')

mock_replay_load = mocker.patch('cookiecutter.main.load')
mock_replay_load = mocker.patch(
'cookiecutter.main.load_replay_file', side_effect=replay.load_replay_file
)
mocker.patch('cookiecutter.main.generate_files')

cookiecutter(
'.', replay='./custom-replay-file', config_file=user_config_file,
'.',
replay='../test-replay/cookiedozer_load.json',
config_file=user_config_file,
)

mock_replay_load.assert_called_once_with(
'.', 'custom-replay-file',
)
mock_replay_load.assert_called_once_with('../test-replay/cookiedozer_load.json',)