Source code for braket.default_simulator.openqasm.interpreter

# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from collections.abc import Iterable
from copy import deepcopy
from dataclasses import fields
from functools import singledispatchmethod
from logging import Logger, getLogger
from typing import Optional, Union

import numpy as np
from braket.ir.openqasm.program_v1 import io_type
from sympy import Symbol

from ._helpers.arrays import (
    convert_range_def_to_range,
    create_empty_array,
    get_elements,
    get_type_width,
)
from ._helpers.casting import (
    LiteralType,
    cast_to,
    get_identifier_name,
    is_literal,
    wrap_value_into_literal,
)
from ._helpers.functions import (
    builtin_constants,
    builtin_functions,
    evaluate_binary_expression,
    evaluate_unary_expression,
    get_operator_of_assignment_operator,
)
from ._helpers.quantum import (
    convert_phase_to_gate,
    get_ctrl_modifiers,
    get_pow_modifiers,
    invert_phase,
    is_controlled,
    is_inverted,
    modify_body,
)
from .circuit import Circuit
from .parser.openqasm_ast import (
    AccessControl,
    ArrayLiteral,
    ArrayReferenceType,
    ArrayType,
    AssignmentOperator,
    BinaryExpression,
    BitstringLiteral,
    BitType,
    BooleanLiteral,
    BranchingStatement,
    Cast,
    ClassicalArgument,
    ClassicalAssignment,
    ClassicalDeclaration,
    ConstantDeclaration,
    DiscreteSet,
    FloatLiteral,
    ForInLoop,
    FunctionCall,
    GateModifierName,
    Identifier,
    Include,
    IndexedIdentifier,
    IndexExpression,
    IntegerLiteral,
    IODeclaration,
    IOKeyword,
    Pragma,
    Program,
    QASMNode,
    QuantumGate,
    QuantumGateDefinition,
    QuantumGateModifier,
    QuantumMeasurement,
    QuantumMeasurementStatement,
    QuantumPhase,
    QuantumReset,
    QuantumStatement,
    QubitDeclaration,
    RangeDefinition,
    ReturnStatement,
    SizeOf,
    SubroutineDefinition,
    SymbolLiteral,
    UnaryExpression,
    WhileLoop,
)
from .parser.openqasm_parser import parse
from .program_context import AbstractProgramContext, ProgramContext


