LLMの実行・学習に必要なVRAMの見積もり方
結論を先に。LLMのVRAMは、まず「モデル重み = パラメータ数 × 1パラメータあたりのバイト数」で概算します。FP16(2バイト)なら7Bで約14GB、8Bで約16GBです。推論ではこれにKVキャッシュと実行時オーバーヘッドを、学習ではさらに勾配とオプティマイザ状態を加えます。
モデル重み(GB) ≒ パラメータ数(B) × バイト数 FP32=4 / FP16・BF16=2 / INT8=1 / INT4=0.5例) 8B を FP16 → 8 × 2 = 16GB数字を手で追うのは面倒なので、モデルと精度を選ぶと推論・学習のVRAMとGPUの可否を出すLLM VRAM計算機を作りました。
推論に必要なVRAM
推論では次の3つを足します。
- モデル重み。パラメータ数×バイト数。量子化(INT8/INT4)で大きく減らせます。
- KVキャッシュ。コンテキスト長とバッチに比例して増えます。
KVキャッシュ(GB) ≒ 2 × レイヤー数 × KV次元 × コンテキスト長 × バッチ × 2 ÷ 10^9KV次元はGQA(グループ化クエリ注意)のモデルではhiddenより小さくなります。たとえばLlama 3 8Bは1024で、hiddenの4096より小さいぶんKVキャッシュも軽くなります。
- 実行時オーバーヘッド。CUDAコンテキストや一時バッファで、重み+KVの15%前後を見ておきます。
Llama 3 8B を FP16・4kトークンで動かすなら、16 + 約0.5 + 約2.5 で合計およそ19GBです。24GBのGPU(RTX 3090/4090)に載る計算になります。
学習に必要なVRAM
フルファインチューンは重み以外も持つため、一気に増えます。混合精度+Adamの目安は次の通りです。
重み(FP16 2) + 勾配(FP16 2) + Adam状態(FP32×2 = 8) + FP32マスタ重み(4)≒ 1パラメータあたり16バイト + 活性化8Bのフルファインチューンは、活性化まで含めると150GB前後になり、単一GPUでは厳しくなります。
一方でLoRA/QLoRAは、元の重みを凍結して小さなアダプタだけ学習するため大幅に軽くなります。QLoRA(INT4)なら8Bでも数GB台に収まり、24GBどころか8〜12GBのGPUでも狙えます。
つまずきやすい点
- 量子化は推論のメモリを減らしますが、精度は多少落ちます。用途に応じて選びます。
- コンテキストを長くするとKVキャッシュが線形に増えます。長文処理はメモリを食います。
- ここでの値は概算です。活性化メモリは実装や設定で前後します。余裕を持って見積もります。
まとめ
- 重み = パラメータ数 × バイト数(FP16なら×2)
- 推論 = 重み + KVキャッシュ + オーバーヘッド。GQAはKVが軽い
- Llama 3 8B FP16推論は約19GBで24GB GPUに載る
- フル学習はAdamで1パラメータ16バイト級、活性化込みで巨大
- LoRA/QLoRAは重み凍結で数GB台まで軽くなる
- 概算はLLM VRAM計算機で
- 関連:JavaScriptで機械学習アルゴリズムを実装する