Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable support for Intel XPU devices (AKA Intel GPUs) #19443

Draft
wants to merge 40 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
4b1c2f3
Enable Intel XPU as an accelerator and automatic GPU
Jul 20, 2023
3cd7503
Merge branch 'Lightning-AI:master' into master
coreyjadams Jul 20, 2023
494a95a
Merge branch 'master' of https://github.com/Lightning-AI/lightning
Nov 8, 2023
54af860
Fixing some things since my last updates: accelerator structure chang…
Nov 8, 2023
d5274ac
Merge branch 'master' of https://github.com/Lightning-AI/lightning
Dec 4, 2023
5de217e
Merge branch 'Lightning-AI:master' into master
coreyjadams Feb 9, 2024
bd12150
Enable DDP for XPU. THere is a bug, probably in the CCL layer, where…
Feb 9, 2024
711281a
Update throughput_monitor.py
coreyjadams Feb 9, 2024
0c36f3c
Update accelerator_connector.py
coreyjadams Feb 9, 2024
4477301
Update module.py
coreyjadams Feb 9, 2024
6b76644
Update saving.py
coreyjadams Feb 9, 2024
0fac170
Add further support for XPU devices from Intel
Feb 9, 2024
3bb66e3
Merge branch 'Lightning-AI:master' into master
coreyjadams Feb 9, 2024
365126d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2024
5b6c3a8
Fix typo
Feb 9, 2024
d08bda8
Fix type error in memory stats. Enable oneccl in distributed mode
Feb 12, 2024
cdc9ae5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 12, 2024
68ffe7a
Address precommit.ci errors
Feb 12, 2024
0242636
Merge branch 'master' of github.com:coreyjadams/lightning
Feb 12, 2024
5103a16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 12, 2024
0aad417
Missed a line that was too long
Feb 12, 2024
0866cc9
Merge branch 'Lightning-AI:master' into master
coreyjadams Feb 12, 2024
eab1302
Merge branch 'master' of github.com:coreyjadams/lightning
Feb 12, 2024
036cc1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 12, 2024
6936c24
Wrap ipex imports more carefully
Feb 12, 2024
a4bb6dc
Merge branch 'master' of github.com:coreyjadams/lightning
Feb 12, 2024
6ab7e1a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 12, 2024
eee4061
Fix typo of nncl to nccl
Feb 12, 2024
be66729
Add function typing and return signature to one XPU function. Add ov…
Feb 12, 2024
b5ab237
Fix missing import
Feb 12, 2024
d62c74f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 12, 2024
74f2e24
Merge branch 'master' of https://github.com/Lightning-AI/lightning
Apr 3, 2024
89865b5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2024
5820c75
Fix typing error in register_accelerator
Apr 3, 2024
3680639
Merge branch 'master' of github.com:coreyjadams/lightning
Apr 3, 2024
b3b2832
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2024
8a443c7
Update xpu.py
coreyjadams Apr 3, 2024
e87d366
Merge branch 'master' of https://github.com/Lightning-AI/lightning
Apr 24, 2024
11d9dbe
Merge branch 'Lightning-AI:master' into master
coreyjadams May 21, 2024
09da87b
Merge branch 'Lightning-AI:master' into master
coreyjadams May 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lightning/fabric/accelerators/__init__.py
Expand Up @@ -18,6 +18,7 @@
from lightning.fabric.accelerators.mps import MPSAccelerator # noqa: F401
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.accelerators.xla import XLAAccelerator # noqa: F401
from lightning.fabric.accelerators.xpu import XPUAccelerator # noqa: F401
from lightning.fabric.utilities.registry import _register_classes

ACCELERATOR_REGISTRY = _AcceleratorRegistry()
Expand Down
113 changes: 113 additions & 0 deletions src/lightning/fabric/accelerators/xpu.py
@@ -0,0 +1,113 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import lru_cache
from typing import Any, Dict, List, Union

import torch
from lightning_utilities.core.imports import RequirementCache
from typing_extensions import override

