Interface

To provide the support for back-propagation with PyTorch and JAX, the JAX computation are wrapped into a simple customized forward and backward function of PyTorch. In this case, JAX module are treated as a custom-defined function in PyTorch, which provide good compatibility to PyTorch and high computation speed by JAX.

PyTorch interface

This module contains the pytorch_interface class, tedq can use jax to calculate the

quantum circuit, but interfacing with torch format.

class tedq.interface.pytorch_interface.PytorchInterface(*args, **kwargs)

This is the class that provide the PyTorch interface for JAX computation module.

static backward(ctx, *dy)

Customized backward function wrapped the JAX computation module. The main purpose is to convert the variable type between PyTorch and JAX and interfacing the gradient calculated by each module. With this function, the back-propagation in the JAX part function are available for PyTorch module.

Args:

dy(Torch.tensor): Gradient from previous layer

static forward(ctx, input_kwargs, *input_)

Customized forward function wrapped the JAX computation module. The main purpose is to convert the variable type between PyTorch and JAX.

Args:

dy(kwargs): TODO input_(Torch.tensor): Parameters