CampAny9995

joined 1 year ago
[–] CampAny9995@alien.top 1 points 11 months ago

Mujoco is now in XLA/JAX.

[–] CampAny9995@alien.top 1 points 11 months ago (1 children)

My experience is that JAX is much lower level, and doesn’t come with batteries included so you have to pick your own optimization library or module abstraction. But I also find it makes way more sense than PyTorch (‘requires_gradient’?), and JAX’s autograd algorithm is substantially better thought out and more robust than PyTorch’s (my background was in compilers and autograd before moving into deep learning during postdocs, so I have dug into that side of things). Plus the support for TPUs makes life a bit easier compared to competing for instances on AWS.