from lightning.fabric.accelerators.accelerator import Accelerator


class XPUAccelerator(Accelerator):
"""Support for a Intel Discrete Graphics Cards 'XPU'."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
if not _IPEX_AVAILABLE:
raise ModuleNotFoundError(str(_IPEX_AVAILABLE))
super().__init__(*args, **kwargs)

@staticmethod
@override
def parse_devices(devices: Any) -> Any:
# Put parsing logic here how devices can be passed into the Trainer
# via the `devices` argument
from lightning.fabric.utilities.device_parser import _parse_gpu_ids

return _parse_gpu_ids(devices, include_xpu=True)

@staticmethod
@override
def get_parallel_devices(devices: Any) -> Any:
# Here, convert the device indices to actual device objects

return [torch.device("xpu", idx) for idx in devices]

@staticmethod
@override
def auto_device_count() -> int:
# Return a value for auto-device selection when `Trainer(devices="auto")`
return num_xpu_devices()

@staticmethod
@override
def is_available() -> bool:
# Carefully check before trying to import:
if _IPEX_AVAILABLE:
import intel_extension_for_pytorch as ipex

return ipex.xpu.is_available()
return False

@override
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
# Return optional device statistics for loggers
return {}

@override
def setup_device(self, device: torch.device) -> None:
pass

@override
def teardown(self) -> None:
pass

@classmethod
@override
def register_accelerators(cls, accelerator_registry: Any) -> None:
accelerator_registry.register(
"xpu",
cls,
description=cls.__name__,
)


_IPEX_AVAILABLE = RequirementCache("intel_extension_for_pytorch>=2.0", "intel_extension_for_pytorch")


@lru_cache(1)
def num_xpu_devices() -> int:
"""Returns the number of available XPU devices.

Unlike :func:`torch.xpu.device_count`, this function does its best not to create a XPU context for fork support,
if the platform allows it.

"""
if _IPEX_AVAILABLE:
import intel_extension_for_pytorch as ipex

return ipex.xpu.device_count()
return 0


def _get_all_visible_xpu_devices() -> List[int]:
"""Returns a list of all visible Intel XPU devices.

Devices masked by the environment variabale ``ZE_AFFINITY_MASK`` won't be returned here. For example, assume you
have 8 physical GPUs. If ``ZE_AFFINITY_MASK="1,3,6"``, then this function will return the list ``[0, 1, 2]``
because these are the three visible GPUs after applying the mask ``ZE_AFFINITY_MASK``.

