this post was submitted on 25 Nov 2023
1 points (100.0% liked)

Machine Learning

1 readers
1 users here now

Community Rules:

founded 1 year ago
MODERATORS
 

Creator of Keras confirmed that the new version comes out in a few days. Keras becomes multi-backend again with support for PyTorch, TensorFlow and JAX. Personally, I'm excited to be able to try JAX without having to deep dive into documentation and entire ecosystem. What about you?

you are viewing a single comment's thread
view the rest of the comments
[–] Relevant-Yak-9657@alien.top 1 points 11 months ago

As the others said, it's a pain to reimplement common layers in JAX (specifically). PyTorch is much higher level in it's nn API, but personally I despise rewriting the amazing training loop for every implementation. That's why even JAX uses Flax for common layers, because why use an error prone operator like jax.lax.conv_from_dilated or whatever and fill its 10 arguments every time? I would rather use flax.linen.Conv2D or keras_core.layers.Conv2D in my Sequential layer and prevent debugging a million times. For PyTorch, model.fit() can just quickly suffice and later customized.