Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 119 additions & 2 deletions oqpy/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,17 @@
from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING, Iterable, Iterator, Optional, TypeVar, overload
from typing import (
TYPE_CHECKING,
Any,
ContextManager,
Iterable,
Iterator,
Literal,
Optional,
TypeVar,
overload,
)

from openpulse import ast

Expand All @@ -38,7 +48,7 @@
from oqpy.program import Program


__all__ = ["If", "Else", "ForIn", "While", "Range"]
__all__ = ["If", "Else", "ForIn", "While", "Range", "Switch", "Case", "Default"]


@contextlib.contextmanager
Expand Down Expand Up @@ -176,3 +186,110 @@ def While(program: Program, condition: OQPyExpression) -> Iterator[None]:
yield
state = program._pop()
program._add_statement(ast.WhileLoop(to_ast(program, condition), state.body))


class Switch(ContextManager[None]):
"""Context manager for switch statement control flow.

.. code-block:: python

selector = IntVar(0)
with Switch(program, selector):
with Case(program, 0):
program.increment(result, 1)
with Case(program, 1, 2): # Multiple values in one case
program.increment(result, 2)
with Default(program):
program.increment(result, 100)

"""

def __init__(self, program: "Program", target: OQPyExpression):
self.program = program
self.target = target
self.cases: list[tuple[list[ast.Expression], list[ast.Statement]]] = []
self.default: list[ast.Statement] | None = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense for Oqpy in general to allow an empty default, to give full flexibility as per the OpenQASM spec. However, the OpenQASM ast implementation notes

# Note that `None` is quite different to `[]` in this case; the latter is
# an explicitly empty body, whereas the absence of a default might mean
# that the switch is inexhaustive, and a linter might want to complain.

Do we want to add an optional flag to this class to toggle whether a None default is allowed? Or should that be handled by oqpy's consumers?

Two other options on the table are defaulting to an empty block (perhaps risky) or raising an error/warning if no default is given (perhaps annoying)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None default seems like sane behavior since the produced openqasm is closest to what was written. I'm not sure openqasm needs to be the one enforcing default behavior.


def __enter__(self) -> None:
self.program._push()
self.program._state.active_switch = self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> Literal[False]:
# Pop the switch context state
self.program._pop()
if exc_type is not None:
# Don't add statement if an exception occurred; propagate the exception
return False
# Build the case tuples as (list of expressions, CompoundStatement)
case_tuples = [(values, ast.CompoundStatement(body)) for values, body in self.cases]
default_stmt = ast.CompoundStatement(self.default) if self.default else None
stmt = ast.SwitchStatement(
to_ast(self.program, self.target),
case_tuples,
default_stmt,
)
self.program._add_statement(stmt)
# Return False to indicate exceptions should not be suppressed
return False


@contextlib.contextmanager
def Case(program: "Program", *values: AstConvertible) -> Iterator[None]:
"""Context manager for a case within a switch statement.

Must be used inside a Switch context. Multiple values can be provided
for a single case block.

.. code-block:: python

with Switch(program, selector):
with Case(program, 0):
# Handle case 0
program.increment(result, 1)
with Case(program, 1, 2):
# Handle cases 1 and 2
program.increment(result, 2)

"""
if not values:
raise ValueError("Case requires at least one value")
switch = program._state.active_switch
if switch is None:
raise RuntimeError("Case must be used inside a Switch context")
program._push()
yield
state = program._pop()
case_values = [to_ast(program, v) for v in values]
switch.cases.append((case_values, state.body))


@contextlib.contextmanager
def Default(program: "Program") -> Iterator[None]:
"""Context manager for the default case within a switch statement.

Must be used inside a Switch context.

.. code-block:: python

with Switch(program, selector):
with Case(program, 0):
program.increment(result, 1)
with Default(program):
# Handle all other cases
program.increment(result, 100)

"""
switch = program._state.active_switch
if switch is None:
raise RuntimeError("Default must be used inside a Switch context")
if switch.default is not None:
raise RuntimeError("Switch statement can only have one default case")
program._push()
yield
state = program._pop()
switch.default = state.body
29 changes: 25 additions & 4 deletions oqpy/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import warnings
from copy import deepcopy
from typing import Any, Hashable, Iterable, Iterator, Optional
from typing import TYPE_CHECKING, Any, Hashable, Iterable, Iterator, Optional

