diff --git a/tests/unit_tests_old/tests_launch/test_launch_kubernetes.py b/tests/unit_tests_old/tests_launch/test_launch_kubernetes.py index e7efa9ab546..9a4c8362453 100644 --- a/tests/unit_tests_old/tests_launch/test_launch_kubernetes.py +++ b/tests/unit_tests_old/tests_launch/test_launch_kubernetes.py @@ -245,6 +245,7 @@ def test_launch_kube( "preemption_policy": "Never", "node_name": "test-node-name", "node_selectors": {"test-selector": "test-value"}, + "tolerations": [{"key": "test-key", "value": "test-value"}], }, }, } @@ -269,6 +270,7 @@ def test_launch_kube( assert job.spec.template.spec.restart_policy == args["restart_policy"] assert job.spec.template.spec.preemption_policy == args["preemption_policy"] assert job.spec.template.spec.node_name == args["node_name"] + assert job.spec.template.spec.tolerations == args["tolerations"] assert ( job.spec.template.spec.node_selector["test-selector"] == args["node_selectors"]["test-selector"] diff --git a/wandb/sdk/launch/runner/kubernetes.py b/wandb/sdk/launch/runner/kubernetes.py index ecbe57cc016..6dd5ca40531 100644 --- a/wandb/sdk/launch/runner/kubernetes.py +++ b/wandb/sdk/launch/runner/kubernetes.py @@ -167,6 +167,8 @@ def populate_pod_spec( pod_spec["nodeName"] = resource_args.get("node_name") if resource_args.get("node_selectors"): pod_spec["nodeSelectors"] = resource_args.get("node_selectors") + if resource_args.get("tolerations"): + pod_spec["tolerations"] = resource_args.get("tolerations") def populate_container_resources( self, containers: List[Dict[str, Any]], resource_args: Dict[str, Any]