JAXでGPUを使う
Pytorchだと.cuda()などでGPUを使うのだが、JAXの場合はpipインストール時にcudaのバージョンを指定することでGPUを使うことができる。
環境にインストールされているCUDAのバージョンが10.02の場 ...
画像処理や自然言語処理などのハマりどころをまとめます
Pytorchだと.cuda()などでGPUを使うのだが、JAXの場合はpipインストール時にcudaのバージョンを指定することでGPUを使うことができる。
環境にインストールされているCUDAのバージョンが10.02の場 ...