Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

It's a little disingenuous to say that the 4000x speedup is due to Jax. I'm a huge Jax fanboy (one of the biggest) but the speedup here is thanks to running the simulation environment on a GPU. But as much as I love Jax, it's still extraordinarily difficult to implement even simple environments purely on a GPU.

My long-term ambition is to replicate OpenAI's Dota 2 reinforcement learning work, since it's one of the most impactful (or at least most entertaining) use of RL. It would be more or less impossible to translate the game logic into Jax, short of transpiling C++ to Jax somehow. Which isn't a bad idea – someone should make that.

It should also be noted that there's a long history of RL being done on accelerators. AlphaZero's chess evaluations ran entirely on TPUs. Pytorch CUDA graphs also make it easier to implement this kind of thing nowadays, since (again, as much as I love Jax) some Pytorch constructs are simply easier to use than turning everything into a functional programming paradigm.

All that said, you should really try out Jax. The fact that you can calculate gradients w.r.t. any arbitrary function is just amazing, and you have complete control over what's JIT'ed into a GPU graph and what's not. It's a wonderful feeling compared to using Pytorch's accursed .backwards() accumulation scheme.

Can't wait for a framework that feels closer to pure arbitrary Python. Maybe AI can figure out how to do it.



Author here! I didn't realize this got posted on HN. While indeed we do get a speedup by putting the environments on the GPU, most of the speedup seems to come from the ability to easily parallelize RL training with Jax.

While there is work on putting RL environments on accelerators, the main speedup from this work comes from also training many RL agents in parallel. This is largely because the neural networks we use in RL are relatively small and thus don't utilize the GPU very efficiently.

While this was always possible to do, Jax makes it way easier because we just need to call `jax.vmap` to get it to work.


AlphaZero did not run game logic on TPUs (neither chess nor other games), implementing it in C++ is more than fast enough and much simpler.

TPUs were used for neural network inference and training, but game logic as well as MCTS was on the CPU using C++.

JAX is awesome though, I use it for all my neural network stuff!


According to the AlphaZero paper (https://arxiv.org/pdf/1712.01815.pdf) they ran game logic on TPUs:

> Training proceeded for 700,000 steps (mini-batches of size 4,096) starting from randomly initialised parameters, using 5,000 first-generation TPUs to generate self-play games and 64 second-generation TPUs to train the neural networks. Further details of the training procedure are provided in the Methods.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: