Revisiting GPU Parallelism Mental Model

Flash Attention教科書のTiled Matmul を読んで意外だったのは、これらが単一の Streaming Multiprocessor の上で実装されていることだった。Triton もそういう実装を想定し, Per-SM というか Per-ThreadBlock の実行モデルを持っている。

SRAM に tile を載せたい動機を考えると当然といえば当然だけれど、古いコンピュータ・グラフィクスの世界観で理解している GPU とは随分違って、自分のメンタルモデルを書き直す必要を感じた。

古いコンピュータグラフィクスの世界では、SM とか ThreadBlock みたいのは実装の詳細であり、プログラマは気にしない。GL のシェーダにも (compute shader は別とすると) 基本的には SRAM/SHARED みたいな概念はない。個々の WARP も独立して動き、syncthread() とかもない。

したがってグラフィクスの世界では GPU のスループットを増やせば頂点やピクセルが十分にある限り自動的に計算資源を使い切れる。GPU のスケールの仕方はあまり気にしない。SM の数が増えても SM 内の compute core / warp の数が増えても、結果はだいたい同じである。これは単純化しすぎてグラフィクスプログラマに怒られるとおもうけど、メンタルモデルとしては間違ってないと思う。

一方 Flash Attention みたいのは GPU のスループット向上が性能向上につながるとは限らない。つまり、GPU 内の SM の数が増えても単一の Flash Attention は速くならない。Attention はふつう multi-head なので head の数だけ並列化はできる。あと training ならバッチサイズを増やせば並列度は増す。inference serving も本質的に並列である。ただグラフィクスと比べると透過的でない。それなりに気にしないといけない。

個々の Flash Attention 自体を速くしようと思ったら、SM を大きくしないといけない。単一 SM 内の CUDA Core の数を増やしたり、SRAM をでかくしてタイルを広げたり、あとは Tensor Core を載せたりみたいなのもある。ただこういう細部は FLOPS だけ見てるとわからない気がする。


誰かがローカル LLM 遊びをしようと気の迷いで A100 を買ったとして、そのスループットを活用できるのだろうか。A100 には 108 も SM が入っているわけで、ローカル用途の LLM でそんな並列度あるの?(買わねーよという話は置いといて。うちのアパートとか一瞬でブレーカ落ちるわ。)

・・・と思って Lhama2 を見てみると、わー 65B で 64 heads! 64/108 半分以上は使えるのか。思ったより並列度が高かった。7B ですら 32 heads もある。ご家庭の GPU なら十分に活用できそうである。というかモデルのサイズに合わせた compute を使い切れるように調整してあるのだろうな。

A100 はさておき家庭用で買える GPU で一番奮発した型番は 4090 らしいのでスペック紹介記事を見ると・・・ 128 SM! A100 より多いじゃねーか! 64 heads だと半分しか使えないじゃん。とはいえ様々なトリックでスループットを使い切る工夫はあると思われるので, heads で半分埋まるくらいなら案外丁度いいのかもしれない。全力で回すと画面が固まりそうだし。GPU だけに。

結論としては、Triton のような SM-based programming の世界はグラフィクスと比べると GPU が transparently scalable ではないが、AI 人材はご家庭の GPU を使い切れるよう、おおむねきちんとモデルのサイズを調整していた。よかったね。