Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ENH add wandb support for benchopt run #520

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ __pycache__
/doc/generated
/examples/benchmark_*
.coverage
/wandb/

# Output files
**/outputs
Expand Down
7 changes: 6 additions & 1 deletion benchopt/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class _Callback:
Contains objective and data names, problem dimension, etc.
stopping_criterion : instance of StoppingCriterion
Object to check if we need to stop a solver.
wandb_cb : callable or None
Callback to log

Attributes
----------
Expand Down Expand Up @@ -44,10 +46,11 @@ class _Callback:
The time when exiting the callback call.
"""

def __init__(self, objective, meta, stopping_criterion):
def __init__(self, objective, meta, stopping_criterion, wandb_cb=None):
self.objective = objective
self.meta = meta
self.stopping_criterion = stopping_criterion
self.wandb_cb = wandb_cb

# Initialize local variables
self.info = get_sys_info()
Expand Down Expand Up @@ -85,6 +88,8 @@ def log_value(self, x):
time=self.time_iter,
**objective_dict, **self.info
))
if self.wandb_cb is not None:
self.wandb_cb(self.curve[-1])

# Check the stopping criterion
should_stop_res = self.stopping_criterion.should_stop(
Expand Down
11 changes: 9 additions & 2 deletions benchopt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _get_run_args(cli_kwargs, config_file_kwargs):
"plot",
"html",
"pdb",
"wandb",
"profile",
"env_name",
"output",
Expand Down Expand Up @@ -138,6 +139,11 @@ def _get_run_args(cli_kwargs, config_file_kwargs):
is_flag=True,
help="Launch a debugger if there is an error. This will launch "
"ipdb if it is installed and default to pdb otherwise.")
@click.option('--wandb',
is_flag=True,
help="Log the results in WandB. This option requires having "
"installed wandb. See the wandb documentation: "
"https://wandb.ai/quickstart.")
@click.option('--local', '-l', 'env_name',
flag_value='False', default=True,
help="Run the benchmark in the local conda environment.")
Expand Down Expand Up @@ -179,7 +185,7 @@ def run(config_file=None, **kwargs):
(
benchmark, solver_names, forced_solvers, dataset_names,
objective_filters, max_runs, n_repetitions, timeout, n_jobs, slurm,
plot, html, pdb, do_profile, env_name, output,
plot, html, pdb, wandb, do_profile, env_name, output,
deprecated_objective_filters, old_objective_filters
) = _get_run_args(kwargs, config)

Expand Down Expand Up @@ -247,7 +253,7 @@ def run(config_file=None, **kwargs):
objective_filters=objective_filters,
max_runs=max_runs, n_repetitions=n_repetitions,
timeout=timeout, n_jobs=n_jobs, slurm=slurm,
plot_result=plot, html=html, pdb=pdb,
plot_result=plot, html=html, pdb=pdb, wandb=wandb,
output=output
)

Expand Down Expand Up @@ -324,6 +330,7 @@ def run(config_file=None, **kwargs):
rf"{'--plot' if plot else '--no-plot'} "
rf"{'--html' if html else '--no-html'} "
rf"{'--pdb' if pdb else ''} "
rf"{'--wandb' if wandb else ''} "
rf"--output {output}"
.replace('\\', '\\\\')
)
Expand Down
34 changes: 26 additions & 8 deletions benchopt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .utils.sys_info import get_sys_info
from .utils.files import uniquify_results
from .utils.pdb_helpers import exception_handler
from .utils.logging.wandb import wandb_ctx
from .utils.terminal_output import TerminalOutput

##################################
Expand Down Expand Up @@ -58,7 +59,7 @@ def run_one_resolution(objective, solver, meta, stop_val):


def run_one_to_cvg(benchmark, objective, solver, meta, stopping_criterion,
force=False, output=None, pdb=False):
force=False, output=None, pdb=False, wandb=False):
"""Run all repetitions of the solver for a value of stopping criterion.

