Source code for pennylane._grad.value_and_grad
# Copyright 2026 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Defines qml.value_and_grad.
"""
import inspect
from functools import lru_cache, wraps
from importlib.util import find_spec
from pennylane import capture
from pennylane.compiler import compiler
from pennylane.exceptions import CompileError
from .grad import _args_and_argnums, _setup_h, _setup_method, _ShapedArray
_has_jax = find_spec("jax") is not None
# pylint: disable=unused-argument, too-many-arguments
@lru_cache
def _get_value_and_grad_prim():
"""Create a primitive for value and gradient computations."""
if not _has_jax: # pragma: no cover
return None
import jax # pylint: disable=import-outside-toplevel
value_and_grad_prim = capture.QmlPrimitive("value_and_grad")
value_and_grad_prim.multiple_results = True
value_and_grad_prim.prim_type = "higher_order"
@value_and_grad_prim.def_impl
def _value_and_grad_impl(*args, argnums, jaxpr, method, h, fn):
if method != "auto": # pragma: no cover
raise ValueError(f"Invalid value '{method=}' without QJIT.")
def func(*inner_args):
res = jax.core.eval_jaxpr(jaxpr, [], *inner_args)
return res[0]
res = jax.value_and_grad(func, argnums=argnums)(*args)
return jax.tree_util.tree_leaves(res)
# pylint: disable=unused-argument
@value_and_grad_prim.def_abstract_eval
def _value_and_grad_abstract(*args, argnums, jaxpr, method, h, fn):
in_avals = tuple(args[i] for i in argnums)
grad_avals = (
_ShapedArray(in_aval.shape, in_aval.dtype, weak_type=in_aval.weak_type)
for in_aval in in_avals
)
return [jaxpr.outvars[0].aval, *grad_avals]
return value_and_grad_prim
def _capture_value_and_grad(func, *, argnums=0, method=None, h=None):
# mostly a copy-paste of _capture_diff, but a few minor things needed to get updated
# Could also find a way to remove code duplication
import jax # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
from jax.tree_util import tree_flatten, tree_leaves, tree_unflatten
h = _setup_h(h)
method = _setup_method(method)
_argnums = argnums # somehow renaming stops it from being unbound?
@wraps(func)
def new_func(*args, **kwargs):
flat_args, flat_argnums, full_in_tree, trainable_in_tree = _args_and_argnums(args, _argnums)
# Create fully flattened function (flat inputs & outputs)
flat_fn = capture.FlatFn(func, full_in_tree)
abstracted_axes, abstract_shapes = capture.determine_abstracted_axes(tuple(flat_args))
jaxpr = jax.make_jaxpr(flat_fn, abstracted_axes=abstracted_axes)(*flat_args, **kwargs)
if len(jaxpr.out_avals) > 1 or jaxpr.out_avals[0].shape != ():
raise TypeError(
f"value_and_grad only defined for scalar-output functions. Got {jaxpr.out_avals}"
)
num_abstract_shapes = len(abstract_shapes)
shift = num_abstract_shapes + len(jaxpr.consts)
shifted_argnums = tuple(a + shift for a in flat_argnums)
j = jaxpr.jaxpr
no_consts_jaxpr = j.replace(constvars=(), invars=j.constvars + j.invars)
flat_inputs, _ = tree_flatten((args, kwargs))
prim_kwargs = {
"argnums": shifted_argnums,
"jaxpr": no_consts_jaxpr,
"fn": func,
"method": method,
"h": h,
}
out_flat = _get_value_and_grad_prim().bind(
*jaxpr.consts,
*abstract_shapes,
*flat_inputs,
**prim_kwargs,
)
res_flat, *grad_flat = out_flat
# flatten once more to go from 2D derivative structure (outputs, args) to flat structure
grad_flat = tree_leaves(grad_flat)
assert flat_fn.out_tree is not None, "out_tree should be set after executing flat_fn"
res_nested = tree_unflatten(flat_fn.out_tree, [res_flat])
# The derivative output tree is the composition of output tree and trainable input trees
combined_tree = flat_fn.out_tree.compose(trainable_in_tree)
grad_nested = tree_unflatten(combined_tree, grad_flat)
return res_nested, grad_nested
return new_func
# pylint: disable=too-few-public-methods
[docs]
class value_and_grad:
"""A :func:`~.qjit`-compatible transformation for returning the result and jacobian of a
function.
This function allows the value and the gradient of a hybrid quantum-classical function to be
computed within the compiled program.
Note that ``value_and_grad`` can be more efficient, and reduce overall quantum executions,
compared to separately executing the function and then computing its gradient.
.. warning::
Currently, higher-order differentiation is only supported by the finite-difference
method.
Args:
fn (Callable): a function or a function object to differentiate
method (str): The method used for differentiation, which can be any of ``["auto", "fd"]``,
where:
- ``"auto"`` represents deferring the quantum differentiation to the method
specified by the QNode, while the classical computation is differentiated
using traditional auto-diff. Catalyst supports ``"parameter-shift"`` and
``"adjoint"`` on internal QNodes. Notably, QNodes with
``diff_method="finite-diff"`` are not supported with ``"auto"``.
- ``"fd"`` represents first-order finite-differences for the entire hybrid
function.
h (float): the step-size value for the finite-difference (``"fd"``) method
argnums (Tuple[int, List[int]]): the argument indices to differentiate
Returns:
Callable: A callable object that computes the value and gradient of the wrapped function
for the given arguments.
Raises:
ValueError: Invalid method or step size parameters.
DifferentiableCompilerError: Called on a function that doesn't return a single scalar.
.. note::
Any JAX-compatible optimization library, such as `Optax
<https://optax.readthedocs.io/en/stable/index.html>`_, can be used
alongside ``value_and_grad`` for JIT-compatible variational workflows.
See the :doc:`quickstart guide <catalyst:dev/quick_start>` for examples.
.. seealso:: :func:`~.grad`, :func:`~.jacobian`
**Example 1 (Classical preprocessing)**
.. code-block:: python
dev = qml.device("lightning.qubit", wires=1)
@qml.qjit
def workflow(x):
@qml.qnode(dev)
def circuit(x):
qml.RX(jnp.pi * x, wires=0)
return qml.expval(qml.PauliY(0))
return qml.value_and_grad(circuit)(x)
>>> workflow(0.2)
(Array(-0.58778525, dtype=float64), Array(-2.54160185, dtype=float64))
**Example 2 (Classical preprocessing and postprocessing)**
.. code-block:: python
dev = qml.device("lightning.qubit", wires=1)
@qml.qjit
def value_and_grad_loss(theta):
@qml.qnode(dev, diff_method="adjoint")
def circuit(theta):
qml.RX(jnp.exp(theta ** 2) / jnp.cos(theta / 4), wires=0)
return qml.expval(qml.PauliZ(wires=0))
def loss(theta):
return jnp.pi / jnp.tanh(circuit(theta))
return qml.value_and_grad(loss, method="auto")(theta)
>>> value_and_grad_loss(1.0)
(Array(-4.2622289, dtype=float64), Array(5.04324559, dtype=float64))
**Example 3 (Purely classical functions)**
.. code-block:: python
def square(x: float):
return x ** 2
@qml.qjit
def dsquare(x: float):
return qml.value_and_grad(square)(x)
>>> dsquare(2.3)
(Array(5.29, dtype=float64), Array(4.6, dtype=float64))
"""
def __init__(self, func, argnums=0, method=None, h=None):
self._func = func
self._argnums = argnums
self._method = method
self._h = h
# need to preserve input signature for use in catalyst AOT compilation, but
# get rid of return annotation to placate autograd
self.__signature__ = inspect.signature(self._func).replace(
return_annotation=inspect.Signature.empty
)
fn_name = getattr(self._func, "__name__", repr(self._func))
self.__name__ = f"<value_and_grad: {fn_name}>"
[docs]
def __call__(self, *args, **kwargs):
if active_jit := compiler.active_compiler():
available_eps = compiler.AvailableCompilers.names_entrypoints
ops_loader = available_eps[active_jit]["ops"].load()
return ops_loader.value_and_grad(
self._func, method=self._method, h=self._h, argnums=self._argnums
)(*args, **kwargs)
if capture.enabled():
g = _capture_value_and_grad(
self._func, argnums=self._argnums, method=self._method, h=self._h
)
return g(*args, **kwargs)
raise CompileError("PennyLane does not support the value_and_grad function without QJIT.")
_modules/pennylane/_grad/value_and_grad
Download Python script
Download Notebook
View on GitHub