Skip to content

Commit

Permalink
[HOTFIX] Restrict MappableRegister use + Type hint improvements (#499)
Browse files Browse the repository at this point in the history
* Fix MappableRegister fails

* Block SLM mask with mappable register

* Improve device type hints

* UT for failed serialization

* Refact test_sequence.py to use fixtures

* Tests for switch_device with mappable registers

* Add test for SLM mask block

* Bump to version 0.11.1

* Restrict `pasqal-sdk` to avoid breaking changes

* Adding clarifying comment
  • Loading branch information
HGSilveri committed Apr 11, 2023
1 parent c26d4d8 commit 43ce7f8
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 58 deletions.
2 changes: 1 addition & 1 deletion VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.11.0
0.11.1
5 changes: 2 additions & 3 deletions pulser-core/pulser/json/abstract_repr/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from pulser.channels.base_channel import Channel
from pulser.channels.eom import RydbergBeam, RydbergEOM
from pulser.devices import Device, VirtualDevice
from pulser.devices._device_datacls import BaseDevice
from pulser.json.abstract_repr.signatures import (
BINARY_OPERATORS,
UNARY_OPERATORS,
Expand Down Expand Up @@ -414,15 +413,15 @@ def deserialize_abstract_sequence(obj_str: str) -> Sequence:
return seq


def deserialize_device(obj_str: str) -> BaseDevice:
def deserialize_device(obj_str: str) -> Device | VirtualDevice:
"""Deserialize a device from an abstract JSON object.
Args:
obj_str: the JSON string representing the device encoded
in the abstract JSON format.
Returns:
BaseDevice: The Pulser device.
The Pulser device.
Raises:
DeserializeDeviceError: Whenever the device deserialization
Expand Down
9 changes: 8 additions & 1 deletion pulser-core/pulser/json/abstract_repr/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,14 @@ def convert_targets(
og_dim = target_array.ndim
if og_dim == 0:
target_array = target_array[np.newaxis]
indices = seq.register.find_indices(target_array.tolist())
try:
indices = seq.register.find_indices(target_array.tolist())
# RuntimeError raised when calling seq.register for a MappableRegister
except RuntimeError:
raise NotImplementedError(
"Serialization of sequences with local operations and"
" a mappable register is currently not supported."
)
return indices[0] if og_dim == 0 else indices

def get_kwarg_default(call_name: str, kwarg_name: str) -> Any:
Expand Down
30 changes: 23 additions & 7 deletions pulser-core/pulser/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,16 @@
import warnings
from collections.abc import Iterable, Mapping
from sys import version_info
from typing import Any, Optional, Tuple, Union, cast, overload
from typing import (
Any,
Generic,
Optional,
Tuple,
TypeVar,
Union,
cast,
overload,
)

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -61,11 +70,12 @@
" `pip install typing-extensions`."
)

DeviceType = TypeVar("DeviceType", bound=BaseDevice)

PROTOCOLS = Literal["min-delay", "no-delay", "wait-for-all"]


class Sequence:
class Sequence(Generic[DeviceType]):
"""A sequence of operations on a device.
A sequence is composed by
Expand Down Expand Up @@ -103,7 +113,7 @@ class Sequence:
def __init__(
self,
register: Union[BaseRegister, MappableRegister],
device: BaseDevice,
device: DeviceType,
):
"""Initializes a new pulse sequence."""
if not isinstance(device, BaseDevice):
Expand All @@ -119,7 +129,7 @@ def __init__(
device.validate_register(register)

self._register: Union[BaseRegister, MappableRegister] = register
self._device: BaseDevice = device
self._device = device
self._in_xy: bool = False
self._mag_field: Optional[tuple[float, float, float]] = None
self._calls: list[_Call] = [
Expand Down Expand Up @@ -161,7 +171,7 @@ def qubit_info(self) -> dict[QubitId, np.ndarray]:
return cast(BaseRegister, self._register).qubits

@property
def device(self) -> BaseDevice:
def device(self) -> DeviceType:
"""Device that the sequence is using."""
return self._device

Expand Down Expand Up @@ -382,6 +392,12 @@ def config_slm_mask(self, qubits: Iterable[QubitId]) -> None:
raise ValueError(
f"The '{self._device}' device does not have an SLM mask."
)

if self.is_register_mappable():
raise RuntimeError(
"The SLM mask can't be combined with a mappable register."
)

try:
targets = set(qubits)
except TypeError:
Expand All @@ -400,7 +416,7 @@ def config_slm_mask(self, qubits: Iterable[QubitId]) -> None:
self._slm_mask_targets = targets

def switch_device(
self, new_device: BaseDevice, strict: bool = False
self, new_device: DeviceType, strict: bool = False
) -> Sequence:
"""Switch the device of a sequence.
Expand Down Expand Up @@ -527,7 +543,7 @@ def check_retarget(ch_obj: Channel) -> bool:
else:
raise TypeError(ch_match_err)
# Initialize the new sequence (works for Sequence subclasses too)
new_seq = type(self)(register=self.register, device=new_device)
new_seq = type(self)(register=self._register, device=new_device)

# Copy the variables to the new sequence
new_seq._variables = self.declared_variables
Expand Down
2 changes: 1 addition & 1 deletion pulser-pasqal/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pasqal-sdk ~= 0.1.12
pasqal-sdk >= 0.1.12, < 0.1.15
18 changes: 18 additions & 0 deletions tests/test_abstract_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,24 @@ def test_default_basis(
"basis", "ground-rydberg"
)

@pytest.mark.xfail(reason="Can't get index of mappable register qubits.")
@pytest.mark.parametrize(
"op,args",
[
("target", ("q0", "raman_local")),
("phase_shift", (1, "q0", "q1")),
],
)
def test_mappable_reg_with_local_ops(
self, op, args, triangular_lattice: TriangularLatticeLayout
):
mappable_reg = triangular_lattice.make_mappable_register(5)
seq = Sequence(mappable_reg, MockDevice)

seq.declare_channel("raman_local", "raman_local")
getattr(seq, op)(*args)
seq.to_abstract_repr()


def _get_serialized_seq(
operations: list[dict] = None,
Expand Down

0 comments on commit 43ce7f8

Please sign in to comment.