Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ENH add wandb support for benchopt run #520

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 20 additions & 1 deletion benchopt/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class _Callback:
The time when exiting the callback call.
"""

def __init__(self, objective, meta, stopping_criterion):
def __init__(self, objective, meta, stopping_criterion, wandb):
self.objective = objective
self.meta = meta
self.stopping_criterion = stopping_criterion
Expand All @@ -58,6 +58,21 @@ def __init__(self, objective, meta, stopping_criterion):
self.next_stopval = self.stopping_criterion.init_stop_val()
self.time_callback = time.perf_counter()

if wandb:
try:
import wandb as wb
self.wandb = wb
except ImportError:
raise ImportError(
"To be able to use wandb, install and configure it."
)
wb.init(
project=meta['benchmark_name'], config=meta, reinit=True
)
else:

self.wandb = None
tomMoral marked this conversation as resolved.
Show resolved Hide resolved

def __call__(self, x):
# Stop time and update computation time since the beginning
t0 = time.perf_counter()
Expand Down Expand Up @@ -85,12 +100,16 @@ def log_value(self, x):
time=self.time_iter,
**objective_dict, **self.info
))
if self.wandb is not None:
self.wandb.log(objective_dict)

# Check the stopping criterion
should_stop_res = self.stopping_criterion.should_stop(
self.next_stopval, self.curve
)
stop, self.status, self.next_stopval = should_stop_res
if stop:
self.wandb.finish()
return stop

def get_results(self):
Expand Down
11 changes: 9 additions & 2 deletions benchopt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _get_run_args(cli_kwargs, config_file_kwargs):
"plot",
"html",
"pdb",
"wandb",
"profile",
"env_name",
"output",
Expand Down Expand Up @@ -138,6 +139,11 @@ def _get_run_args(cli_kwargs, config_file_kwargs):
is_flag=True,
help="Launch a debugger if there is an error. This will launch "
"ipdb if it is installed and default to pdb otherwise.")
@click.option('--wandb',
is_flag=True,
help="Log the results in WandB. This option reauires having "
tomMoral marked this conversation as resolved.
Show resolved Hide resolved
"installed wandb. See the wandb documentation: "
"https://wandb.ai/quickstart.")
@click.option('--local', '-l', 'env_name',
flag_value='False', default=True,
help="Run the benchmark in the local conda environment.")
Expand Down Expand Up @@ -179,7 +185,7 @@ def run(config_file=None, **kwargs):
(
benchmark, solver_names, forced_solvers, dataset_names,
objective_filters, max_runs, n_repetitions, timeout, n_jobs, slurm,
plot, html, pdb, do_profile, env_name, output,
plot, html, pdb, wandb, do_profile, env_name, output,
deprecated_objective_filters, old_objective_filters
) = _get_run_args(kwargs, config)

Expand Down Expand Up @@ -247,7 +253,7 @@ def run(config_file=None, **kwargs):
objective_filters=objective_filters,
max_runs=max_runs, n_repetitions=n_repetitions,
timeout=timeout, n_jobs=n_jobs, slurm=slurm,
plot_result=plot, html=html, pdb=pdb,
plot_result=plot, html=html, pdb=pdb, wandb=wandb,
output=output
)

Expand Down Expand Up @@ -324,6 +330,7 @@ def run(config_file=None, **kwargs):
rf"{'--plot' if plot else '--no-plot'} "
rf"{'--html' if html else '--no-html'} "
rf"{'--pdb' if pdb else ''} "
rf"{'--wandb' if wandb else ''} "
rf"--output {output}"
.replace('\\', '\\\\')
)
Expand Down
47 changes: 40 additions & 7 deletions benchopt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run_one_resolution(objective, solver, meta, stop_val):


def run_one_to_cvg(benchmark, objective, solver, meta, stopping_criterion,
force=False, output=None, pdb=False):
force=False, output=None, pdb=False, wandb=False):
"""Run all repetitions of the solver for a value of stopping criterion.