[docs] class Interpreter: """ The interpreter is responsible for visiting the AST of an OpenQASM program, as created by the parser, and building a braket.default_simulator.openqasm.circuit.Circuit to hand off to a simulator e.g. braket.default_simulator.state_vector_simulator.StateVectorSimulator. The interpreter keeps track of all state using a ProgramContext object. The main entry point is build_circuit(), which returns the built circuit. An alternative entry poitn, run() returns the ProgramContext object, which can be used for debugging or other customizability. """ def __init__( self, context: Optional[AbstractProgramContext] = None, logger: Optional[Logger] = None ): # context keeps track of all state self.context = context or ProgramContext() self.logger = logger or getLogger(__name__) self._uses_advanced_language_features = False
[docs] def build_circuit( self, source: str, inputs: Optional[dict[str, io_type]] = None, is_file: bool = False ) -> Circuit: """Interpret an OpenQASM program and build a Circuit IR.""" return self.run(source, inputs, is_file).circuit
[docs] def run( self, source: str, inputs: Optional[dict[str, io_type]] = None, is_file: bool = False ) -> ProgramContext: """Interpret an OpenQASM program and return the program state""" if inputs: self.context.load_inputs(inputs) if is_file: with open(source, encoding="utf-8", mode="r") as f: source = f.read() program = parse(source) self._uses_advanced_language_features = False self.visit(program) if self._uses_advanced_language_features: self.logger.warning( "This program uses OpenQASM language features that may " "not be supported on QPUs or on-demand simulators." ) return self.context
[docs] @singledispatchmethod def visit(self, node: Union[QASMNode, list[QASMNode]]) -> Optional[QASMNode]: """Generic visit function for an AST node""" if node is None: return if not isinstance(node, QASMNode): return node for field in fields(node): value = getattr(node, field.name) setattr(node, field.name, self.visit(value)) return node
@visit.register def _(self, node_list: list) -> list[QASMNode]: """Generic visit function for a list of AST nodes""" return [n for n in [self.visit(node) for node in node_list] if n is not None] @visit.register def _(self, node: Program) -> None: self.visit(node.statements) @visit.register def _(self, node: ClassicalDeclaration) -> None: node_type = self.visit(node.type) if node.init_expression is not None: init_expression = self.visit(node.init_expression) init_value = cast_to(node.type, init_expression) elif isinstance(node_type, ArrayType): init_value = create_empty_array(node_type.dimensions) elif isinstance(node_type, BitType) and node_type.size: init_value = create_empty_array([node_type.size]) else: init_value = None self.context.declare_variable(node.identifier.name, node_type, init_value) @visit.register def _(self, node: IODeclaration) -> None: if node.io_identifier == IOKeyword.output: raise NotImplementedError("Output not supported") else: # IOKeyword.input: if node.identifier.name not in self.context.inputs: # previously raised a NameError init_value = wrap_value_into_literal(Symbol(node.identifier.name)) node_type = SymbolLiteral else: init_value = wrap_value_into_literal(self.context.inputs[node.identifier.name]) node_type = node.type declaration = ClassicalDeclaration(node_type, node.identifier, init_value) self.visit(declaration) @visit.register def _(self, node: ConstantDeclaration) -> None: self._uses_advanced_language_features = True node_type = self.visit(node.type) init_expression = self.visit(node.init_expression) init_value = cast_to(node.type, init_expression) self.context.declare_variable(node.identifier.name, node_type, init_value, const=True) @visit.register def _(self, node: BinaryExpression) -> Union[BinaryExpression, LiteralType]: lhs = self.visit(node.lhs) rhs = self.visit(node.rhs) if is_literal(lhs) and is_literal(rhs): return evaluate_binary_expression(lhs, rhs, node.op) else: return BinaryExpression(node.op, lhs, rhs) @visit.register def _(self, node: UnaryExpression) -> Union[UnaryExpression, LiteralType]: expression = self.visit(node.expression) if is_literal(expression): return evaluate_unary_expression(expression, node.op) else: return UnaryExpression(node.op, expression) @visit.register def _(self, node: Cast) -> LiteralType: return cast_to(node.type, self.visit(node.argument)) @visit.register(BooleanLiteral) @visit.register(IntegerLiteral) @visit.register(FloatLiteral) def _(self, node: LiteralType) -> LiteralType: return node @visit.register def _(self, node: Identifier) -> LiteralType: if node.name.startswith("$"): return node if node.name in builtin_constants: return builtin_constants[node.name] if not self.context.is_initialized(node.name): raise NameError(f"Identifier '{node.name}' is not initialized.") return self.context.get_value_by_identifier(node) @visit.register def _(self, node: QubitDeclaration) -> None: size = self.visit(node.size).value if node.size else 1 self.context.add_qubits(node.qubit.name, size) @visit.register def _(self, node: QuantumReset) -> None: raise NotImplementedError("Reset not supported") @visit.register def _(self, node: IndexedIdentifier) -> Union[IndexedIdentifier, LiteralType]: """Returns an identifier for qubits, value for classical identifier""" name = node.name indices = [] for index in node.indices: if isinstance(index, DiscreteSet): self._uses_advanced_language_features = True indices.append(index) else: for element in index: if isinstance(element, RangeDefinition): self._uses_advanced_language_features = True element = self.visit(element) indices.append([element]) updated = IndexedIdentifier(name, indices) if name.name not in self.context.qubit_mapping: return self.context.get_value_by_identifier(updated) return updated @visit.register def _(self, node: RangeDefinition) -> RangeDefinition: self._uses_advanced_language_features = True start = self.visit(node.start) if node.start else None end = self.visit(node.end) step = self.visit(node.step) if node.step else None return RangeDefinition(start, end, step) @visit.register def _(self, node: IndexExpression) -> Union[IndexedIdentifier, ArrayLiteral]: """Returns an identifier for qubits, values for classical identifier""" type_width = None index = self.visit(node.index) if isinstance(node.collection, Identifier): # indexed QuantumArgument if isinstance(self.context.get_type(node.collection.name), type(Identifier)): return IndexedIdentifier(node.collection, [index]) var_type = self.context.get_type(get_identifier_name(node.collection)) type_width = get_type_width(var_type) collection = self.visit(node.collection) return get_elements(collection, index, type_width) @visit.register def _(self, node: QuantumGateDefinition) -> None: self._uses_advanced_language_features = True with self.context.enter_scope(): for qubit in node.qubits: self.context.declare_qubit_alias(qubit.name, qubit) for param in node.arguments: self.context.declare_variable(param.name, Identifier, param) node.body = self.inline_gate_def_body(node.body) self.context.add_gate(node.name.name, node)
[docs] def inline_gate_def_body(self, body: list[QuantumStatement]) -> list[QuantumStatement]: inlined_body = [] for statement in body: if isinstance(statement, QuantumPhase): statement.argument = self.visit(statement.argument) statement.modifiers = self.visit(statement.modifiers) if is_inverted(statement): statement = invert_phase(statement) if is_controlled(statement): statement = convert_phase_to_gate(statement) # statement is a quantum phase instruction else: inlined_body.append(statement) # this includes converted phase instructions if isinstance(statement, QuantumGate): gate_name = statement.name.name statement.arguments = self.visit(statement.arguments) statement.modifiers = self.visit(statement.modifiers) statement.qubits = self.visit(statement.qubits) if self.context.is_builtin_gate(gate_name): inlined_body.append(statement) else: with self.context.enter_scope(): gate_def = self.context.get_gate_definition(gate_name) ctrl_modifiers = get_ctrl_modifiers(statement.modifiers) pow_modifiers = get_pow_modifiers(statement.modifiers) num_ctrl = sum(mod.argument.value for mod in ctrl_modifiers) ctrl_qubits = statement.qubits[:num_ctrl] gate_qubits = statement.qubits[num_ctrl:] for qubit_called, qubit_defined in zip(gate_qubits, gate_def.qubits): self.context.declare_qubit_alias(qubit_defined.name, qubit_called) for param_called, param_defined in zip( statement.arguments, gate_def.arguments ): self.context.declare_variable( param_defined.name, Identifier, param_called ) inlined_copy = self.inline_gate_def_body(deepcopy(gate_def.body)) inlined_body += modify_body( inlined_copy, is_inverted(statement), ctrl_modifiers, ctrl_qubits, pow_modifiers, ) return inlined_body
@visit.register def _(self, node: QuantumGate) -> None: gate_name = node.name.name arguments = self.visit(node.arguments) modifiers = self.visit(node.modifiers) if self.context.in_global_scope and modifiers: self._uses_advanced_language_features = True qubits = [] for qubit in node.qubits: if isinstance(qubit, Identifier): qubits.append(self.visit(qubit)) else: # IndexedIdentifier dereffed_name = self.visit(qubit.name) simplified_indices = self.visit(qubit.indices) qubits.append(IndexedIdentifier(dereffed_name, simplified_indices)) qubit_lengths = np.array( [self.context.qubit_mapping.get_qubit_size(qubit) for qubit in qubits] ) register_lengths = qubit_lengths[qubit_lengths > 1] if register_lengths.size: reg_length = register_lengths[0] if not np.all(register_lengths == reg_length): raise ValueError("Qubit registers must all be the same length.") for i in range(reg_length): indexed_qubits = deepcopy(qubits) for j, qubit_length in enumerate(qubit_lengths): if qubit_length > 1: if isinstance(indexed_qubits[j], Identifier): indexed_qubits[j] = IndexedIdentifier( indexed_qubits[j], [[IntegerLiteral(i)]] ) else: indexed_qubits[j].indices.append([IntegerLiteral(i)]) gate_call = QuantumGate( modifiers, node.name, arguments, indexed_qubits, ) self.visit(gate_call) return if self.context.is_builtin_gate(gate_name): # to simplify indices qubits = self.visit(qubits) self.handle_builtin_gate( gate_name, arguments, qubits, modifiers, ) else: with self.context.enter_scope(): gate_def = self.context.get_gate_definition(gate_name) ctrl_modifiers = get_ctrl_modifiers(modifiers) pow_modifiers = get_pow_modifiers(modifiers) num_ctrl = sum(mod.argument.value for mod in ctrl_modifiers) ctrl_qubits = qubits[:num_ctrl] gate_qubits = qubits[num_ctrl:] modified_gate_body = modify_body( deepcopy(gate_def.body), is_inverted(node), ctrl_modifiers, ctrl_qubits, pow_modifiers, ) for qubit_called, qubit_defined in zip(gate_qubits, gate_def.qubits): self.context.declare_qubit_alias(qubit_defined.name, qubit_called) for param_called, param_defined in zip(arguments, gate_def.arguments): self.context.declare_variable(param_defined.name, FloatLiteral, param_called) for statement in deepcopy(modified_gate_body): if isinstance(statement, QuantumGate): self.visit(statement) else: # QuantumPhase phase = self.visit(statement.argument) self.handle_phase(phase, qubits) @visit.register def _(self, node: QuantumPhase) -> None: node.argument = self.visit(node.argument) node.modifiers = self.visit(node.modifiers) if is_inverted(node): node = invert_phase(node) if is_controlled(node): node = convert_phase_to_gate(node) self.visit(node) else: self.handle_phase(node.argument) @visit.register def _(self, node: QuantumGateModifier) -> QuantumGateModifier: if node.modifier in (GateModifierName.ctrl, GateModifierName.negctrl): if node.argument is None: node.argument = IntegerLiteral(1) else: node.argument = self.visit(node.argument) elif node.modifier == GateModifierName.pow: node.argument = self.visit(node.argument) return node @visit.register def _(self, node: QuantumMeasurement) -> None: qubits = self.context.get_qubits(self.visit(node.qubit)) self.context.add_measure(qubits) @visit.register def _(self, node: QuantumMeasurementStatement) -> None: """The measure is performed but the assignment is ignored""" self.visit(node.measure) @visit.register def _(self, node: ClassicalAssignment) -> None: lvalue_name = get_identifier_name(node.lvalue) if self.context.get_const(lvalue_name): raise TypeError(f"Cannot update const value {lvalue_name}") if node.op == getattr(AssignmentOperator, "="): rvalue = self.visit(node.rvalue) else: op = get_operator_of_assignment_operator(node.op) binary_expression = BinaryExpression(op, node.lvalue, node.rvalue) rvalue = self.visit(binary_expression) lvalue = node.lvalue if isinstance(lvalue, IndexedIdentifier): lvalue.indices = self.visit(lvalue.indices) elif isinstance(rvalue, SymbolLiteral): pass else: rvalue = cast_to(self.context.get_type(lvalue.name), rvalue) self.context.update_value(lvalue, rvalue) @visit.register def _(self, node: BitstringLiteral) -> ArrayLiteral: return cast_to(BitType(IntegerLiteral(node.width)), node) @visit.register def _(self, node: BranchingStatement) -> None: self._uses_advanced_language_features = True condition = cast_to(BooleanLiteral, self.visit(node.condition)) block = node.if_block if condition.value else node.else_block for statement in block: self.visit(statement) @visit.register def _(self, node: ForInLoop) -> None: self._uses_advanced_language_features = True index = self.visit(node.set_declaration) if isinstance(index, RangeDefinition): index_values = [IntegerLiteral(x) for x in convert_range_def_to_range(index)] # DiscreteSet else: index_values = index.values block = node.block for i in index_values: block_copy = deepcopy(block) with self.context.enter_scope(): self.context.declare_variable(node.identifier.name, node.type, i) self.visit(block_copy) @visit.register def _(self, node: WhileLoop) -> None: self._uses_advanced_language_features = True while cast_to(BooleanLiteral, self.visit(deepcopy(node.while_condition))).value: self.visit(deepcopy(node.block)) @visit.register def _(self, node: Include) -> None: self._uses_advanced_language_features = True with open(node.filename, encoding="utf-8", mode="r") as f: included = f.read() parsed = parse(included) self.visit(parsed) @visit.register def _(self, node: Pragma) -> None: parsed = self.context.parse_pragma(node.command) if node.command.startswith("braket result"): if not parsed: raise TypeError(f"Result type {node.command.split()[2]} is not supported.") self.context.add_result(parsed) elif node.command.startswith("braket unitary"): unitary, target = parsed self.context.add_custom_unitary(unitary, target) elif node.command.startswith("braket noise kraus"): matrices, target = parsed self.context.add_kraus_instruction(matrices, target) elif node.command.startswith("braket noise"): noise_instruction, target, probabilities = parsed self.context.add_noise_instruction(noise_instruction, target, probabilities) elif node.command.startswith("braket verbatim"): pass else: raise NotImplementedError(f"Pragma '{node.command}' is not supported") @visit.register def _(self, node: SubroutineDefinition) -> None: self._uses_advanced_language_features = True # todo: explicitly handle references to existing variables # either by throwing an error or evaluating the closure. # currently, the implementation does not consider the values # of current-scope variables used inside of the function # at the time of function definition, and relies on their values # at the time of execution. This is incorrect, but currently an # edge case and known limitation. More effort can be invested here # if this functionality is prioritized. self.context.add_subroutine(node.name.name, node) @visit.register def _(self, node: FunctionCall) -> Optional[QASMNode]: self._uses_advanced_language_features = True function_name = node.name.name arguments = self.visit(node.arguments) if function_name in builtin_functions: return builtin_functions[function_name](*arguments) function_def = self.context.get_subroutine_definition(function_name) with self.context.enter_scope(): for arg_passed, arg_defined in zip(arguments, function_def.arguments): if isinstance(arg_defined, ClassicalArgument): arg_name = arg_defined.name.name arg_type = arg_defined.type arg_const = arg_defined.access == AccessControl.const arg_value = deepcopy(arg_passed) self.context.declare_variable(arg_name, arg_type, arg_value, arg_const) else: # QuantumArgument qubit_name = get_identifier_name(arg_defined.name) self.context.declare_qubit_alias(qubit_name, arg_passed) return_value = None for statement in deepcopy(function_def.body): visited = self.visit(statement) if isinstance(statement, ReturnStatement): return_value = visited break for arg_passed, arg_defined in zip(node.arguments, function_def.arguments): if isinstance(arg_defined, ClassicalArgument): if isinstance(arg_defined.type, ArrayReferenceType): if isinstance(arg_passed, IndexExpression): identifier = IndexedIdentifier( arg_passed.collection, [arg_passed.index] ) identifier.indices = self.visit(identifier.indices) else: identifier = arg_passed reference_value = self.context.get_value(arg_defined.name.name) self.context.update_value(identifier, reference_value) return return_value @visit.register def _(self, node: ReturnStatement) -> Optional[QASMNode]: self._uses_advanced_language_features = True return self.visit(node.expression) @visit.register def _(self, node: SizeOf) -> IntegerLiteral: self._uses_advanced_language_features = True target = self.visit(node.target) index = self.visit(node.index) return builtin_functions["sizeof"](target, index)
[docs] def handle_builtin_gate( self, gate_name: str, arguments: list[FloatLiteral], qubits: list[Union[Identifier, IndexedIdentifier]], modifiers: list[QuantumGateModifier], ) -> None: """Add unitary operation to the circuit""" self.context.add_builtin_gate( gate_name, arguments, qubits, modifiers, )
[docs] def handle_phase(self, phase: FloatLiteral, qubits: Optional[Iterable[int]] = None) -> None: """Add quantum phase operation to the circuit""" self.context.add_phase(phase, qubits)