Interface
To provide the support for back-propagation with PyTorch and JAX, the JAX computation are wrapped into a simple customized forward and backward function of PyTorch. In this case, JAX module are treated as a custom-defined function in PyTorch, which provide good compatibility to PyTorch and high computation speed by JAX.
PyTorch interface
- This module contains the
pytorch_interfaceclass, tedq can use jax to calculate the quantum circuit, but interfacing with torch format.
- class tedq.interface.pytorch_interface.PytorchInterface(*args, **kwargs)
This is the class that provide the PyTorch interface for JAX computation module.
- static backward(ctx, *dy)
Customized backward function wrapped the JAX computation module. The main purpose is to convert the variable type between PyTorch and JAX and interfacing the gradient calculated by each module. With this function, the back-propagation in the JAX part function are available for PyTorch module.
- Args:
dy(Torch.tensor): Gradient from previous layer
- static forward(ctx, input_kwargs, *input_)
Customized forward function wrapped the JAX computation module. The main purpose is to convert the variable type between PyTorch and JAX.
- Args:
dy(kwargs): TODO input_(Torch.tensor): Parameters