this post was submitted on 14 Nov 2023
1 points (100.0% liked)

Machine Learning

1 readers
1 users here now

Community Rules:

founded 11 months ago
MODERATORS
 

Have any of you seen benchmarks comparing the performance of `@jax.jit` and `@torch.compile`, especially when using functional PyTorch code? Are the performance differences big? Small? Do they depend a lot on what you're doing?

top 3 comments
sorted by: hot top controversial new old
[–] PM_ME_YOUR_BAYES@alien.top 1 points 10 months ago (1 children)

Experiment yourself. It takes like 10 minutes.

[–] maizeq@alien.top 1 points 10 months ago

Banal and unhelpful response.

[–] depressed-bench@alien.top 1 points 10 months ago

I have not gone into the internals, but from the documentation they seem to be behaving in a very similar manner in the sense that they both do symbolic execution by sending a tracer object, recording what’s happening, and then emitting compiled code.

JAX specifically takes into consideration shapes, and it seems that torch does it as well. Both might do more jitting if it’s necessary, eg different input shapes.

I have the hunch that Jax should be faster because it relies on XLA primitives, eg xla-backed for-loops and scans, and I am not sure how PT handles that given it most likely tries to remain backwards compatible and rely on python semantics over some compiler primitives.