Parameters
Expand All @@ -79,6 +80,10 @@ def run_one_to_cvg(benchmark, objective, solver, meta, stopping_criterion,
for the solver anyway. Else, use the cache if available.
pdb : bool
It pdb is set to True, open a debugger on error.
wandb : bool
It wandb is set to True, send the results to a wandb page. This option
needs to have wandb installed and configured. See the wandb
documentation: https://wandb.ai/quickstart.

Returns
-------
Expand All @@ -89,7 +94,8 @@ def run_one_to_cvg(benchmark, objective, solver, meta, stopping_criterion,
"""

curve = []
with exception_handler(output, pdb=pdb) as ctx:
with (exception_handler(output, pdb=pdb) as ctx,
wandb_ctx(meta=meta, wandb=wandb) as wandb_cb):
tomMoral marked this conversation as resolved.
Show resolved Hide resolved

if solver._solver_strategy == "callback":
output.progress('empty run for compilation')
Expand All @@ -105,7 +111,7 @@ def run_one_to_cvg(benchmark, objective, solver, meta, stopping_criterion,
# If stopping strategy is 'callback', only call once to get the
# results up to convergence.
callback = _Callback(
objective, meta, stopping_criterion
objective, meta, stopping_criterion, wandb_cb=wandb_cb
)
solver.run(callback)
curve, ctx.status = callback.get_results()
Expand All @@ -127,6 +133,8 @@ def run_one_to_cvg(benchmark, objective, solver, meta, stopping_criterion,
cost = run_one_resolution_cached(stop_val=stop_val,
**call_args)
curve.append(cost)
if wandb_cb is not None:
wandb_cb(cost)

# Check the stopping criterion and update rho if necessary.
stop, ctx.status, stop_val = stopping_criterion.should_stop(
Expand All @@ -137,7 +145,8 @@ def run_one_to_cvg(benchmark, objective, solver, meta, stopping_criterion,


def run_one_solver(benchmark, dataset, objective, solver, n_repetitions,
max_runs, timeout, force=False, output=None, pdb=False):
max_runs, timeout, force=False, output=None, pdb=False,
wandb=False):
"""Run a benchmark for a given dataset, objective and solver.

Parameters
Expand All @@ -164,14 +173,18 @@ def run_one_solver(benchmark, dataset, objective, solver, n_repetitions,
Object to format string to display the progress of the solver.
pdb : bool
It pdb is set to True, open a debugger on error.
wandb : bool
It wandb is set to True, send the results to a wandb page. This option
needs to have wandb installed and configured. See the wandb
documentation: https://wandb.ai/quickstart.

Returns
-------
run_statistics : list
The benchmark results.
"""
run_one_to_cvg_cached = benchmark.cache(
run_one_to_cvg, ignore=['force', 'output', 'pdb']
run_one_to_cvg, ignore=['force', 'output', 'pdb', 'wandb']
)

# Set objective an skip if necessary.
Expand All @@ -192,6 +205,7 @@ def run_one_solver(benchmark, dataset, objective, solver, n_repetitions,
output.set(rep=rep)
# Get meta
meta = dict(
benchmark_name=benchmark.name,
objective_name=str(objective),
solver_name=str(solver),
data_name=str(dataset),
Expand All @@ -208,7 +222,7 @@ def run_one_solver(benchmark, dataset, objective, solver, n_repetitions,
benchmark=benchmark, objective=objective,
solver=solver, meta=meta,
stopping_criterion=stopping_criterion,
force=force, output=output, pdb=pdb
force=force, output=output, pdb=pdb, wandb=wandb
)
if status in ['diverged', 'error', 'interrupted']:
run_statistics = []
Expand Down Expand Up @@ -237,7 +251,7 @@ def run_benchmark(benchmark, solver_names=None, forced_solvers=None,
dataset_names=None, objective_filters=None, max_runs=10,
n_repetitions=1, timeout=100, n_jobs=1, slurm=None,
plot_result=True, html=True, show_progress=True, pdb=False,
output="None"):
wandb=False, output="None"):
"""Run full benchmark.

Parameters
Expand Down Expand Up @@ -278,6 +292,10 @@ def run_benchmark(benchmark, solver_names=None, forced_solvers=None,
If show_progress is set to True, display the progress of the benchmark.
pdb : bool
It pdb is set to True, open a debugger on error.
wandb : bool
It wandb is set to True, send the results to a wandb page. This option
needs to have wandb installed and configured. See the wandb
documentation: https://wandb.ai/quickstart.
output_name : str
Filename for the parquet output. If given, the results will
be stored at <BENCHMARK>/outputs/<filename>.parquet.
Expand All @@ -304,7 +322,7 @@ def run_benchmark(benchmark, solver_names=None, forced_solvers=None,
)
common_kwargs = dict(
benchmark=benchmark, n_repetitions=n_repetitions, max_runs=max_runs,
timeout=timeout, pdb=pdb
timeout=timeout, pdb=pdb, wandb=wandb
)

if slurm is not None:
Expand Down
Empty file.
100 changes: 100 additions & 0 deletions benchopt/utils/logging/wandb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import random
from pathlib import Path
from contextlib import contextmanager


WORDS_URL = (
'https://raw.githubusercontent.com/dwyl/english-words/'
'master/words_dictionary.json'
)


def generate_id(seed, n_words=2):
"""Generate a n_words Id based on a given non-mutable seed."""
# If not present, download and preprocess a list of english words.
word_list = Path("~/.cache/benchopt/word_list.txt").expanduser()
if not word_list.exists():
from urllib.request import urlopen
word_list.parent.mkdir(exist_ok=True)
with urlopen(WORDS_URL) as f:
# List english words
words = [
s.decode().split(":")[0].split()[0][1:-1]
for s in f.readlines() if 16 > len(s) > 12
]
word_list.write_text("\n".join(words))
else:
words = word_list.read_text().splitlines()

rng = random.Random(seed)
N_WORDS = len(words)
name = "-".join([words[rng.randint(0, N_WORDS)] for _ in range(n_words)])
return name


@contextmanager
def wandb_ctx(meta, wandb):
"""Context manager to init and close wandb logging.

This yields a callback that can be used to log the objective in wandb.

Parameters
----------
meta : dict
Meta data on the benchmark run. Used to setup the run names and stored
as the run configuration.
wandb : bool
If set to False, this context does nothing and return None. Else, setup
wandb and returns a callback to log the information to wandb.
"""
if not wandb:
yield None
return

try:
import wandb as wb
except ImportError:
raise ImportError(
"To be able to use wandb, please install and configure it. "
"See first step in https://wandb.ai/quickstart/python-script."
)
try:
assert wb.login()
except (wb.errors.UsageError, AssertionError):
raise RuntimeError(
"wandb is not setup. Need to run `wandb login` to allow for wandb "
"reports upload."
)

try:
# In order to get separate plots for different datasets and objectives,
# group the metric based on a tag. This separates each setup in
# different pannels.
tag = f'{meta["data_name"]}/{meta["objective_name"]}'
# In order to make it easy to navigate the different pannels, we group
# the run by common tags (i.e. couple of data and objective names), and
# we add tag to simplify filtering the results.
# For tags and group, we cannot use directly the names as the length
# of these fields are limited. We thus generate reasonable length
# two-word identifiers with `generate_id`.
run = wb.init(
project=meta['benchmark_name'], name=meta['solver_name'],
group=generate_id(tag), job_type=meta['solver_name'], tags=[
generate_id(meta["data_name"]),
generate_id(meta["objective_name"])
], config=meta, reinit=True,
)

# Callback to be called in the runner.
def cb(objective_dict):
# Remove the leading "objective_" from the column name. Also log
# the time and stop_val
run.log({
f'{tag}/{k.replace("objective_", "")}': v
for k, v in objective_dict.items()
if k.startswith('objective') or k in ["time", "stop_val"]
})

yield cb
finally:
run.finish(quiet=True)
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ Version 1.4 - in development
CLI
---

- Add support for wandb upload of the benchmark results.
By `Thomas Moreau`_ (:gh:`520`)

API
---

Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,6 @@ doc =
slurm =
submitit
rich

wandb =
wandb