JAXでGPUを使う

2月 23, 2022

Pytorchだと.cuda()などでGPUを使うのだが、JAXの場合はpipインストール時にcudaのバージョンを指定することでGPUを使うことができる。

環境にインストールされているCUDAのバージョンが10.02の場合、以下を入力しJAXのインストールを行う。

pip install -U jax[cuda102] -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install -U jaxlib[cuda102] -f https://storage.googleapis.com/jax-releases/jax_releases.html

JAXがGPUを使っているかどうかは、以下で確かめる。

import jax
jax.default_backend()

上記のコマンドを入力後、’gpu’がかえってこれば正常に使えているということになる。

また、以下のコマンドを打つと、

jax.local_devices()

使われているGPUのidと数が以下のように把握可能。

[GpuDevice(id=0, process_index=0),
 GpuDevice(id=1, process_index=0),
 GpuDevice(id=2, process_index=0),
 GpuDevice(id=3, process_index=0)]

JAX

Posted by vastee