"""
return list(range(num_xpu_devices()))
2 changes: 1 addition & 1 deletion src/lightning/fabric/cli.py
Expand Up @@ -36,7 +36,7 @@
_CLICK_AVAILABLE = RequirementCache("click")
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")

_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu", "xpu")


def _get_supported_strategies() -> List[str]:
Expand Down
9 changes: 7 additions & 2 deletions src/lightning/fabric/connector.py
Expand Up @@ -23,6 +23,7 @@
from lightning.fabric.accelerators.cuda import CUDAAccelerator
from lightning.fabric.accelerators.mps import MPSAccelerator
from lightning.fabric.accelerators.xla import XLAAccelerator
from lightning.fabric.accelerators.xpu import XPUAccelerator
from lightning.fabric.plugins import (
BitsandbytesPrecision,
CheckpointIO,
Expand Down Expand Up @@ -321,6 +322,8 @@ def _choose_auto_accelerator(self) -> str:
return "mps"
if CUDAAccelerator.is_available():
return "cuda"
if XPUAccelerator.is_available():
return "xpu"
return "cpu"

@staticmethod
Expand All @@ -329,6 +332,8 @@ def _choose_gpu_accelerator_backend() -> str:
return "mps"
if CUDAAccelerator.is_available():
return "cuda"
if XPUAccelerator.is_available():
return "xpu"
raise RuntimeError("No supported gpu backend found!")

def _set_parallel_devices_and_init_accelerator(self) -> None:
Expand Down Expand Up @@ -399,8 +404,8 @@ def _choose_strategy(self) -> Union[Strategy, str]:
if self._num_nodes_flag > 1:
return "ddp"
if len(self._parallel_devices) <= 1:
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or (
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps")
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator, XPUAccelerator)) or (
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps", "xpu")
):
device = _determine_root_gpu_device(self._parallel_devices)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/fabric.py
Expand Up @@ -97,7 +97,7 @@ class Fabric:

Args:
accelerator: The hardware to run on. Possible choices are:
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
``"cpu"``, ``"cuda"``, ``"mps"``, ``"xpu"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
strategy: Strategy for how to run across multiple devices. Possible choices are:
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
Expand Down
9 changes: 8 additions & 1 deletion src/lightning/fabric/strategies/ddp.py
Expand Up @@ -123,8 +123,15 @@ def setup_environment(self) -> None:
def setup_module(self, module: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self._determine_ddp_device_ids()
print(self.root_device)
print(self.root_device.type)
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
if self.root_device.type == "cuda":
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
elif self.root_device.type == "xpu":
ctx = torch.xpu.stream(torch.xpu.Stream()) if device_ids is not None else nullcontext()
else:
ctx = nullcontext()
with ctx:
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)

Expand Down
2 changes: 2 additions & 0 deletions src/lightning/fabric/strategies/strategy.py
Expand Up @@ -326,6 +326,8 @@ def load_checkpoint(

"""
torch.cuda.empty_cache()
if hasattr(torch, "xpu"):
torch.xpu.empty_cache()
checkpoint = self.checkpoint_io.load_checkpoint(path)
if not state:
return checkpoint
Expand Down
24 changes: 24 additions & 0 deletions src/lightning/fabric/utilities/device_dtype_mixin.py
Expand Up @@ -44,6 +44,9 @@ def device(self) -> torch.device:
if device.type == "cuda" and device.index is None:
return torch.device(f"cuda:{torch.cuda.current_device()}")

if hasattr(torch, "xpu") and device.type == "xpu" and device.index is None:
return torch.device(f"xpu:{torch.xpu.current_device()}")

return device

@override
Expand Down Expand Up @@ -75,6 +78,27 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
_update_properties(self, device=device)
return super().cuda(device=device)

@override
def xpu(self, device: Optional[Union[torch.device, int]] = None) -> Self:
"""Moves all model parameters and buffers to the XPU GPU. This also makes associated parameters and buffers
different objects. So it should be called before constructing optimizer if the module will live on GPU while
being optimized.

Arguments:
device: If specified, all parameters will be copied to that device. If `None`, the current XPU device
index will be used.

Returns:
Module: self

"""
if device is None:
device = torch.device("xpu", torch.xpu.current_device())
elif isinstance(device, int):
device = torch.device("xpu", index=device)
_update_properties(self, device=device)
return super().xpu(device=device)

@override
def cpu(self) -> Self:
"""See :meth:`torch.nn.Module.cpu`."""
Expand Down
36 changes: 25 additions & 11 deletions src/lightning/fabric/utilities/device_parser.py
Expand Up @@ -49,6 +49,7 @@ def _parse_gpu_ids(
gpus: Optional[Union[int, str, List[int]]],
include_cuda: bool = False,
include_mps: bool = False,
include_xpu: bool = False,
) -> Optional[List[int]]:
"""Parses the GPU IDs given in the format as accepted by the :class:`~lightning.pytorch.trainer.trainer.Trainer`.

Expand All @@ -60,6 +61,7 @@ def _parse_gpu_ids(
Any int N > 0 indicates that GPUs [0..N) should be used.
include_cuda: A boolean value indicating whether to include CUDA devices for GPU parsing.
include_mps: A boolean value indicating whether to include MPS devices for GPU parsing.
include_xpu: A boolean value indicating whether to include XPU devices for GPU parsing.

Returns:
A list of GPUs to be used or ``None`` if no GPUs were requested
Expand All @@ -69,7 +71,7 @@ def _parse_gpu_ids(
If no GPUs are available but the value of gpus variable indicates request for GPUs

.. note::
``include_cuda`` and ``include_mps`` default to ``False`` so that you only
``include_cuda`` and ``include_mps`` and ``include_xpu`` default to ``False`` so that you only
have to specify which device type to use and all other devices are not disabled.

"""
Expand All @@ -83,23 +85,26 @@ def _parse_gpu_ids(
# We know the user requested GPUs therefore if some of the
# requested GPUs are not available an exception is thrown.
gpus = _normalize_parse_gpu_string_input(gpus)
gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps)
gpus = _normalize_parse_gpu_input_to_list(
gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu
)
if not gpus:
raise MisconfigurationException("GPUs requested but none are available.")

