Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HOTFIX] Restrict MappableRegister use + Type hint improvements #499

Merged
merged 10 commits into from
Apr 11, 2023
2 changes: 1 addition & 1 deletion VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.11.0
0.11.1
5 changes: 2 additions & 3 deletions pulser-core/pulser/json/abstract_repr/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from pulser.channels.base_channel import Channel
from pulser.channels.eom import RydbergBeam, RydbergEOM
from pulser.devices import Device, VirtualDevice
from pulser.devices._device_datacls import BaseDevice
from pulser.json.abstract_repr.signatures import (
BINARY_OPERATORS,
UNARY_OPERATORS,
Expand Down Expand Up @@ -414,15 +413,15 @@ def deserialize_abstract_sequence(obj_str: str) -> Sequence:
return seq


def deserialize_device(obj_str: str) -> BaseDevice:
def deserialize_device(obj_str: str) -> Device | VirtualDevice:
"""Deserialize a device from an abstract JSON object.

Args:
obj_str: the JSON string representing the device encoded
in the abstract JSON format.

Returns:
BaseDevice: The Pulser device.
The Pulser device.

Raises:
DeserializeDeviceError: Whenever the device deserialization
Expand Down
8 changes: 7 additions & 1 deletion pulser-core/pulser/json/abstract_repr/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,13 @@ def convert_targets(
og_dim = target_array.ndim
if og_dim == 0:
target_array = target_array[np.newaxis]
indices = seq.register.find_indices(target_array.tolist())
try:
indices = seq.register.find_indices(target_array.tolist())
except RuntimeError:
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(
"Serialization of sequences with local operations and"
" a mappable register is currently not supported."
)
return indices[0] if og_dim == 0 else indices

def get_kwarg_default(call_name: str, kwarg_name: str) -> Any:
Expand Down
30 changes: 23 additions & 7 deletions pulser-core/pulser/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,16 @@
import warnings
from collections.abc import Iterable, Mapping
from sys import version_info
from typing import Any, Optional, Tuple, Union, cast, overload
from typing import (
Any,
Generic,
Optional,
Tuple,
TypeVar,
Union,
cast,
overload,
)

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -61,11 +70,12 @@
" `pip install typing-extensions`."
)

DeviceType = TypeVar("DeviceType", bound=BaseDevice)
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved

PROTOCOLS = Literal["min-delay", "no-delay", "wait-for-all"]


class Sequence:
class Sequence(Generic[DeviceType]):
"""A sequence of operations on a device.

A sequence is composed by
Expand Down Expand Up @@ -103,7 +113,7 @@ class Sequence:
def __init__(
self,
register: Union[BaseRegister, MappableRegister],
device: BaseDevice,
device: DeviceType,
):
"""Initializes a new pulse sequence."""
if not isinstance(device, BaseDevice):
Expand All @@ -119,7 +129,7 @@ def __init__(
device.validate_register(register)

self._register: Union[BaseRegister, MappableRegister] = register
self._device: BaseDevice = device
self._device = device
self._in_xy: bool = False
self._mag_field: Optional[tuple[float, float, float]] = None
self._calls: list[_Call] = [
Expand Down Expand Up @@ -161,7 +171,7 @@ def qubit_info(self) -> dict[QubitId, np.ndarray]:
return cast(BaseRegister, self._register).qubits

@property
def device(self) -> BaseDevice:
def device(self) -> DeviceType:
"""Device that the sequence is using."""
return self._device

Expand Down Expand Up @@ -382,6 +392,12 @@ def config_slm_mask(self, qubits: Iterable[QubitId]) -> None:
raise ValueError(
f"The '{self._device}' device does not have an SLM mask."
)

if self.is_register_mappable():
raise RuntimeError(
"The SLM mask can't be combined with a mappable register."
)

try:
targets = set(qubits)
except TypeError:
Expand All @@ -400,7 +416,7 @@ def config_slm_mask(self, qubits: Iterable[QubitId]) -> None:
self._slm_mask_targets = targets

def switch_device(
self, new_device: BaseDevice, strict: bool = False
self, new_device: DeviceType, strict: bool = False
) -> Sequence:
"""Switch the device of a sequence.

Expand Down Expand Up @@ -527,7 +543,7 @@ def check_retarget(ch_obj: Channel) -> bool:
else:
raise TypeError(ch_match_err)
# Initialize the new sequence (works for Sequence subclasses too)
new_seq = type(self)(register=self.register, device=new_device)
new_seq = type(self)(register=self._register, device=new_device)

# Copy the variables to the new sequence
new_seq._variables = self.declared_variables
Expand Down
2 changes: 1 addition & 1 deletion pulser-pasqal/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pasqal-sdk ~= 0.1.12
pasqal-sdk >= 0.1.12, < 0.1.15
18 changes: 18 additions & 0 deletions tests/test_abstract_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,24 @@ def test_default_basis(
"basis", "ground-rydberg"
)

@pytest.mark.xfail(reason="Can't get index of mappable register qubits.")
@pytest.mark.parametrize(
"op,args",
[
("target", ("q0", "raman_local")),
("phase_shift", (1, "q0", "q1")),
],
)
def test_mappable_reg_with_local_ops(
self, op, args, triangular_lattice: TriangularLatticeLayout
):
mappable_reg = triangular_lattice.make_mappable_register(5)
seq = Sequence(mappable_reg, MockDevice)

seq.declare_channel("raman_local", "raman_local")
getattr(seq, op)(*args)
seq.to_abstract_repr()


def _get_serialized_seq(
operations: list[dict] = None,
Expand Down