nanoDFT computes forces on the CPU using def grad(..) on line 230. To run def grad(..) on the IPU it is sufficient to port lines 269-273 and line 283.
Different strategies for porting lines 269-273:
- Compile libcint to poplar and replace all
mol.intor(..) with corresponding poplar calls (ERI is only problematic part).
- Use Jax implementation from D4FT for the forward pass of the
mol.intor(..) and match up the jax.grad(..) of the forward passes with lines 269-273 (pyscfad matched up libcint with jax.grad for CPU => their code may be helpful).
- Reimplement all integrals from first principles in Jax/tesselate.
Note: Line 230 uses this theorem to compute gradients. We could use jax.grad(_nanoDFT) instead of the theorem. That would require us to fix all calls in _nanoDFT(..) which don't support derivatives. We currently believe the work involved is the same as fixing def grad(..) (see the above different strategies). In other words: the non-autograd stuff _nanoDFT calls are calls which have derivatives as computed on line 269-273 and 283.
nanoDFT computes forces on the CPU using
def grad(..)on line 230. To rundef grad(..)on the IPU it is sufficient to port lines 269-273 and line 283.Different strategies for porting lines 269-273:
mol.intor(..)with corresponding poplar calls (ERI is only problematic part).mol.intor(..)and match up thejax.grad(..)of the forward passes with lines 269-273 (pyscfad matched up libcint withjax.gradfor CPU => their code may be helpful).Note: Line 230 uses this theorem to compute gradients. We could use
jax.grad(_nanoDFT)instead of the theorem. That would require us to fix all calls in_nanoDFT(..)which don't support derivatives. We currently believe the work involved is the same as fixingdef grad(..)(see the above different strategies). In other words: the non-autograd stuff_nanoDFTcalls are calls which have derivatives as computed on line 269-273 and 283.