I'm not very knowledgeable in this realm, so somebody clue me in. I always thought of JAX as targeted towards more "bespoke" stuff, is there any advantage to using it in a high-level way instead of Torch or TF? Anything in the performance or ecosystem etc?
Machine Learning
Community Rules:
- Be nice. No offensive behavior, insults or attacks: we encourage a diverse community in which members feel safe and have a voice.
- Make your post clear and comprehensive: posts that lack insight or effort will be removed. (ex: questions which are easily googled)
- Beginner or career related questions go elsewhere. This community is focused in discussion of research and new projects that advance the state-of-the-art.
- Limit self-promotion. Comments and posts should be first and foremost about topics of interest to ML observers and practitioners. Limited self-promotion is tolerated, but the sub is not here as merely a source for free advertisement. Such posts will be removed at the discretion of the mods.
Mujoco is now in XLA/JAX.
Biggest 'advantage' i can see is that, since Google is deprecating tf soon, JAX is the only googly deep learning lib left. It fills a niche, insofar a that is a definable niche. I'm sticking with pytorch for now.
No clue about things like speed/efficiency, which may be a factor.
since Google is deprecating tf soon
Do you have a source? IMO TF is too big to deprecate soon. They did stop support for windows, but nobody abandons an enormous project suddenly
Yeah, from what I see, despite the mess TensorFlow might be, it still is getting updated frequently and has been improving these days. Not sure why they would depracate anytime soon.
TLDR: No, they are not officially planning to deprecate TF. Yes they are still actively developing TF. No, that doesn't fill me with much confidence, coming from Google, especially while they are also developing Jax.
Just searched this again and kudos, I can't find anything but official Google statements that they are continuing support for TF in the foreseeable future. For a while people were doom-saying so confidently that Google is completely dropping TF for JAX that I kinda just took it on blind faith.
All that said: #TF REALLY COULD GET DEPRECATED SOON Despite their insistence that this won't happen, Google is known for deprecating strong projects with bright futures with little/no warning. Do not take the size of Tensorflow as evidence that the Goog is going to stand by it. Especially when they are actively developing a competing product in the niche.
fwiw, it is also the current fad in tech to make high level decisions abruptly without proper warning to engineers. It really does mean almost nothing when a company's engineers are enthusiastically continuing their support of a product.
TF is just not on solid ground.
Also, JAX is not official a google product, but rather a research product. So on paper, Tensorflow is google's official framework for deep learning.
What obligation does Google have to not deprecate tf? Google abandons projects all the time.
I would also like to hear about some programming frameworks (or languages) that Google has abandoned before.
But saying that it dropped Angularjs is like saying that google dropped tensorflow. They just rebooted it like tensorflow right? Thanks for Noop though. No idea that it existed lol.
Actually, another perspective to put is that TensorFlow's deployment is something JAX doesn't have (not that I know of) and cutting it would be idiotic for google, since they eliminated their own tool in an ongoing AI revolution. TensorFlow is their current tool and if they are going to abandon it, they will need a strong replacement for it's deployability which does guarantee a few years (since the JAX team doesn't seem to be quite focused in deployment). IIRC JAX deploys by Tensorflow rn.
That could be a valid concern. Personally, not too worried, since this is just a speculation though. Besides, the field is diverse enough that most people would benefit from learning multiple frameworks.
My experience is that JAX is much lower level, and doesn’t come with batteries included so you have to pick your own optimization library or module abstraction. But I also find it makes way more sense than PyTorch (‘requires_gradient’?), and JAX’s autograd algorithm is substantially better thought out and more robust than PyTorch’s (my background was in compilers and autograd before moving into deep learning during postdocs, so I have dug into that side of things). Plus the support for TPUs makes life a bit easier compared to competing for instances on AWS.
It’s a drop in replacement for numpy. It does not get sexy than that. I use it for my research on PDE solvers and deep learning and to be able to just use numpy and with automatic differentiation on it is very useful. Previously I was looking to use auto diff frameworks like tapenade but that’s not required anymore.
Google isn’t deprecating TF.
This is mainly for people who only learned Keras and don’t want to learn anything else
Saves time when creating neural nets. If you want to utilize subsets of its speed and not spend hours analysing training methods, Keras can be a good addition. Other than that, Flax is probably the best way for full tkinter while having an abstraction better than a NumPy like api.
One draw of keras that would get people to switch over would be how easy it is to model parallelism, but you'd need to get better mfu than fsdp/deepspeed, ideally competitive to megatron while being way easier and more flexible for people to switch
No. If they change the name to KerasGPT 4.0, then it will be all I think, talk, and read about for 6 months.
Started with Keras, wtf is this. Moved to PyTorch, oh this is so nice.
Don't plan to ever come back.
Most people here would say that PyTorch is better, but IMO Keras is fine and no hate to tensorflow either. They just did a lot of questionable API design changes and FC has been weird on twitter. For me, it is pretty exciting, since Keras_core seems pretty stable as I use it and it is just another great framework for new people in deep learning or quick prototyping.
Documentation is superb as well.
I get access to some awesome data loading and preprocessing tools with the pytorch backend then I swap to tensorflow for quantization for tflite model with almost no fuss.
It was somewhat annoying going from torch to onnx to tflite previously. There's a bunch of small roadbumps that you have to deal with.
Yeah, unifying these tools feels like the best way to go for me too. I also like JAX for a similar reason because there are 50 different libraries with different use cases and it is easy to mix parts of them together, due to the common infrastructure. Like Keras losses + flax models + optax training + my custom libraries super classes. It's great tbh.
What does FC say on Twitter?
Libraries like PyTorch and Jax are already high level libraries in my view. The low level stuff is C++/CUDA/XLA.
I don’t really see the useful extra abstractions in Keras that would lure me to it.
This.
What about not having to write your own training loop? Keras takes away a lot of boilerplate code, it makes your code more readable and less likely to contain bugs. I would compare it to scikit-learn: Sure, you can implement your own Random Forest, but why bother?
The reality is that you nearly always need to break into that training abstraction, and so it is useless.
As the others said, it's a pain to reimplement common layers in JAX (specifically). PyTorch is much higher level in it's nn API, but personally I despise rewriting the amazing training loop for every implementation. That's why even JAX uses Flax for common layers, because why use an error prone operator like jax.lax.conv_from_dilated or whatever and fill its 10 arguments every time? I would rather use flax.linen.Conv2D or keras_core.layers.Conv2D in my Sequential layer and prevent debugging a million times. For PyTorch, model.fit() can just quickly suffice and later customized.
If you use Jax with Keras you are eseentialy doing: keras->jax->jaxpr->llvm->cuda/xla, with probably many more intermediate levels
Will it be backward compatible with keras 2.0 largely?
Yes, completely. In fact, tf.keras will just be changed internally (as in source code) as keras_core, but you won't notice any difference in tf.keras (except for some removal of currently depracated legacy code and a visual update in .fit()).
No, all these extra layers of abstraction are detrimental in my view. Too much magic happening without you knowing what unless you dive into their code, same with PyTorch lightning etc
No, I finally changed from Tensorflow Keras to PyTorch and not going back.
Why abandon keras? In the end I felt like I was always fighting some little bug or error that needed brute force guessing to resolve. The data loaders are painful. Dealing with idiosyncrasies of the model subclass system to do anything custom.
Yes, very much actually! People around here tend to forget that not everyone in ML is building billion-parameter LLMs, some of us just need a few basic building blocks and a model.fit() function to call
Keras broke the ice for me. The design of NNs used to take me a while to understand. It felt mechanic and meaningless. I was struggling hard to understand why adding or subtracting layers would help or hurt my models. I was trudging through tf documentation and honestly… I was very close to giving up.
I built my first ANN, got better with keras, graduated to tf, built my first U-net and got more confidence. I think anyone that really criticizes keras doesn’t understand that it is like criticizing training wheels for a bike.
You gotta learn to walk before you can run. You gotta learn baby nets before you are building monster segmentation models on recurrent convolutional neural nets. It takes time to understand the concepts and data flow.
Yes, same story. Keras allowed me to understand the basics. Personally, my journey has been as Keras for architecture, Pytorch/TensorFlow for implicit gradient differentiation, JAX for explicit gradient optimization, and then creating a library on JAX to understand how these libraries work.
training wheels are horrible btw
it's much better to train kids with "pedal-less" bikes and then graduate them to pedals without training wheels, much easier to adapt to gaining balance etc.
Yeah, Keras was sort of useful and sort of annoying, but training wheels just suck. What's worst is when your kid falls while using training wheels. One a balance bike, you know you're unstable. On training wheels, your kid has false faith and isn't prepared for the tipover... especially if your kid is, at that moment, entranced with your scintillating lecture about the superiority of PyTorch.
So, either you have not recently taught a kid to ride a bike or you are just trolling.
So, I will counter your high ceiling with the low floor plan. The more a person rides a bike tw’s or not the better they will be at riding a bike. The tw’s get you riding more often and logging in the hours.
You may be right about balance being a skill you develop without tw’s but the hours they will spend failing and falling down discourages the kids then they don’t want to play anymore.
I think you misunderstood me. In France they have those bikes for kids without pedals called "draisiennes" (I don't know what it is in English).
Kids on these bikes have no training wheels and they just "stroll" with them, lift their legs, and get used to manage the balance at speed. My friend's kids who got used to it like that were able to pedal on their first "real bike" (with pedals) first time, without any training wheels.
It makes the transition *a lot* easier apparently.
I don't care.
All the research I do and did is in Pytorch, nearly all the research I use is done in Pytorch.
So why should I switch? I would need to implement all the frameworks, tests etc. again and reimplement (and test+verify) all related work.
No, thanks.
No, not so much. I haven't used Keras for years.
Nope...
Moved to PyTorch.
Not coming back....
Anything beyond the normal stuff is a pain,
no, I tried really hard to stick with TF, I learned the basics, back when you still had to deal with a computational graph, then I found tflearn, keras and my world changed.
I would still have sticked to it, but google just doesn't care enough about TF and I think its a waste of my ressources to learn it.
I think I spent enough hours with Pytorch that I see no reason going to Keras anymore. I really liked it for the very easy documentation, nice blog posts etc. Now I know all that stuff already in Pytorch, so it's hard to see any reason to be excited.
Yes! I know all the cool kids use PyTorch but I find it is too much boilerplate. I like keras. So this is great news.