this post was submitted on 25 Nov 2023
1 points (100.0% liked)

Machine Learning

1 readers
1 users here now

Community Rules:

founded 1 year ago
MODERATORS
 

Creator of Keras confirmed that the new version comes out in a few days. Keras becomes multi-backend again with support for PyTorch, TensorFlow and JAX. Personally, I'm excited to be able to try JAX without having to deep dive into documentation and entire ecosystem. What about you?

you are viewing a single comment's thread
view the rest of the comments
[–] underPanther@alien.top 1 points 11 months ago (4 children)

Libraries like PyTorch and Jax are already high level libraries in my view. The low level stuff is C++/CUDA/XLA.

I don’t really see the useful extra abstractions in Keras that would lure me to it.

[–] abio93@alien.top 1 points 11 months ago

If you use Jax with Keras you are eseentialy doing: keras->jax->jaxpr->llvm->cuda/xla, with probably many more intermediate levels

[–] okay123987@alien.top 1 points 11 months ago
[–] odd1e@alien.top 1 points 11 months ago (1 children)

What about not having to write your own training loop? Keras takes away a lot of boilerplate code, it makes your code more readable and less likely to contain bugs. I would compare it to scikit-learn: Sure, you can implement your own Random Forest, but why bother?

[–] TheCloudTamer@alien.top 1 points 11 months ago

The reality is that you nearly always need to break into that training abstraction, and so it is useless.

[–] Relevant-Yak-9657@alien.top 1 points 11 months ago

As the others said, it's a pain to reimplement common layers in JAX (specifically). PyTorch is much higher level in it's nn API, but personally I despise rewriting the amazing training loop for every implementation. That's why even JAX uses Flax for common layers, because why use an error prone operator like jax.lax.conv_from_dilated or whatever and fill its 10 arguments every time? I would rather use flax.linen.Conv2D or keras_core.layers.Conv2D in my Sequential layer and prevent debugging a million times. For PyTorch, model.fit() can just quickly suffice and later customized.