-
Notifications
You must be signed in to change notification settings - Fork 615
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add raytune examples / tests (#4053)
- Loading branch information
Showing
17 changed files
with
506 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1006,6 +1006,8 @@ workflows: | |
- "tf21" | ||
- "tf25" | ||
- "tf26" | ||
- "ray112" | ||
- "ray2" | ||
- "service" | ||
- "noml" | ||
- "grpc" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import os | ||
|
||
import requests | ||
|
||
|
||
def get_wandb_api_key() -> str: | ||
base_url = os.environ.get("WANDB_BASE_URL", "https://api.wandb.ai") | ||
api_key = os.environ.get("WANDB_API_KEY") | ||
if not api_key: | ||
auth = requests.utils.get_netrc_auth(base_url) | ||
if not auth: | ||
raise ValueError( | ||
f"must configure api key by env or in netrc for {base_url}" | ||
) | ||
api_key = auth[-1] | ||
return api_key | ||
|
||
|
||
def get_wandb_api_key_file(file_name: str = None) -> str: | ||
file_name = file_name or ".wandb-api-key.secret" | ||
api_key = get_wandb_api_key() | ||
with open(file_name, "w") as f: | ||
f.write(api_key) | ||
return os.path.abspath(file_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
"""ray-tune test. | ||
Based on: | ||
https://docs.wandb.ai/guides/integrations/other/ray-tune | ||
""" | ||
|
||
import random | ||
|
||
from _test_support import get_wandb_api_key | ||
import numpy as np | ||
from ray import tune | ||
from ray.tune.integration.wandb import wandb_mixin | ||
import wandb | ||
|
||
|
||
@wandb_mixin | ||
def train_fn(config): | ||
for _i in range(10): | ||
loss = config["a"] + config["b"] | ||
wandb.log({"loss": loss}) | ||
tune.report(loss=loss) | ||
|
||
|
||
# Make test deterministic | ||
random.seed(2022) | ||
np.random.seed(2022) | ||
|
||
|
||
tune.run( | ||
train_fn, | ||
config={ | ||
# define search space here | ||
"a": tune.choice([1, 2, 3]), | ||
"b": tune.choice([4, 5, 6]), | ||
# wandb configuration | ||
"wandb": {"project": "Optimization_Project", "api_key": get_wandb_api_key()}, | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
plugin: | ||
- wandb | ||
tag: | ||
shard: ray2 | ||
depend: | ||
requirements: | ||
- pandas | ||
- ray[tune]>=2.0.0rc0 # ray1 would work, but limit the number of versions | ||
assert: | ||
- :yea:exit: 0 | ||
- :wandb:runs_len: 1 | ||
- :wandb:runs[0][config]: {'a': 2, 'b': 4} | ||
- :op:>: | ||
- :wandb:runs[0][summary][loss] | ||
- 0 | ||
- :wandb:runs[0][exitcode]: 0 | ||
- :op:contains: | ||
- :wandb:runs[0][telemetry][1] # imports_init | ||
- 30 # ray | ||
- :op:contains: | ||
- :wandb:runs[0][telemetry][2] # imports_finish | ||
- 30 # ray |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""ray-tune test. | ||
Based on: | ||
https://docs.ray.io/en/master/tune/examples/tune-wandb.html | ||
""" | ||
|
||
from _test_support import get_wandb_api_key_file | ||
import numpy as np | ||
from ray import air, tune | ||
from ray.air import session | ||
from ray.air.callbacks.wandb import WandbLoggerCallback | ||
|
||
|
||
def objective(config, checkpoint_dir=None): | ||
for _i in range(30): | ||
loss = config["mean"] + config["sd"] * np.random.randn() | ||
session.report({"loss": loss}) | ||
|
||
|
||
def tune_function(api_key_file): | ||
"""Example for using a WandbLoggerCallback with the function API""" | ||
tuner = tune.Tuner( | ||
objective, | ||
tune_config=tune.TuneConfig( | ||
metric="loss", | ||
mode="min", | ||
), | ||
run_config=air.RunConfig( | ||
callbacks=[ | ||
WandbLoggerCallback(api_key_file=api_key_file, project="Wandb_example") | ||
], | ||
), | ||
param_space={ | ||
"mean": tune.grid_search([1, 2, 3, 4, 5]), | ||
"sd": tune.uniform(0.2, 0.8), | ||
}, | ||
) | ||
results = tuner.fit() | ||
|
||
return results.get_best_result().config | ||
|
||
|
||
def main(): | ||
api_key_file = get_wandb_api_key_file() | ||
tune_function(api_key_file) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
plugin: | ||
- wandb | ||
tag: | ||
shard: ray2 | ||
depend: | ||
requirements: | ||
- pandas | ||
- ray[tune]>=2.0.0rc0 | ||
assert: | ||
- :yea:exit: 0 | ||
- :wandb:runs_len: 5 | ||
- :wandb:runs[0][exitcode]: 0 | ||
- :wandb:runs[1][exitcode]: 0 | ||
- :wandb:runs[2][exitcode]: 0 | ||
- :wandb:runs[3][exitcode]: 0 | ||
- :wandb:runs[4][exitcode]: 0 | ||
- :op:contains: | ||
- :wandb:runs[0][telemetry][1] # imports_init | ||
- 30 # ray | ||
- :op:contains: | ||
- :wandb:runs[1][telemetry][1] # imports_init | ||
- 30 # ray | ||
- :op:contains: | ||
- :wandb:runs[2][telemetry][1] # imports_init | ||
- 30 # ray | ||
- :op:contains: | ||
- :wandb:runs[3][telemetry][1] # imports_init | ||
- 30 # ray | ||
- :op:contains: | ||
- :wandb:runs[4][telemetry][1] # imports_init | ||
- 30 # ray |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
"""ray-tune test. | ||
Based on: | ||
https://docs.ray.io/en/master/tune/examples/tune-wandb.html | ||
""" | ||
|
||
from _test_support import get_wandb_api_key_file | ||
import numpy as np | ||
from ray import tune | ||
from ray.air import session | ||
from ray.tune.integration.wandb import wandb_mixin | ||
import wandb | ||
|
||
|
||
@wandb_mixin | ||
def decorated_objective(config, checkpoint_dir=None): | ||
for _i in range(30): | ||
loss = config["mean"] + config["sd"] * np.random.randn() | ||
session.report({"loss": loss}) | ||
wandb.log(dict(loss=loss)) | ||
|
||
|
||
def tune_decorated(api_key_file): | ||
"""Example for using the @wandb_mixin decorator with the function API""" | ||
tuner = tune.Tuner( | ||
decorated_objective, | ||
tune_config=tune.TuneConfig( | ||
metric="loss", | ||
mode="min", | ||
), | ||
param_space={ | ||
"mean": tune.grid_search([1, 2, 3, 4, 5]), | ||
"sd": tune.uniform(0.2, 0.8), | ||
"wandb": {"api_key_file": api_key_file, "project": "Wandb_example"}, | ||
}, | ||
) | ||
results = tuner.fit() | ||
|
||
return results.get_best_result().config | ||
|
||
|
||
def main(): | ||
api_key_file = get_wandb_api_key_file() | ||
tune_decorated(api_key_file) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
plugin: | ||
- wandb | ||
tag: | ||
shard: ray2 | ||
depend: | ||
requirements: | ||
- pandas | ||
- ray[tune]>=2.0.0rc0 | ||
assert: | ||
- :yea:exit: 0 | ||
- :wandb:runs_len: 5 | ||
- :wandb:runs[0][exitcode]: 0 | ||
- :wandb:runs[1][exitcode]: 0 | ||
- :wandb:runs[2][exitcode]: 0 | ||
- :wandb:runs[3][exitcode]: 0 | ||
- :wandb:runs[4][exitcode]: 0 | ||
- :op:contains: | ||
- :wandb:runs[0][telemetry][1] # imports_init | ||
- 30 # ray | ||
- :op:contains: | ||
- :wandb:runs[1][telemetry][1] # imports_init | ||
- 30 # ray | ||
- :op:contains: | ||
- :wandb:runs[2][telemetry][1] # imports_init | ||
- 30 # ray | ||
- :op:contains: | ||
- :wandb:runs[3][telemetry][1] # imports_init | ||
- 30 # ray | ||
- :op:contains: | ||
- :wandb:runs[4][telemetry][1] # imports_init | ||
- 30 # ray |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
"""ray-tune test. | ||
Based on: | ||
https://docs.ray.io/en/master/tune/examples/tune-wandb.html | ||
""" | ||
|
||
from _test_support import get_wandb_api_key_file | ||
import numpy as np | ||
from ray import tune | ||
from ray.tune import Trainable | ||
from ray.tune.integration.wandb import WandbTrainableMixin | ||
import wandb | ||
|
||
|
||
class WandbTrainable(WandbTrainableMixin, Trainable): | ||
def step(self): | ||
for _i in range(30): | ||
loss = self.config["mean"] + self.config["sd"] * np.random.randn() | ||
wandb.log({"loss": loss}) | ||
return {"loss": loss, "done": True} | ||
|
||
|
||
def tune_trainable(api_key_file): | ||
"""Example for using a WandTrainableMixin with the class API""" | ||
tuner = tune.Tuner( | ||
WandbTrainable, | ||
tune_config=tune.TuneConfig( | ||
metric="loss", | ||
mode="min", | ||
), | ||
param_space={ | ||
"mean": tune.grid_search([1, 2, 3, 4, 5]), | ||
"sd": tune.uniform(0.2, 0.8), | ||
"wandb": {"api_key_file": api_key_file, "project": "Wandb_example"}, | ||
}, | ||
) | ||
results = tuner.fit() | ||
|
||
return results.get_best_result().config | ||
|
||
|
||
def main(): | ||
api_key_file = get_wandb_api_key_file() | ||
tune_trainable(api_key_file) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
plugin: | ||
- wandb | ||
tag: | ||
shard: ray2 | ||
depend: | ||
requirements: | ||
- pandas | ||
- ray[tune]>=2.0.0rc0 | ||
assert: | ||
- :yea:exit: 0 | ||
- :wandb:runs_len: 5 | ||
- :wandb:runs[0][exitcode]: 0 | ||
- :wandb:runs[1][exitcode]: 0 | ||
- :wandb:runs[2][exitcode]: 0 | ||
- :wandb:runs[3][exitcode]: 0 | ||
- :wandb:runs[4][exitcode]: 0 | ||
- :op:contains: | ||
- :wandb:runs[0][telemetry][1] # imports_init | ||
- 30 # ray | ||
- :op:contains: | ||
- :wandb:runs[1][telemetry][1] # imports_init | ||
- 30 # ray | ||
- :op:contains: | ||
- :wandb:runs[2][telemetry][1] # imports_init | ||
- 30 # ray | ||
- :op:contains: | ||
- :wandb:runs[3][telemetry][1] # imports_init | ||
- 30 # ray | ||
- :op:contains: | ||
- :wandb:runs[4][telemetry][1] # imports_init | ||
- 30 # ray |
Oops, something went wrong.