# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import numpy as np
from antlr4 import CommonTokenStream, InputStream
from braket.ir.jaqcd import (
Amplitude,
DensityMatrix,
Expectation,
Probability,
Sample,
StateVector,
Variance,
)
from braket.ir.jaqcd.program_v1 import Results
from .generated.BraketPragmasLexer import BraketPragmasLexer
from .generated.BraketPragmasParser import BraketPragmasParser
from .generated.BraketPragmasParserVisitor import BraketPragmasParserVisitor
from .openqasm_parser import parse
[docs]
class BraketPragmaNodeVisitor(BraketPragmasParserVisitor):
"""
This is a visitor for the BraketPragmas grammar. Consumes a
braketPragmas AST and converts to relevant python objects
for use by the Interpreter
"""
def __init__(self, qubit_table: "QubitTable"):
self.qubit_table = qubit_table
[docs]
def visitNoArgResultType(self, ctx: BraketPragmasParser.NoArgResultTypeContext) -> Results:
result_type = ctx.noArgResultTypeName().getText()
no_arg_result_type_map = {
"state_vector": StateVector,
}
return no_arg_result_type_map[result_type]()
[docs]
def visitOptionalMultiTargetResultType(
self, ctx: BraketPragmasParser.OptionalMultiTargetResultTypeContext
) -> Results:
result_type = ctx.optionalMultiTargetResultTypeName().getText()
optional_multitarget_result_type_map = {
"probability": Probability,
"density_matrix": DensityMatrix,
}
targets = self.visit(ctx.multiTarget()) if ctx.multiTarget() is not None else None
return optional_multitarget_result_type_map[result_type](targets=targets)
[docs]
def visitMultiTargetIdentifiers(self, ctx: BraketPragmasParser.MultiTargetIdentifiersContext):
parsable = f"target {''.join(x.getText() for x in ctx.getChildren())};"
parsed_statement = parse(parsable)
target_identifiers = parsed_statement.statements[0].qubits
target = sum(
(self.qubit_table.get_by_identifier(identifier) for identifier in target_identifiers),
(),
)
return target
[docs]
def visitMultiTargetAll(self, ctx: BraketPragmasParser.MultiTargetAllContext):
return
[docs]
def visitMultiStateResultType(
self, ctx: BraketPragmasParser.MultiStateResultTypeContext
) -> Results:
result_type = ctx.multiStateResultTypeName().getText()
multistate_result_type_map = {
"amplitude": Amplitude,
}
states = self.visit(ctx.getChild(1))
return multistate_result_type_map[result_type](states=states)
[docs]
def visitMultiState(self, ctx: BraketPragmasParser.MultiStateContext) -> list[str]:
# unquote and skip commas
states = [x.getText()[1:-1] for x in list(ctx.getChildren())[::2]]
return states
[docs]
def visitObservableResultType(
self, ctx: BraketPragmasParser.ObservableResultTypeContext
) -> Results:
result_type = ctx.observableResultTypeName().getText()
observable_result_type_map = {
"expectation": Expectation,
"sample": Sample,
"variance": Variance,
}
observables, targets = self.visit(ctx.observable())
obs = observable_result_type_map[result_type](targets=targets, observable=observables)
return obs
[docs]
def visitStandardObservableIdentifier(
self,
ctx: BraketPragmasParser.StandardObservableIdentifierContext,
) -> tuple[tuple[str], int]:
observable = ctx.standardObservableName().getText()
target_tuple = self.visit(ctx.indexedIdentifier())
if len(target_tuple) != 1:
raise ValueError("Standard observable target must be exactly 1 qubit.")
return (observable,), target_tuple
[docs]
def visitStandardObservableAll(
self,
ctx: BraketPragmasParser.StandardObservableAllContext,
) -> tuple[tuple[str], None]:
observable = ctx.standardObservableName().getText()
return (observable,), None
[docs]
def visitTensorProductObservable(
self, ctx: BraketPragmasParser.TensorProductObservableContext
) -> tuple[tuple[str], tuple[int]]:
observables, targets = zip(
*(self.visit(ctx.getChild(i)) for i in range(0, ctx.getChildCount(), 2))
)
observables = sum(observables, ())
targets = sum(targets, ())
return observables, targets
[docs]
def visitHermitianObservable(
self, ctx: BraketPragmasParser.HermitianObservableContext
) -> tuple[tuple[list[list[float]]], int]:
matrix = self.visit(ctx.twoDimMatrix())
matrix = np.expand_dims(matrix, axis=-1)
converted = np.append(matrix.real, matrix.imag, axis=-1).tolist()
target = self.visit(ctx.multiTarget())
return (converted,), target
[docs]
def visitIndexedIdentifier(
self, ctx: BraketPragmasParser.IndexedIdentifierContext
) -> tuple[int]:
parsable = f"target {''.join(x.getText() for x in ctx.getChildren())};"
parsed_statement = parse(parsable)
identifier = parsed_statement.statements[0].qubits[0]
target = self.qubit_table.get_by_identifier(identifier)
return target
[docs]
def visitComplexOneValue(self, ctx: BraketPragmasParser.ComplexOneValueContext) -> list[float]:
sign = -1 if ctx.neg else 1
value = ctx.value.text
imag = False
if value.endswith("im"):
value = value[:-2]
imag = True
complex_array = [0, 0]
complex_array[imag] = sign * float(value)
return complex_array
[docs]
def visitComplexTwoValues(
self, ctx: BraketPragmasParser.ComplexTwoValuesContext
) -> list[float]:
real = float(ctx.real.text)
imag = float(ctx.imag.text[:-2]) # exclude "im"
if ctx.neg:
real *= -1
if ctx.imagNeg:
imag *= -1
if ctx.sign.text == "-":
imag *= -1
return [real, imag]
[docs]
def visitBraketUnitaryPragma(
self, ctx: BraketPragmasParser.BraketUnitaryPragmaContext
) -> tuple[np.ndarray, tuple[int]]:
target = self.visit(ctx.multiTarget())
matrix = self.visit(ctx.twoDimMatrix())
return matrix, target
[docs]
def visitRow(self, ctx: BraketPragmasParser.RowContext) -> list[complex]:
numbers = ctx.children[1::2]
return [x[0] + x[1] * 1j for x in [self.visit(number) for number in numbers]]
[docs]
def visitTwoDimMatrix(self, ctx: BraketPragmasParser.TwoDimMatrixContext) -> np.ndarray:
rows = [self.visit(row) for row in ctx.children[1::2]]
if not all(len(r) == len(rows) for r in rows):
raise TypeError("Not a valid square matrix")
matrix = np.array(rows)
return matrix
[docs]
def visitNoise(self, ctx: BraketPragmasParser.NoiseContext):
target = self.visit(ctx.target)
probabilities = self.visit(ctx.probabilities())
noise_instruction = ctx.noiseInstructionName().getText()
return noise_instruction, target, probabilities
[docs]
def visitKraus(self, ctx: BraketPragmasParser.KrausContext):
target = self.visit(ctx.target)
matrices = [self.visit(m) for m in ctx.matrices().children[::2]]
return matrices, target
[docs]
def visitProbabilities(self, ctx: BraketPragmasParser.ProbabilitiesContext):
return [float(prob.symbol.text) for prob in ctx.children[::2]]
[docs]
def parse_braket_pragma(pragma_body: str, qubit_table: "QubitTable"):
"""Parse braket pragma and return relevant information.
Pragma types include:
- result types
- custom unitary operations
"""
data = InputStream(pragma_body)
lexer = BraketPragmasLexer(data)
stream = CommonTokenStream(lexer)
parser = BraketPragmasParser(stream)
tree = parser.braketPragma()
visited = BraketPragmaNodeVisitor(qubit_table).visit(tree)
return visited