My experience is that JAX is much lower level, and doesn’t come with batteries included so you have to pick your own optimization library or module abstraction. But I also find it makes way more sense than PyTorch (‘requires_gradient’?), and JAX’s autograd algorithm is substantially better thought out and more robust than PyTorch’s (my background was in compilers and autograd before moving into deep learning during postdocs, so I have dug into that side of things). Plus the support for TPUs makes life a bit easier compared to competing for instances on AWS.
CampAny9995
joined 1 year ago
Mujoco is now in XLA/JAX.