I recently tried to port my model to JAX. Got it all working the "JAX WAY", and I believe I did everything correct, with one neat top level .jit() applied to the training step. Unfortunately I could not replicate the performance boost of torch.compile(). I have not yet delved under the hood to find the culprit, but my model is fairly simple so I was sort of expecting JAX JIT to perform just as well if not better than torch.compile().
JAX code usually ends up being way faster than equivalent torch code for me, even with torch.compile. There are common performance killers, though. Notably, using Python control flow (if statements, loops) instead of jax.lax primitives (where, cond, scan, etc).
Interesting. Thanks for you input. I already tried to adhere to the JAX paradigm as laid out in the documentation so I already have a fully static graph.
I would test how much of the total flop capability of the hardware you are using. Take the first order terms of your model and estimate how many flops you need per data point (a good guide is 6*param for training if you mostly have large multiplies and nonlinearity/norm layers) and then calculate the real time performance for a given data size input vs the actual expected theoretical max perfomance for the given GPU (eg 1e15 FLOPs/s for bfloat16 per H100 or H200 GPU). If you are already over 50% it is unlikely you can have big gains without very considerable effort, and most likely simple jax or pytorch are not sufficient at that point. If you are at the 2–20% range there are probably some low hanging fruit left and the closer you are to using only 1% the easier it is to see dramatic gains.
Have anyone else had similiar experiences?