from openpulse import ast
from openpulse.printer import dumps
Expand All @@ -44,6 +44,9 @@
from oqpy.pulse import FrameVar, PortVar, WaveformVar
from oqpy.timing import convert_duration_to_float, convert_float_to_duration

if TYPE_CHECKING:
from oqpy.control_flow import Switch

__all__ = ["Program"]


Expand All @@ -59,6 +62,7 @@ def __init__(self) -> None:
self.body: list[ast.Statement | ast.Pragma] = []
self.if_clause: Optional[ast.BranchingStatement] = None
self.annotations: list[ast.Annotation] = []
self.active_switch: Optional["Switch"] = None # Set when inside a switch context

def add_if_clause(self, condition: ast.Expression, if_clause: list[ast.Statement]) -> None:
if_clause_annotations, self.annotations = self.annotations, []
Expand All @@ -82,6 +86,10 @@ def add_statement(self, stmt: ast.Statement | ast.Pragma) -> None:
# it seems to conflict with the definition of ast.Program.
# Issue raised in https://github.com/openqasm/openqasm/issues/468
assert isinstance(stmt, (ast.Statement, ast.Pragma))
if self.active_switch is not None:
raise RuntimeError(
"Statements inside a Switch block must be within a Case or Default context"
)
if isinstance(stmt, ast.Statement) and self.annotations:
stmt.annotations = self.annotations + list(stmt.annotations)
self.annotations = []
Expand Down Expand Up @@ -133,7 +141,9 @@ def __iadd__(self, other: Program) -> Program:
self.defcals.update(other.defcals)
for name, subroutine_stmt in other.subroutines.items():
self._add_subroutine(
name, subroutine_stmt, needs_declaration=name not in other.declared_subroutines
name,
subroutine_stmt,
needs_declaration=name not in other.declared_subroutines,
)
for name, gate_stmt in other.gates.items():
self._add_gate(name, gate_stmt, needs_declaration=name not in other.declared_gates)
Expand Down Expand Up @@ -418,7 +428,9 @@ def declare(
return self

def delay(
self, time: AstConvertible, qubits_or_frames: AstConvertible | Iterable[AstConvertible] = ()
self,
time: AstConvertible,
qubits_or_frames: AstConvertible | Iterable[AstConvertible] = (),
Comment on lines +431 to +433
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like your formatter is set to 80 lines, but pyproject.toml specifies 100, perhaps we can revert this and the below changes

) -> Program:
"""Apply a delay to a set of qubits or frames."""
if not isinstance(qubits_or_frames, Iterable):
Expand Down Expand Up @@ -608,7 +620,9 @@ def reset(self, qubit: quantum_types.Qubit) -> Program:
return self

def measure(
self, qubit: quantum_types.Qubit, output_location: classical_types.BitVar | None = None
self,
qubit: quantum_types.Qubit,
output_location: classical_types.BitVar | None = None,
) -> Program:
"""Measure a particular qubit.

Expand Down Expand Up @@ -709,6 +723,13 @@ def visit_BranchingStatement(self, node: ast.BranchingStatement, context: None =
node.else_block = self.process_statement_list(node.else_block)
self.generic_visit(node, context)

def visit_SwitchStatement(self, node: ast.SwitchStatement, context: None = None) -> None:
for _, case_block in node.cases:
case_block.statements = self.process_statement_list(case_block.statements)
if node.default is not None:
node.default.statements = self.process_statement_list(node.default.statements)
self.generic_visit(node, context)

def visit_CalibrationStatement(
self, node: ast.CalibrationStatement, context: None = None
) -> None:
Expand Down
Loading
Loading