Parameters
Expand All @@ -79,6 +79,10 @@ def run_one_to_cvg(benchmark, objective, solver, meta, stopping_criterion,
for the solver anyway. Else, use the cache if available.
pdb : bool
It pdb is set to True, open a debugger on error.
wandb : bool
It wandb is set to True, send the results to a wandb page. This option
needs to have wandb installed and configured. See the wandb
documentation: https://wandb.ai/quickstart.

Returns
-------
Expand All @@ -105,7 +109,7 @@ def run_one_to_cvg(benchmark, objective, solver, meta, stopping_criterion,
# If stopping strategy is 'callback', only call once to get the
# results up to convergence.
callback = _Callback(
objective, meta, stopping_criterion
objective, meta, stopping_criterion, wandb
)
solver.run(callback)
curve, ctx.status = callback.get_results()
Expand All @@ -117,6 +121,19 @@ def run_one_to_cvg(benchmark, objective, solver, meta, stopping_criterion,
run_one_resolution, force
)

# Configure W&B if necessary
if wandb:
try:
import wandb as wb
except ImportError:
raise ImportError(
"To be able to use wandb, install and configure it."
)
run = wb.init(
name=meta['solver_name'],
project=meta['benchmark_name'], config=meta, reinit=True
)

# compute initial value
call_args = dict(objective=objective, solver=solver, meta=meta)

Expand All @@ -127,17 +144,24 @@ def run_one_to_cvg(benchmark, objective, solver, meta, stopping_criterion,
cost = run_one_resolution_cached(stop_val=stop_val,
**call_args)
curve.append(cost)
if wandb:
wb.log({
k: v for k, v in cost.items() if k.startswith('objective')
})

# Check the stopping criterion and update rho if necessary.
stop, ctx.status, stop_val = stopping_criterion.should_stop(
stop_val, curve
)
if wandb:
run.finish(quiet=True)

return curve, ctx.status


def run_one_solver(benchmark, dataset, objective, solver, n_repetitions,
max_runs, timeout, force=False, output=None, pdb=False):
max_runs, timeout, force=False, output=None, pdb=False,
wandb=False):
"""Run a benchmark for a given dataset, objective and solver.

Parameters
Expand All @@ -164,14 +188,18 @@ def run_one_solver(benchmark, dataset, objective, solver, n_repetitions,
Object to format string to display the progress of the solver.
pdb : bool
It pdb is set to True, open a debugger on error.
wandb : bool
It wandb is set to True, send the results to a wandb page. This option
needs to have wandb installed and configured. See the wandb
documentation: https://wandb.ai/quickstart.

Returns
-------
run_statistics : list
The benchmark results.
"""
run_one_to_cvg_cached = benchmark.cache(
run_one_to_cvg, ignore=['force', 'output', 'pdb']
run_one_to_cvg, ignore=['force', 'output', 'pdb', 'wandb']
)

# Set objective an skip if necessary.
Expand All @@ -192,6 +220,7 @@ def run_one_solver(benchmark, dataset, objective, solver, n_repetitions,
output.set(rep=rep)
# Get meta
meta = dict(
benchmark_name=benchmark.name,
objective_name=str(objective),
solver_name=str(solver),
data_name=str(dataset),
Expand All @@ -208,7 +237,7 @@ def run_one_solver(benchmark, dataset, objective, solver, n_repetitions,
benchmark=benchmark, objective=objective,
solver=solver, meta=meta,
stopping_criterion=stopping_criterion,
force=force, output=output, pdb=pdb
force=force, output=output, pdb=pdb, wandb=wandb
)
if status in ['diverged', 'error', 'interrupted']:
run_statistics = []
Expand Down Expand Up @@ -237,7 +266,7 @@ def run_benchmark(benchmark, solver_names=None, forced_solvers=None,
dataset_names=None, objective_filters=None, max_runs=10,
n_repetitions=1, timeout=100, n_jobs=1, slurm=None,
plot_result=True, html=True, show_progress=True, pdb=False,
output="None"):
wandb=False, output="None"):
"""Run full benchmark.

Parameters
Expand Down Expand Up @@ -278,6 +307,10 @@ def run_benchmark(benchmark, solver_names=None, forced_solvers=None,
If show_progress is set to True, display the progress of the benchmark.
pdb : bool
It pdb is set to True, open a debugger on error.
wandb : bool
It wandb is set to True, send the results to a wandb page. This option
needs to have wandb installed and configured. See the wandb
documentation: https://wandb.ai/quickstart.
output_name : str
Filename for the parquet output. If given, the results will
be stored at <BENCHMARK>/outputs/<filename>.parquet.
Expand All @@ -304,7 +337,7 @@ def run_benchmark(benchmark, solver_names=None, forced_solvers=None,
)
common_kwargs = dict(
benchmark=benchmark, n_repetitions=n_repetitions, max_runs=max_runs,
timeout=timeout, pdb=pdb
timeout=timeout, pdb=pdb, wandb=wandb
)

if slurm is not None:
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,6 @@ doc =
slurm =
submitit
rich

wandb =
wandb