this post was submitted on 25 Nov 2023
1 points (100.0% liked)
Machine Learning
1 readers
1 users here now
Community Rules:
- Be nice. No offensive behavior, insults or attacks: we encourage a diverse community in which members feel safe and have a voice.
- Make your post clear and comprehensive: posts that lack insight or effort will be removed. (ex: questions which are easily googled)
- Beginner or career related questions go elsewhere. This community is focused in discussion of research and new projects that advance the state-of-the-art.
- Limit self-promotion. Comments and posts should be first and foremost about topics of interest to ML observers and practitioners. Limited self-promotion is tolerated, but the sub is not here as merely a source for free advertisement. Such posts will be removed at the discretion of the mods.
founded 1 year ago
MODERATORS
you are viewing a single comment's thread
view the rest of the comments
view the rest of the comments
My experience is that JAX is much lower level, and doesn’t come with batteries included so you have to pick your own optimization library or module abstraction. But I also find it makes way more sense than PyTorch (‘requires_gradient’?), and JAX’s autograd algorithm is substantially better thought out and more robust than PyTorch’s (my background was in compilers and autograd before moving into deep learning during postdocs, so I have dug into that side of things). Plus the support for TPUs makes life a bit easier compared to competing for instances on AWS.
It’s a drop in replacement for numpy. It does not get sexy than that. I use it for my research on PDE solvers and deep learning and to be able to just use numpy and with automatic differentiation on it is very useful. Previously I was looking to use auto diff frameworks like tapenade but that’s not required anymore.