Out-of-Memory Errors when launching JAX
Neptune uses the spawn
method to launch multiprocessing workers. This can sometimes cause OOM errors when JAX is initialized during import.
To resolve, you need to import JAX after the child process is created. This can be done:
(Preferred) As part of the worker function.
After Neptune run is initialized.