PyTorch is developed by multiple companies / stake holders while jax is google only with internal tooling they don’t share with the world. This alone is a major reason not to use jax. Also I think it is more the other way around: with torch.compile the main advantage of jax is disappearing.
It's the old age question in programming: Do you use a highly constrained paradigm that allows easy automatic optimization or do you use a very flexible and more user intuitive paradigm that makes automatic optimization harder?
If the future is going to be better more intelligent compilers, then that settles the question in my opinion.
> with torch.compile the main advantage of jax is disappearing.
Interesting take - I agree here somewhat.
But also, wouldn't you think a framework that has been from the ground-up designed around a specific, mature compiler stack be better able to integrate compilers in a more stable fashion than just shoe-horning static compilers into a very dynamic framework? ;)
Depends. PyTorch on the other hand has a large user base and well defined and tested api. So should be doable; and is already progressing and rapid speed..
I think that PyTorch is part of the puzzle and it certainly helps that it is supported by AMD [0]. That said, there is code that needs to run closer to the metal too.
reply