Source code for braket.default_simulator.result_types

# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import annotations

from abc import ABC, abstractmethod
from functools import singledispatch
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
from braket.ir import jaqcd

from braket.default_simulator.linalg_utils import marginal_probability, partial_trace
from braket.default_simulator.observables import (
    Hadamard,
    Hermitian,
    Identity,
    PauliX,
    PauliY,
    PauliZ,
    TensorProduct,
)
from braket.default_simulator.operation import Observable
from braket.default_simulator.operation_helpers import ir_matrix_to_ndarray
from braket.default_simulator.simulation import Simulation
from braket.default_simulator.state_vector_simulation import StateVectorSimulation


[docs]def from_braket_result_type(result_type) -> ResultType: """Creates a `ResultType` corresponding to the given Braket instruction. Args: result_type: Result type for a circuit specified using the `braket.ir.jacqd` format. Returns: ResultType: Instance of specific `ResultType` corresponding to the type of result_type Raises: NotImplementedError: If no concrete `ResultType` class has been registered for the Braket instruction type """ return _from_braket_result_type(result_type)
@singledispatch def _from_braket_result_type(result_type): raise NotImplementedError(f"Result type {result_type} not recognized")
[docs]class ResultType(ABC): """ An abstract class that when implemented defines a calculation on a quantum state simulation. Note: All result types are calculated exactly, instead of approximated from samples. Sampled results are returned from `Simulation.retrieve_samples`, which can be processed by, for example, the Amazon Braket SDK. """
[docs] @abstractmethod def calculate(self, simulation: Simulation) -> Any: # Return type of Any due to lack of sum type support in Python """Calculate a result from the given quantum state vector simulation. Args: simulation (Simulation): The simulation to use in the calculation. Returns: Any: The result of the calculation. """
[docs]class TargetedResultType(ResultType, ABC): """ Holds an observable that may target qubits. """ def __init__(self, targets: Optional[List[int]] = None): """ Args: targets (List[int], optional): The target qubits of the result type. If None, no specific qubits are targeted. """ self._targets = targets @property def targets(self) -> Optional[Tuple[int, ...]]: """Tuple[int], optional: The target qubits of the result type, if any.""" return self._targets
[docs]class ObservableResultType(TargetedResultType, ABC): """ Holds an observable to perform a calculation in conjunction with a state. """ def __init__(self, observable: Observable): """ Args: observable (Observable): The observable for which the desired result is calculated """ super().__init__(observable.measured_qubits) self._observable = observable @property def observable(self): """Observable: The observable for which the desired result is calculated.""" return self._observable @property def targets(self) -> Optional[Tuple[int, ...]]: return self._observable.measured_qubits
[docs] def calculate(self, simulation: Simulation) -> Union[float, List[float]]: """Calculates the result type using the underlying observable. Returns a real number if the observable has defined targets, or a list of real numbers, one for the result type on each target, if the observable has no target. Args: simulation (Simulation): The simulation to use in the calculation. Returns: Union[float, List[float]]: The value of the result type; will be a real due to self-adjointness of observable. """ if self._observable.measured_qubits: return self._calculate_single_quantity(simulation, self._observable) return [ self._calculate_single_quantity(simulation, self._observable.fix_qubit(qubit)) for qubit in range(simulation.qubit_count) ]
@staticmethod @abstractmethod def _calculate_single_quantity(simulation: Simulation, observable: Observable) -> float: """Calculates a single real value of the result type. Args: simulation (Simulation): The simulation to use in the calculation. observable (Observable): The observable used to calculate the result type. Returns: float: The value of the result type. """
[docs]class StateVector(ResultType): """ Simply returns the given state vector. """
[docs] def calculate(self, simulation: StateVectorSimulation) -> np.ndarray: """Return the given state vector of the simulation. Args: simulation (StateVectorSimulation): The simulation whose state vector will be returned Returns: np.ndarray: The state vector (before observables) of the simulation """ return simulation.state_vector
@_from_braket_result_type.register def _(_: jaqcd.StateVector): return StateVector()
[docs]class DensityMatrix(TargetedResultType): """ Simply returns the given density matrix. """ def __init__(self, targets: Optional[List[int]] = None): """ Args: targets (Optional[List[int]]): The qubit indices on which the reduced density matrix are desired. If no targets are specified, the full density matrix is calculated. Default: `None` """ super().__init__(targets)
[docs] def calculate(self, simulation: Simulation) -> np.ndarray: """Return the given density matrix of the simulation. Args: simulation (Simulation): The simulation whose (full or reduced) density matrix will be returned. Returns: np.ndarray: The density matrix (before observables) of the simulation """ if self._targets is None or len(self._targets) == simulation.qubit_count: return simulation.density_matrix else: if not all(ta in list(range(simulation.qubit_count)) for ta in self._targets): raise IndexError( "Input target qubits must be within the range of the qubits in the circuit." ) return partial_trace( simulation.density_matrix.reshape(np.array([2] * 2 * simulation.qubit_count)), self._targets, )
@_from_braket_result_type.register def _(density_matrix: jaqcd.DensityMatrix): return DensityMatrix(density_matrix.targets)
[docs]class Amplitude(ResultType): """ Extracts the amplitudes of the desired computational basis states. """ def __init__(self, states: List[str]): """ Args: states (List[str]): The computational basis states whose amplitudes are desired """ self._states = states
[docs] def calculate(self, simulation: StateVectorSimulation) -> Dict[str, complex]: """Return the amplitudes of the desired computational basis states in the state of the given simulation. Args: simulation (StateVectorSimulation): The simulation whose state vector amplitudes will be returned Returns: Dict[str, complex]: A dict keyed on computational basis states as bitstrings, with corresponding values the amplitudes """ state = simulation.state_vector return {basis_state: state[int(basis_state, 2)] for basis_state in self._states}
@_from_braket_result_type.register def _(amplitude: jaqcd.Amplitude): return Amplitude(amplitude.states)
[docs]class Probability(TargetedResultType): """ Computes the marginal probabilities of computational basis states on the desired qubits. """ def __init__(self, targets: Optional[List[int]] = None): """ Args: targets (Optional[List[int]]): The qubit indices on which probabilities are desired. If no targets are specified, the probabilities are calculated on the entire state. Default: `None` """ super().__init__(targets)
[docs] def calculate(self, simulation: Simulation) -> np.ndarray: """Return the marginal probabilities of computational basis states on the target qubits. Probabilities are marginalized over all non-target qubits. Args: simulation (Simulation): The simulation from which probabilities are calculated. Returns: np.ndarray: An array of probabilities of length equal to 2^(number of target qubits), indexed by the decimal encoding of the computational basis state on the target qubits """ return marginal_probability( simulation.probabilities, self._targets, )
@_from_braket_result_type.register def _(probability: jaqcd.Probability): return Probability(probability.targets)
[docs]class Expectation(ObservableResultType): """ Holds an observable :math:`O` to calculate its expected value. """ @staticmethod def _calculate_single_quantity(simulation: Simulation, observable: Observable) -> float: return simulation.expectation(observable)
@_from_braket_result_type.register def _(expectation: jaqcd.Expectation): return Expectation(_from_braket_observable(expectation.observable, expectation.targets))
[docs]class Variance(ObservableResultType): """ Holds an observable :math:`O` to calculate its variance. """ @staticmethod def _calculate_single_quantity(simulation: Simulation, observable: Observable) -> float: return simulation.expectation(observable**2) - simulation.expectation(observable) ** 2
@_from_braket_result_type.register def _(variance: jaqcd.Variance): return Variance(_from_braket_observable(variance.observable, variance.targets)) def _from_braket_observable( ir_observable: List[Union[str, List[List[List[float]]]]], ir_targets: Optional[List[int]] = None ) -> Observable: targets = list(ir_targets) if ir_targets else None if len(ir_observable) == 1: return _from_single_observable(ir_observable[0], targets) else: observable = TensorProduct( [_from_single_observable(factor, targets, is_factor=True) for factor in ir_observable] ) if targets: raise ValueError( f"Found {len(targets)} more target qubits than the tensor product acts on" ) return observable def _from_single_observable( observable: Union[str, List[List[List[float]]]], targets: Optional[List[int]] = None, # IR tensor product observables are decoupled from targets is_factor: bool = False, ) -> Observable: if observable == "i": return Identity(_actual_targets(targets, 1, is_factor)) elif observable == "h": return Hadamard(_actual_targets(targets, 1, is_factor)) elif observable == "x": return PauliX(_actual_targets(targets, 1, is_factor)) elif observable == "y": return PauliY(_actual_targets(targets, 1, is_factor)) elif observable == "z": return PauliZ(_actual_targets(targets, 1, is_factor)) else: try: matrix = ir_matrix_to_ndarray(observable) if is_factor: num_qubits = int(np.log2(len(matrix))) return Hermitian(matrix, _actual_targets(targets, num_qubits, True)) else: return Hermitian(matrix, targets) except Exception: raise ValueError(f"Invalid observable specified: {observable}") def _actual_targets(targets: List[int], num_qubits: int, is_factor: bool): if not is_factor: return targets try: return [targets.pop(0) for _ in range(num_qubits)] except Exception: raise ValueError("Insufficient target qubits for tensor product")