It's a little disingenuous to say that the 4000x speedup is due to Jax. I'm a huge Jax fanboy (one of the biggest) but the speedup here is thanks to running the simulation environment on a GPU. But as much as I love Jax, it's still extraordinarily difficult to implement even simple environments purely on a GPU.
My long-term ambition is to replicate OpenAI's Dota 2 reinforcement learning work, since it's one of the most impactful (or at least most entertaining) use of RL. It would be more or less impossible to translate the game logic into Jax, short of transpiling C++ to Jax somehow. Which isn't a bad idea – someone should make that.
It should also be noted that there's a long history of RL being done on accelerators. AlphaZero's chess evaluations ran entirely on TPUs. Pytorch CUDA graphs also make it easier to implement this kind of thing nowadays, since (again, as much as I love Jax) some Pytorch constructs are simply easier to use than turning everything into a functional programming paradigm.
All that said, you should really try out Jax. The fact that you can calculate gradients w.r.t. any arbitrary function is just amazing, and you have complete control over what's JIT'ed into a GPU graph and what's not. It's a wonderful feeling compared to using Pytorch's accursed .backwards() accumulation scheme.
Can't wait for a framework that feels closer to pure arbitrary Python. Maybe AI can figure out how to do it.
Author here! I didn't realize this got posted on HN. While indeed we do get a speedup by putting the environments on the GPU, most of the speedup seems to come from the ability to easily parallelize RL training with Jax.
While there is work on putting RL environments on accelerators, the main speedup from this work comes from also training many RL agents in parallel. This is largely because the neural networks we use in RL are relatively small and thus don't utilize the GPU very efficiently.
While this was always possible to do, Jax makes it way easier because we just need to call `jax.vmap` to get it to work.
> Training proceeded
for 700,000 steps (mini-batches of size 4,096) starting from randomly initialised parameters,
using 5,000 first-generation TPUs to generate self-play games and 64 second-generation
TPUs to train the neural networks. Further details of the training procedure are provided in the
Methods.
My long-term ambition is to replicate OpenAI's Dota 2 reinforcement learning work, since it's one of the most impactful (or at least most entertaining) use of RL. It would be more or less impossible to translate the game logic into Jax, short of transpiling C++ to Jax somehow. Which isn't a bad idea – someone should make that.
It should also be noted that there's a long history of RL being done on accelerators. AlphaZero's chess evaluations ran entirely on TPUs. Pytorch CUDA graphs also make it easier to implement this kind of thing nowadays, since (again, as much as I love Jax) some Pytorch constructs are simply easier to use than turning everything into a functional programming paradigm.
All that said, you should really try out Jax. The fact that you can calculate gradients w.r.t. any arbitrary function is just amazing, and you have complete control over what's JIT'ed into a GPU graph and what's not. It's a wonderful feeling compared to using Pytorch's accursed .backwards() accumulation scheme.
Can't wait for a framework that feels closer to pure arbitrary Python. Maybe AI can figure out how to do it.