New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Engine: kwargs #1156
Engine: kwargs #1156
Conversation
catalyst/engines/apex.py
Outdated
model = ApexDistributedDataParallel(model, delay_allreduce=self.delay_all_reduce) | ||
model, optimizer = amp.initialize(model, optimizer, **self.apex_kwargs) | ||
# TODO: kwargs for Apex DDP ? | ||
model = ApexDistributedDataParallel(model) # , delay_allreduce=self.delay_all_reduce) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we also add ddp_kwargs
and pass them here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in this case, we have to remove ** from the init and make apex_krargs
and ddp_kwargs
dict storages
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the PR looks amazing, nevertheless could we please make a few extra changes:
- rename
ddp_kwargs
todist_kwargs
orprocess_kwargs
, as far as they are used fortorch.distributed.init_process_group
- add truly
ddp_kwargs
and use them forApexDistributedDataParallel
andDistributedDataParallel
wrappers - add an extra expectation for such cases - I mean, we should raise an error if we could not wrap the model correctly
Huge thanks in advance!
catalyst/engines/torch.py
Outdated
if "device_ids" not in self.ddp_kwargs: | ||
self.ddp_kwargs["device_ids"] = [self.device] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should move this under def setup_process(self, rank: int = -1, world_size: int = 1):
in the end
cause self.device = None
os.environ["MASTER_ADDR"] = str(self.address) | ||
os.environ["MASTER_PORT"] = str(self.port) | ||
dist.init_process_group(self.backend, rank=self.rank, world_size=self.world_size) | ||
dist.init_process_group(**self.process_group_kwargs) | ||
torch.cuda.set_device(int(self._rank)) | ||
self.device = f"cuda:{int(self._rank)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.device = f"cuda:{int(self._rank)}" | |
self.device = f"cuda:{int(self._rank)}" | |
if "device_ids" not in self.ddp_kwargs: | |
self.ddp_kwargs["device_ids"] = [self.device] |
catalyst/engines/torch.py
Outdated
self.ddp_kwargs = copy.deepcopy(ddp_kwargs) | ||
if "device_ids" not in self.ddp_kwargs: | ||
self.ddp_kwargs["device_ids"] = [self.device] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.ddp_kwargs = copy.deepcopy(ddp_kwargs) | |
if "device_ids" not in self.ddp_kwargs: | |
self.ddp_kwargs["device_ids"] = [self.device] | |
self.ddp_kwargs = copy.deepcopy(ddp_kwargs) |
Before submitting (checklist)
catalyst-make-codestyle && catalyst-check-codestyle
(pip install -U catalyst-codestyle
).make check-docs
?pytest .
?latest
andminimal
requirements?Description
Related Issue
Type of Change
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
PS