if (
torch.distributed.is_available()
and torch.distributed.is_torchelastic_launched()
and len(gpus) != 1
and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)) == 1
and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu))
== 1
):
# Omit sanity check on torchelastic because by default it shows one visible GPU per process
return gpus

# Check that GPUs are unique. Duplicate GPUs are not supported by the backend.
_check_unique(gpus)

return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps)
return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)


def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:
Expand All @@ -112,7 +117,9 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in
return int(s.strip())


def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]:
def _sanitize_gpu_ids(
gpus: List[int], include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False
) -> List[int]:
"""Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the
GPUs is not available.

Expand All @@ -127,9 +134,11 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps:
If machine has fewer available GPUs than requested.

"""
if sum((include_cuda, include_mps)) == 0:
if sum((include_cuda, include_mps, include_xpu)) == 0:
raise ValueError("At least one gpu type should be specified!")
all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)
all_available_gpus = _get_all_available_gpus(
include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu
)
for gpu in gpus:
if gpu not in all_available_gpus:
raise MisconfigurationException(
Expand All @@ -139,7 +148,7 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps:


def _normalize_parse_gpu_input_to_list(
gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool
gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool, include_xpu: bool
) -> Optional[List[int]]:
assert gpus is not None
if isinstance(gpus, (MutableSequence, tuple)):
Expand All @@ -149,22 +158,27 @@ def _normalize_parse_gpu_input_to_list(
if not gpus: # gpus==0
return None
if gpus == -1:
return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)
return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)

return list(range(gpus))


def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]:
def _get_all_available_gpus(
include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False
) -> List[int]:
"""
Returns:
A list of all available GPUs
"""

from lightning.fabric.accelerators.cuda import _get_all_visible_cuda_devices
from lightning.fabric.accelerators.mps import _get_all_available_mps_gpus
from lightning.fabric.accelerators.xpu import _get_all_visible_xpu_devices

cuda_gpus = _get_all_visible_cuda_devices() if include_cuda else []
mps_gpus = _get_all_available_mps_gpus() if include_mps else []
return cuda_gpus + mps_gpus
xpu_gpus = _get_all_visible_xpu_devices() if include_xpu else []
return cuda_gpus + mps_gpus + xpu_gpus


def _check_unique(device_ids: List[int]) -> None:
Expand Down
10 changes: 9 additions & 1 deletion src/lightning/fabric/utilities/distributed.py
Expand Up @@ -289,6 +289,10 @@ def _init_dist_connection(
os.environ["MASTER_ADDR"] = cluster_environment.main_address
os.environ["MASTER_PORT"] = str(cluster_environment.main_port)
log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")

if torch_distributed_backend.lower() == "ccl":
pass

torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)

# On rank=0 let everyone know training is starting
Expand All @@ -301,7 +305,11 @@ def _init_dist_connection(


def _get_default_process_group_backend_for_device(device: torch.device) -> str:
return "nccl" if device.type == "cuda" else "gloo"
if device.type == "cuda":
return "nccl"
if device.type == "xpu":
return "ccl"
return "gloo"


class _DatasetSamplerWrapper(Dataset):
Expand Down