this post was submitted on 14 Nov 2023
1 points (100.0% liked)
Machine Learning
1 readers
1 users here now
Community Rules:
- Be nice. No offensive behavior, insults or attacks: we encourage a diverse community in which members feel safe and have a voice.
- Make your post clear and comprehensive: posts that lack insight or effort will be removed. (ex: questions which are easily googled)
- Beginner or career related questions go elsewhere. This community is focused in discussion of research and new projects that advance the state-of-the-art.
- Limit self-promotion. Comments and posts should be first and foremost about topics of interest to ML observers and practitioners. Limited self-promotion is tolerated, but the sub is not here as merely a source for free advertisement. Such posts will be removed at the discretion of the mods.
founded 11 months ago
MODERATORS
you are viewing a single comment's thread
view the rest of the comments
view the rest of the comments
I have not gone into the internals, but from the documentation they seem to be behaving in a very similar manner in the sense that they both do symbolic execution by sending a tracer object, recording what’s happening, and then emitting compiled code.
JAX specifically takes into consideration shapes, and it seems that torch does it as well. Both might do more jitting if it’s necessary, eg different input shapes.
I have the hunch that Jax should be faster because it relies on XLA primitives, eg xla-backed for-loops and scans, and I am not sure how PT handles that given it most likely tries to remain backwards compatible and rely on python semantics over some compiler primitives.