Patches
Plain Diff
updating notebook to use latest version of jax (otherwise it does not work with A100, and colab pro+ )