Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] support for custom logic in submitit plugin's checkpoint function #2042

Closed
tesfaldet opened this issue Feb 18, 2022 · 16 comments · Fixed by #2044
Closed

[Feature Request] support for custom logic in submitit plugin's checkpoint function #2042

tesfaldet opened this issue Feb 18, 2022 · 16 comments · Fixed by #2044
Labels

Comments

@tesfaldet
Copy link

🚀 Feature Request

Currently, when preemption/timeout occurs, the submitit plugin resubmits the job with the same initial arguments. There is no clear way of adding custom logic within the checkpoint function of the plugin.

def checkpoint(self, *args: Any, **kwargs: Any) -> Any:
"""Resubmit the current callable at its current state with the same initial arguments."""
# lazy import to ensure plugin discovery remains fast
import submitit
return submitit.helpers.DelayedSubmission(self, *args, **kwargs)

Motivation

My train loop saves (pytorch) checkpoints every N iterations. When preemption/timeout occurs and the submitit plugin resubmits the job, after the job starts it will continue from the most recently saved checkpoint, which could be several hundred (or thousands) of iterations old. The job will have to re-do training iterations that were already done before, which is a waste of time and resources.

It would be nice to have it such that once the submitit plugin's checkpoint function is called (in reaction to preemption/timeout), it could re-use the checkpoint save code that my train loop uses (or whatever logic I'd want, really), that way upon job restart it will start from a checkpoint at most 1 iteration old.

Pitch

Describe the solution you'd like
From what I see so far, the checkpoint function is exposed to the current callable's context (via self):

def __call__(
self,
sweep_overrides: List[str],
job_dir_key: str,
job_num: int,
job_id: str,
singleton_state: Dict[type, Singleton],
) -> JobReturn:
# lazy import to ensure plugin discovery remains fast
import submitit
assert self.hydra_context is not None
assert self.config is not None
assert self.task_function is not None
Singleton.set_state(singleton_state)
setup_globals()
sweep_config = self.hydra_context.config_loader.load_sweep_config(
self.config, sweep_overrides
)
with open_dict(sweep_config.hydra.job) as job:
# Populate new job variables
job.id = submitit.JobEnvironment().job_id # type: ignore
sweep_config.hydra.job.num = job_num
return run_job(
hydra_context=self.hydra_context,
task_function=self.task_function,
config=sweep_config,
job_dir_key=job_dir_key,
job_subdir_key="hydra.sweep.subdir",
)

The callable returns a JobReturn object, which is where the task_function(...) (i.e., the train loop which contains all the relevant train loop data, such as the model, its state dict, the optimizer's state dict, etc.) is executed. Specifically, it's executed within run_job(...):
ret.return_value = task_function(task_cfg)

An issue I see here is that the task_function's context is inaccessible, so it's not possible to be able to pass through objects such as the model's current state, etc. Meaning that the checkpoint function also won't have access to that data. You'd need to figure out how to pass the task_function's context up to __call__(...) so that when checkpoint(...) is called it will have access to job-specific data and the user can inject custom logic (maybe through a yaml config with _target_ set to a save() function) that will be able to act upon this job-specific data, such as checkpointing/dumping it.

Here are some steps forward I think would make sense (correct me if I'm wrong):

  1. It looks like you could provide a TaskFunction class with a __call__ function instead of a task_function from what I see here:

    hydra/hydra/main.py

    Lines 15 to 18 in b16baa7

    def main(
    config_path: Optional[str] = _UNSPECIFIED_,
    config_name: Optional[str] = None,
    ) -> Callable[[TaskFunction], Any]:

    That way, you can treat each task_function as an instance and save your relevant train data using self.<whatever> = <whatever> in __call__. From there, you can access each task_function's context after it's been called. More importantly, run_job(...) will have access to its context.
  2. Have run_job(...) optionally accept a Launcher context so that it could modify it. This would allow it to pass through task_function's context back up to BaseSubmititLauncher, which would make it available to its checkpoint function. The task_function's context can be saved in self.task_function_context.
  3. Have the checkpoint function check for an on_job_preempt_or_timeout callback using _get_callbacks_for_run_job(...) and execute the callback while passing in self.task_function_context.

(Also, as a bonus feature request that could be tackled simultaneously with the above steps, it'd be nice to be able to pass a task_function's context to each of the callbacks in run_job(...). I could submit a separate feature request for this. The quick motivation is that you can make each task's wandb run object that was initialized within the task using self.run = wandb.init(...) accessible to callbacks so that you could make use of their Alerts interface to send Slack alerts. Currently there are ways to do this already but it involves reinitializing a wandb run within each callback separately, with the same run ID you used before, just to send an alert, which introduces unnecessary overhead due to the init process)

Describe alternatives you've considered
I've considered saving checkpoints every iteration but that's just time and space consuming...

Are you willing to open a pull request? (See CONTRIBUTING)
I could be willing to fork and try out the steps above then submit a PR, but I'd like to get a feel for what y'all think about this request before I try implementing it myself.

Additional context

None.

@tesfaldet tesfaldet added the enhancement Enhanvement request label Feb 18, 2022
@Jasha10
Copy link
Collaborator

Jasha10 commented Feb 18, 2022

Hi @tesfaldet,

I'm not an expert in using submitit, but let me nevertheless share an idea:
If your main concern is being able to inject custom logic into the submitit launcher's checkpoint method, might this be achieved by defining a custom subclass of the submitit launcher plugin?

Here is what I am thinking:
Next to your python script (say my_app.py), you could create a file hydra_plugins/custom_submitit_logic.py which defines subclasses of the BaseSubmititLauncher class from Hydra's submitit plugin:

$ tree
.
├── conf
│   └── config.yaml
├── hydra_plugins
│   └── custom_submitit_logic.py
└── my_app.py

Here is the main python script:

$ cat my_app.py
import hydra
from omegaconf import DictConfig


@hydra.main(config_path="conf", config_name="config")
def app(cfg: DictConfig) -> None:
    print(cfg)


if __name__ == "__main__":
    app()

Here is the yaml configuration file:

$ cat conf/config.yaml
defaults:
  - override hydra/launcher: submitit_local
  # - override hydra/launcher: submitit_slurm
  - _self_

hydra:
  launcher:
    _target_: hydra_plugins.custom_submitit_logic.MyLocalLauncher
    # _target_: hydra_plugins.custom_submitit_logic.MySlurmLauncher

And here is the custom plugin subclassing BaseSubmititLauncher:

$ cat hydra_plugins/custom_submitit_logic.py
from typing import Any

from hydra_plugins.hydra_submitit_launcher.submitit_launcher import BaseSubmititLauncher


class MyBaseSubmititLauncher(BaseSubmititLauncher):
    def __init__(self, *args, **kwargs) -> None:
        print("INITIALIZING CUSTOM SUBMITIT LAUNCHER")
        ...
        super().__init__(*args, **kwargs)

    def checkpoint(self, *args: Any, **kwargs: Any) -> Any:
        """This method is a modified version of the BaseSubmititLauncher.checkpoint method"""

        #########################
        ### CUSTOM LOGIC HERE ###
        #########################

        super().checkpoint(*modified_args, **modified_kwargs)


class MyLocalLauncher(MyBaseSubmititLauncher):
    _EXECUTOR = "local"


class MySlurmLauncher(MyBaseSubmititLauncher):
    _EXECUTOR = "slurm"

When running my_app.py in --multirun mode, the call to print in MyBaseSubmititLauncher.__init__ verifies that the custom subclass is being used:

$ python my_app.py --multirun
INITIALIZING CUSTOM SUBMITIT LAUNCHER
[2022-02-18 17:15:33,990][HYDRA] Submitit 'local' sweep output dir : multirun/2022-02-18/17-15-33
[2022-02-18 17:15:33,991][HYDRA]        #0 :

Would subclassing the submitit plugin in this way be sufficiently flexible for your use-case?

@Jasha10
Copy link
Collaborator

Jasha10 commented Feb 18, 2022

I have tried passing a callable class instance (instead of passing a function) to hydra.main:

from typing import Any

import hydra
from omegaconf import DictConfig


class MyTaskFunction:
    def __init__(self, context: Any) -> None:
        self._context = context

    def __call__(self, cfg: DictConfig) -> None:
        print(cfg)

if __name__ == "__main__":
    my_task_function = MyTaskFunction(context=123)
    app = hydra.main(config_path="conf", config_name="config")(my_task_function)
    app()

It is not working currently, but should be easy to support.
For now, would you be able to use e.g. a global variable to store the context that is relevant for checkpointing?

@tesfaldet
Copy link
Author

tesfaldet commented Feb 19, 2022

Thanks for the quick response @Jasha10.

Your suggestion is a great one and I think it could work.

Specifically, my suggestion of using a TaskFunction class with a __call__ function that's decorated with the main decorator combined with your suggestion could work. I see that you already tried the idea and that it wasn't working for you, but I think I may have an MVP below:

from omegaconf import DictConfig, OmegaConf
from typing import Any, Optional, Callable
from functools import wraps
from hydra.types import TaskFunction
from hydra.core.utils import JobReturn


TaskFunction = Callable[[Any], Any]

_UNSPECIFIED_: Any = object()


def main(config_path: Optional[str] = _UNSPECIFIED_,
         config_name: Optional[str] = None) -> Callable[[TaskFunction], Any]:
  def main_decorator(task_function: TaskFunction) -> Callable[[], None]:
    @wraps(task_function)
    def decorator_main(cfg_passthrough: Optional[DictConfig] = None, checkpoint: Optional[bool] = False) -> Any:
      if cfg_passthrough is not None:
        return task_function(cfg_passthrough)
      else:
        conf = OmegaConf.create({'config_name': config_name, 'config_path': config_path})
        launcher = BaseSubmititLauncher()
        launcher.setup(task_function=task_function, config=conf)
        ret = launcher.launch()
        if checkpoint:
          launcher.checkpoint()
        return ret
    return decorator_main
  return main_decorator


def run_job(task_function: TaskFunction,
            config: DictConfig):
  return task_function(config)


class BaseSubmititLauncher():
  def __init__(self) -> None:
    self.config: Optional[DictConfig] = None
    self.task_function: Optional[TaskFunction] = None

  def setup(
    self,
    *,
    task_function: TaskFunction,
    config: DictConfig
  ) -> None:
    self.config = config
    self.task_function = task_function

  def __call__(self) -> JobReturn:
    assert self.config is not None
    assert self.task_function is not None

    return run_job(
      task_function=self.task_function,
      config=self.config
    )

  def checkpoint(self, *args: Any, **kwargs: Any) -> Any:
    """Resubmit the current callable at its current state with the same initial arguments."""
    print(f'in checkpoint with model {self.task_function.model}')

  def launch(self) -> JobReturn:
    return self()


class Train(TaskFunction):
  def __init__(self):
    self.model = None
    self.config = None

  # main train loop code
  def __call__(self, cfg: DictConfig) -> float:
    self.model = 'PyTorchModel'
    self.config = cfg
    return 3.3

  def __str__(self) -> str:
    return f'Train(self.model = {self.model}, self.config = {self.config})'


if __name__ == '__main__':
  train = Train()
  print(f'in __main__ after instantiating Train object: {train}')

  wrapped = main(config_path='./conf', config_name='config')(train)
  print(f'in __main__ after wrapping Train object with main decorator: {train}')

  wrapped()
  print(f'in __main__ after calling train() with no config passthrough: {train}')

  wrapped(checkpoint=True)
  print(f'in __main__ after calling train() with no config passthrough and calling checkpoint: {train}')

  wrapped(OmegaConf.create({'foo': 'bar'}))
  print(f'in __main__ after calling train() with config passthrough: {train}')

outputs:

in __main__ after instantiating Train object: Train(self.model = None, self.config = None)
in __main__ after wrapping Train object with main decorator: Train(self.model = None, self.config = None)
in __main__ after calling train() with no config passthrough: Train(self.model = PyTorchModel, self.config = {'config_name': 'config', 'config_path': './conf'})
in checkpoint with model PyTorchModel
in __main__ after calling train() with no config passthrough and calling checkpoint: Train(self.model = PyTorchModel, self.config = {'config_name': 'config', 'config_path': './conf'})
in __main__ after calling train() with config passthrough: Train(self.model = PyTorchModel, self.config = {'foo': 'bar'})

What do you think? It follows a similar execution path as Hydra to show that this might be possible without any change to Hydra's code :)

Basically, since the launcher already has access to task_function, then by making task_function a Callable instance (i.e., a TaskFunction object) I could access its attributes. These attributes would be the relevant pytorch data. Combine that with a custom launcher that's a subclass of BaseSubmititLauncher that has its own checkpoint function that accesses task_function and its attributes and I'm pretty sure we're golden.

@tesfaldet
Copy link
Author

tesfaldet commented Feb 19, 2022

I just took a closer look at your example and it's basically the same as mine 😅 I'm surprised it didn't work for you! I haven't tested mine with Hydra's actual main function since I made my own version of it, but it's fairly similar and it worked...

I'm not at my computer for the weekend so I can't test it out myself with Hydra's main at the moment but I'm curious what error you got.

@tesfaldet
Copy link
Author

tesfaldet commented Feb 19, 2022

For now, would you be able to use e.g. a global variable to store the context that is relevant for checkpointing?

Correct me if I'm wrong, but doesn't the submitit plug-in launch a subprocess for each list of sweep params, with each subprocess executing your app's task function with those params? Wouldn't that mean that a global variable outside of the app's task function won't be accessible to each subprocess unless there's some IPC implemented?

@Jasha10
Copy link
Collaborator

Jasha10 commented Feb 21, 2022

I'm not at my computer for the weekend so I can't test it out myself with Hydra's main at the moment but I'm curious what error you got.

When I run the example from my previous comment, I get the following error:

$ python my_app.py
Traceback (most recent call last):
  File "my_app.py", line 17, in <module>
    app()
  File "/home/jasha10/hydra.git/hydra/main.py", line 52, in decorated_main
    config_name=config_name,
  File "/home/jasha10/hydra.git/hydra/_internal/utils.py", line 320, in _run_hydra
    ) = detect_calling_file_or_module_from_task_function(task_function)
  File "/home/jasha10/hydra.git/hydra/_internal/utils.py", line 51, in detect_calling_file_or_module_from_task_function
    calling_file = task_function.__code__.co_filename
AttributeError: 'MyTaskFunction' object has no attribute '__code__'

I am almost done with a PR to patch this so that a callable can be passed to hydra.main.
Edit: See PR #2044.

@Jasha10
Copy link
Collaborator

Jasha10 commented Feb 21, 2022

Wouldn't that mean that a global variable outside of the app's task function won't be accessible to each subprocess unless there's some IPC implemented?

Aah yes, good point. I haven't tried it myself, but I suspect you're right.

@tesfaldet
Copy link
Author

tesfaldet commented Feb 21, 2022

@Jasha10 I believe the error you're experiencing occurs because you're passing the callable object and not its function, since only functions have a __code__ attribute. So perhaps the PR is not necessary?

import hydra
from omegaconf import DictConfig


class MyTaskFunction:
	def __init__(self) -> None:
		self._context = None

	def __call__(self, cfg: DictConfig) -> None:
		self._context = 123
		print(cfg, self._context)

if __name__ == "__main__":
	my_task_function = MyTaskFunction()
	print(my_task_function._context)
	app = hydra.main(config_path=None, config_name=None)(my_task_function.__call__)
	app()
	print(my_task_function._context)

outputs:

None
{} 123
123

@tesfaldet
Copy link
Author

tesfaldet commented Feb 21, 2022

Nvm, the PR is super necessary haha. Although the above works, the problem with passing in my_task_function.__call__ is that you're not passing in the object, which is what's needed for the custom MyBaseSubmititLauncher's checkpoint function to access the context of my_task_function. So forget what I said!

@tesfaldet
Copy link
Author

the problem with passing in my_task_function.call is that you're not passing in the object

I just realized that you could access the object through my_task_function.__call__.__self__. Welp, I'm gonna test this out right now with a custom MyBaseSubmititLauncher and report back :)

@tesfaldet
Copy link
Author

tesfaldet commented Feb 23, 2022

I finished trying it out and it worked :) It was a hassle to get working properly for a couple of reasons:

  1. Providing the path to the file that has MyBaseSubmititLauncher (e.g., src.launchers.MyBaseSubmititLauncher) as part of hydra.launcher._target_ only works if the path starts with hydra_plugins, so hydra_plugins.launchers.MyBaseSubmititLauncher. However, during instantiation, hydra complains that hydra_plugins is not a module. But if you put an __init__.py file there it'll clash with the import line from hydra_plugins.hydra_submitit_launcher.submitit_launcher import BaseSubmititLauncher that's inside launchers.py. So you have to instead make a folder within hydra_plugins and put the __init__.py there as well as the custom plugin. So hydra_plugins.launchers.my_submitit_launcher.MyBaseSubmititLauncher and MyBaseSubmititLauncher is located within my_submitit_launcher.py, where launchers.py got turned into a folder /launchers.
  2. After submitit's pickle save-then-load-context process before starting each job (essentially sending pickled contexts to each node, which will unpickle the context and run the task), if your cwd (hydra.sweep.dir) is not in the same location as where the main app is located, then when hydra tries to load the hydra_plugins.launchers module you made in the previous step, it won't be able to find it. A workaround is:
# otherwise pickle.load from submitit launcher won't find it since cloudpickle registers by reference by default
import hydra_plugins.launchers
cloudpickle.register_pickle_by_value(hydra_plugins.launchers)

in the top your main app file. Matter of fact, you need to do this register_pickle_by_value for any module that you've created in your project folder hierarchy if you want to use a hydra working directory that's not the same as the project directory.

There are a couple of weird gotchas but I'm too tired to list them out right now. I've managed to get it working with my own custom submitit launcher with its own checkpoint function and had it send me a wandb alert during pre-emption, re-using the run = wandb.init() context that was only available within the task function. The only issue right now is that upon job restart, specifically during the pickle load process of submitit when it's reloading the associated job's pickle, it fails with a ModuleNotFoundError for a module that exists in my project hierarchy. Which is very odd considering I already used cloudpickle.register_pickle_by_value on this module.

@Jasha10
Copy link
Collaborator

Jasha10 commented Feb 23, 2022

I see, very interesting! I'm glad it's working for you.

I'm a bit surprised that creating a file hydra_plugins/my_submitit_launcher.py (with no __init__.py file next to it) did not work. I seem to recall that this strategy worked without a hitch when I posted this comment above. Note that omitting the __init__.py file from the hydra_plugins directory is highly recommended -- see the note about namespace packages at the top of this doc regarding plugin development. (Including an __init__.py file inside a directory nested one level under hydra_plugins is not a problem though.)

@tesfaldet
Copy link
Author

tesfaldet commented Feb 23, 2022

It's working up to a point!

Here's a more detailed explanation of the issue(s) I've been experiencing. Let's say the below is an example folder hierarchy:

/home/project/
├─ main.py
├─ src/
│  ├─ utils.py
│  ├─ __init__.py
/network/datasets/
/network/scratch/

main.py is where my task function is located. /network/datasets, /network/scratch, and /home/project are all located on different servers, each optimized for their respective use cases and all existing under a SLURM environment. When using the Submitit plugin to do a multi-run launch of main.py using sweep params across a SLURM cluster, it will use cloudpickle to serialize main.py and the files it references through its import statements. Let's say within main.py it has an import src statement. Now consider that hydra.sweep.dir was set to /network/scratch for saving results, which includes the hydra outputs and application outputs. Since hydra treats hydra.sweep.dir as both the output and working directory, when cloudpickle unpickles main.py it will raise an error since it won't be able to find the package src. This is because the unpickling happened outside of /home/project---the working directory is /network/scratch and so sys.path does not contain /home/project which is where the src package is located.

Cloudpickle (and pickle) serializes files by reference by default. Cloudpickle can automatically serialize by value but only during an interactive session. However, you could use their experimental cloudpickle.register_pickle_by_value(...) function which is designed for this very use case to explicitly pickle by value. Specifically, in main.py you'd have import cloudpickle; import src; cloudpickle.register_pickle_by_value(src) and the src package and its contents/modules will be serialized by value during submitit's pickling process (submitit uses cloudpickle). Now when main.py is being unpickled from the /network/scratch working directory, it will be able to find src since it was serialized by value.

You'd think doing import sys; sys.path.append('/home/project') in main.py would solve this problem but unfortunately it doesn't. It seems that when cloudpickle is unpickling, it is not working under the context of main.py since that is yet to be executed by the launcher. If you wanted to go the sys.path.append route, then you'd have to do it within the cloudpickle.load(...) function which is not what you'd want since you'd be modifying library code.

Anyways, register_pickle_by_value(src) works when following the prepare sweep params -> prepare jobs with params -> pickle jobs -> send to nodes -> unpickle jobs -> execute jobs hydra submitit plugin process. But it fails when resuming from a checkpointed job after it was pre-empted or timed out. Specifically, it only seems to fail when the task function is the __call__ function from my TaskFunction subclass: Train. When my task function is just a regular function like train() with the hydra.main decorator, job resumption works fine. But when my task function is the __call__ function within my Train(TaskFunction): class, and I'm feeding it into the decorator as train = Train(); app = hydra.main(config_path='conf', config_name='config')(train.__call__); app() it fails with the above error. I'm not sure why.... perhaps I should place import cloudpickle; import src; cloudpickle.register_pickle_by_value(src) within the Train class?

@tesfaldet
Copy link
Author

I'm a bit surprised that creating a file hydra_plugins/my_submitit_launcher.py (with no init.py file next to it) did not work. I seem to recall that this strategy worked without a hitch when I posted #2042 (comment) above. Note that omitting the init.py file from the hydra_plugins directory is highly recommended -- see the note about namespace packages at the top of this doc regarding plugin development. (Including an init.py file inside a directory nested one level under hydra_plugins is not a problem though.)

The below error is what I got if I were to put my own custom plugin within a hydra_plugins folder with no __init__.py. Hydra is able to resolve the plugin just fine by inspecting its _target_ and using its own process for resolving classes given a . de-limited path. However, when the submitit plugin is unpickling the code to execute as a job, it tries to search for the hydra_plugins.my_submitit_launcher module since an instance of hydra_plugins.my_submitit_launcher.MySlurmLauncher is what's returned by hydra's instantiate function that is used in the execution stack with the main code. A way around this is to create an additional folder, launchers, within the hydra_plugins folder and place an __init__.py inside launchers. Then move my_submitit_launcher.py to the launchers folder and place import hydra_plugins.launchers; cloudpickle.register_pickle_by_value(hydra_plugins.launchers) in your main code. Finally, change the submitit launcher's target to hydra_plugins.launchers.my_submitit_launcher.MySlurmLauncher. It's quite the hassle....

Exception has occurred: FailedJobError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
Job (task=0) failed during processing with trace:
----------------------
Traceback (most recent call last):
  File "/home/mila/m/mattie.tesfaldet/.conda/envs/pytorch1.10/lib/python3.8/site-packages/submitit/core/submission.py", line 51, in process_job
    delayed = utils.DelayedSubmission.load(paths.submitted_pickle)
  File "/home/mila/m/mattie.tesfaldet/.conda/envs/pytorch1.10/lib/python3.8/site-packages/submitit/core/utils.py", line 138, in load
    obj = pickle_load(filepath)
  File "/home/mila/m/mattie.tesfaldet/.conda/envs/pytorch1.10/lib/python3.8/site-packages/submitit/core/utils.py", line 227, in pickle_load
    return pickle.load(ifile)
ModuleNotFoundError: No module named 'hydra_plugins.my_submitit_launcher'

@tesfaldet
Copy link
Author

Anyways, register_pickle_by_value(src) works when following the prepare sweep params -> prepare jobs with params -> pickle jobs -> send to nodes -> unpickle jobs -> execute jobs hydra submitit plugin process. But it fails when resuming from a checkpointed job after it was pre-empted or timed out. Specifically, it only seems to fail when the task function is the __call__ function from my TaskFunction subclass: Train. When my task function is just a regular function like train() with the hydra.main decorator, job resumption works fine. But when my task function is the __call__ function within my Train(TaskFunction): class, and I'm feeding it into the decorator as train = Train(); app = hydra.main(config_path='conf', config_name='config')(train.__call__); app() it fails with the above error. I'm not sure why....

I finally fixed the above issue! Specifically, the ModuleNotFoundError: No module named 'src' error. It required digging deep into submitit's code and figuring out how its checkpointing works. The reason src is able to be found during the unpickle jobs part of the prepare sweep params -> prepare jobs with params -> pickle jobs -> send to nodes -> unpickle jobs -> execute jobs process (let's call this the "main pipeline") but not during the unpickle job part of the pre-empt/time-out job -> checkpoint -> pickle job -> ask slurm to rerun job -> unpickle job -> execute job process (let's call this the "checkpoint pipeline") is because cloudpickle.register_pickle_by_value(src) was executed in the main pipeline (specifically in main.py) before the task function __call__ was cloudpickle.dumped, but when the task function is unpickled and executed, since it's a new process cloudpickle doesn't have src as part of its _PICKLE_BY_VALUE_MODULES list anymore. So when the checkpoint pipeline begins, it's operating off of a context whose cloudpickle's _PICKLE_BY_VALUE_MODULES list is different compared to the main pipeline. So, a hackish way around this is to re-register src inside the checkpoint function:

def checkpoint(self, *args: Any, **kwargs: Any) -> Any:
	"""This method is a modified version of the BaseSubmititLauncher.checkpoint method"""
	run = self.task_function.__self__.run
	run.alert(
		title='Job Pre-empted/Timed-out',
		text=f'Job {run.name} in group {run.group} has either been pre-empted or timed-out.',
		level=wandb.AlertLevel.INFO
	)

	import cloudpickle
	cloudpickle.register_pickle_by_value(self.task_function.__self__.src)

	return super().checkpoint(*args, **kwargs)

As you can tell, this meant I had to pass a reference to src via the task function's attribute self.src. So, in my main.py:

import hydra
import cloudpickle
import src
# otherwise pickle.load from submitit launcher won't find src since cloudpickle registers by reference by default
cloudpickle.register_pickle_by_value(src)

class Train(TaskFunction):
	def __init__(self):
		self.model_and_trainer = None
		self.run = None
		self.cfg = None
		self.src = src

	# main train loop code
	def __call__(self, cfg: DictConfig) -> float:
        ...

if __name__ == '__main__':
	train = Train()
	app = hydra.main(config_path='conf', config_name='config')(train.__call__)
	app()

It's...quite an annoying solution. Anyways, I hope with your PR that you consider the pecularities of having a different working directory than the code's directory in combination with using submitit. You might bump into the same issues. The PR has the potential to clean this up.

@Jasha10
Copy link
Collaborator

Jasha10 commented Feb 24, 2022

Wow, nice work! Seems like very tricky logic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants