🔥

RTX 5090でPytorchのパフォーマンスが落ちる問題に関して

に公開

RTX 5090を用いて学習を行っている際に「学習速度が期待値に比べ遅い?」となり、問題に関して調査しました。2025/04/22 現在改善していません。

環境

  • RTX 5090
    • NVIDIA Open Driver 570.133.20
      RTX 50シリーズ(Blackwell)はオープンソースドライバのみがサポートされます。

      For cutting-edge platforms such as NVIDIA Grace Hopper or NVIDIA Blackwell, you must use the open-source GPU kernel modules. The proprietary drivers are unsupported on these platforms.[1]

    • CUDA 12.8
    • Pytorch 2.8.0.dev20250419+cu128
  • RTX 4090
    • NVIDIA Proprietary Driver 570.124.06
    • CUDA 12.6
    • Pytorch 2.6.0

問題のリサーチ

コミュニティレベルでRTX 5090におけるパフォーマンスに関する報告があるのですが、どれも問題の核心となる情報ありませんでした。

しかし、GitHubのIssueに詳細な分析があり、RTX 5090で発生する不具合のようです。下記リンクは5090Dにおける報告です。(後のコメントで5090も同様であることが報告されます。)

https://github.com/pytorch/pytorch/issues/150725

PytorchのCollaborator eqy氏と報告者のmobulan氏、コミュニティの熱心な調査により、RTX 5090(D)上ではcuBLASの問題か、CUDAカーネルが適切に呼び出しされていない可能性があります。

自分の環境でも検証する

mobulan氏が最初に報告した内容のコードを実行してみようと思います。コード内容は以下の通りです。GPUの詳細な動作を確認したいためnsys nvprof python3コマンドを使用して実行します。

import time

import torch
from torch import nn

linear = nn.Linear(768, 768).cuda()
x = torch.randn(256, 196, 768).cuda()

torch.cuda.synchronize()
t = time.time()
for i in range(2000):
    y = linear(x)
torch.cuda.synchronize()
print(f"Linear Time:{time.time() - t:.3f}")
print()

x = torch.randn(256, 196, 768).cuda()
weight = nn.Parameter(torch.randn(768, 768).cuda())
bias = nn.Parameter(torch.randn(768).cuda())


torch.cuda.synchronize()
t = time.time()
for i in range(2000):
    # y = F.linear(x, weight, bias)
    y = x @ weight.t() + bias
torch.cuda.synchronize()
print(f"Manu Time:{time.time() - t:.3f}")
print()

RTX 4090

比較として最初にRTX 4090で実行しておきましょう。

Linear Time:5.295

Manu Time:5.188

Generating '/tmp/nsys-report-6f03.qdstrm'
[1/7] [========================100%] report1.nsys-rep
[2/7] [========================100%] report1.sqlite
[3/7] Executing 'nvtx_sum' stats report
[4/7] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)       Med (ns)     Min (ns)     Max (ns)      StdDev (ns)                       Name                     
 --------  ---------------  ---------  -------------  -------------  ---------  -------------  -------------  ---------------------------------------------
     70.0    6,855,494,220      6,000    1,142,582.4    1,301,944.0      1,750    165,079,631    2,397,710.8  cudaLaunchKernel                             
     24.0    2,354,640,648          4  588,660,162.0  559,632,535.0      4,020  1,235,371,558  681,371,063.2  cudaDeviceSynchronize                        
      4.9      480,939,351          2  240,469,675.5  240,469,675.5  2,314,718    478,624,633  336,801,970.8  cudaFree                                     
      0.3       33,178,621         10    3,317,862.1      204,349.5      3,900     31,286,899    9,828,568.6  cudaMalloc                                   
      0.3       27,321,950      6,001        4,552.9          360.0        190     24,774,190      319,801.4  cudaOccupancyMaxActiveBlocksPerMultiprocessor
      0.3       26,629,633          6    4,438,272.2      321,299.0      5,040     13,241,826    6,621,531.5  cudaMemcpyAsync                              
      0.1        9,633,475      2,000        4,816.7        3,184.5      2,040      2,216,309       49,693.6  cudaMemsetAsync                              
      0.0        2,688,716          6      448,119.3       90,554.5     22,859      2,131,910      828,985.2  cudaStreamSynchronize                        
      0.0          374,387        768          487.5          285.0         90         76,659        3,561.0  cuGetProcAddress                             
      0.0          326,749         18       18,152.7          345.0        270        320,729       75,513.1  cudaEventCreateWithFlags                     
      0.0           39,160          7        5,594.3        4,780.0      2,050         13,770        3,960.6  cudaStreamIsCapturing_v10000                 
      0.0            2,920          2        1,460.0        1,460.0      1,400          1,520           84.9  cuInit                                       
      0.0            2,380          3          793.3          180.0        110          2,090        1,123.5  cuModuleGetLoadingMode                       

[5/7] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)   Max (ns)   StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  -----------  -----------  ---------  ---------  -----------  ----------------------------------------------------------------------------------------------------
     90.0    8,216,303,474      4,000  2,054,075.9  1,720,608.5  1,311,917  4,470,187    868,423.0  ampere_sgemm_128x64_tn                                                                              
     10.0      908,172,197      2,000    454,086.1    321,316.0    289,859  2,999,101    456,039.6  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…

[6/7] Executing 'cuda_gpu_mem_time_sum' stats report

 Time (%)  Total Time (ns)  Count   Avg (ns)    Med (ns)   Min (ns)   Max (ns)   StdDev (ns)           Operation          
 --------  ---------------  -----  -----------  ---------  --------  ----------  -----------  ----------------------------
     90.1       26,442,500      6  4,407,083.3  248,338.5       672  13,248,866  6,635,798.1  [CUDA memcpy Host-to-Device]
      9.9        2,915,253  2,000      1,457.6    1,471.0     1,184       3,968         69.9  [CUDA memset]               

[7/7] Executing 'cuda_gpu_mem_size_sum' stats report

 Total (MB)  Count  Avg (MB)  Med (MB)  Min (MB)  Max (MB)  StdDev (MB)           Operation          
 ----------  -----  --------  --------  --------  --------  -----------  ----------------------------
    313.006      6    52.168     2.359     0.003   154.141       78.995  [CUDA memcpy Host-to-Device]
     37.632  2,000     0.019     0.019     0.019     0.019        0.000  [CUDA memset]               

Generated:
    report1.nsys-rep
    report1.sqlite

RTX 5090

Linear Time:7.446

Manu Time:2.749

Collecting data...
Generating '/tmp/nsys-report-e9e6.qdstrm'
[1/7] [========================100%] report4.nsys-rep
[2/7] [========================100%] report4.sqlite
[3/7] Executing 'nvtx_sum' stats report
[4/7] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)       Med (ns)     Min (ns)    Max (ns)     StdDev (ns)                 Name               
 --------  ---------------  ---------  -------------  -------------  ---------  -----------  -------------  ---------------------------------
     51.6    5,225,494,321      2,000    2,612,747.2    2,998,808.0      2,060    3,323,507    1,005,219.4  cudaMemsetAsync                  
     17.1    1,730,084,634      2,000      865,042.3    1,165,948.0      1,730    1,177,938      510,434.4  cuLaunchKernel                   
     16.2    1,645,276,602          4  411,319,150.5  349,377,738.0      3,020  946,518,106  485,597,246.1  cudaDeviceSynchronize            
     14.7    1,485,287,373      8,000      185,660.9       83,038.0      1,700   13,508,433      274,965.8  cudaLaunchKernel                 
      0.2       20,910,374          6    3,485,062.3      149,397.5      5,620   10,872,513    5,291,753.6  cudaMemcpyAsync                  
      0.2       15,231,832          4    3,807,958.0    4,709,806.0    835,934    4,976,286    1,985,846.2  cuLibraryLoadData                
      0.0        2,708,429          1    2,708,429.0    2,708,429.0  2,708,429    2,708,429            0.0  cudaGetDeviceProperties_v2_v12000
      0.0        1,245,415      2,000          622.7          170.0        130      908,453       20,309.9  cuKernelGetFunction              
      0.0        1,086,730          1    1,086,730.0    1,086,730.0  1,086,730    1,086,730            0.0  cudaFree                         
      0.0        1,051,870         10      105,187.0       98,018.5      1,530      209,146       70,021.5  cudaMalloc                       
      0.0          127,650        838          152.3          130.0         50          630          100.9  cuGetProcAddress_v2              
      0.0          119,548          6       19,924.7       21,600.0      9,180       30,969        9,029.1  cudaStreamSynchronize            
      0.0           17,110          7        2,444.3        1,300.0        630        5,790        2,018.4  cudaStreamIsCapturing_v10000     
      0.0           11,120          5        2,224.0        1,830.0        480        4,140        1,801.6  cuLibraryGetKernel               
      0.0           10,140         18          563.3          160.0        140        4,880        1,148.3  cudaEventCreateWithFlags         
      0.0            4,200          3        1,400.0        1,080.0        790        2,330          818.4  cuInit                           
      0.0            1,390          3          463.3          130.0         70        1,190          630.0  cuModuleGetLoadingMode           
      0.0              810          2          405.0          405.0        110          700          417.2  cudaGetDriverEntryPoint_v11030   

[5/7] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)   Max (ns)   StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  -----------  -----------  ---------  ---------  -----------  ----------------------------------------------------------------------------------------------------
     59.6    6,023,073,030      2,000  3,011,536.5  3,010,899.0  3,004,370  3,333,169      7,971.7  void sgemm_largek_lds64<(bool)1, (bool)0, (int)5, (int)5, (int)4, (int)4, (int)4, (int)34>(float *,…
     23.2    2,349,767,702      2,000  1,174,883.9  1,174,907.0  1,169,179  1,200,219      2,173.0  void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_8x4_tn_align1>(T1::Params)                      
     11.8    1,192,299,317      2,000    596,149.7    596,029.0    578,302    829,884      7,174.9  void cublasLt::epilogue::impl::globalKernel<(int)8, (int)32, float, float, float, (bool)1, (bool)1,…
      3.7      378,948,315      2,000    189,474.2    189,439.0    187,103    193,695        801.5  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
      1.6      163,026,527      2,000     81,513.3     81,408.0     68,320     84,895        736.5  void scal_kernel<float, float, (int)1, (bool)1, (int)6, (int)5, (int)5, (int)3>(cublasTransposePara…

[6/7] Executing 'cuda_gpu_mem_time_sum' stats report

 Time (%)  Total Time (ns)  Count   Avg (ns)    Med (ns)   Min (ns)   Max (ns)   StdDev (ns)           Operation          
 --------  ---------------  -----  -----------  ---------  --------  ----------  -----------  ----------------------------
     87.3       20,656,357      6  3,442,726.2  105,696.0       544  10,781,266  5,263,203.2  [CUDA memcpy Host-to-Device]
     12.7        2,997,299  2,000      1,498.6    1,440.0     1,216       4,960        412.7  [CUDA memset]               

[7/7] Executing 'cuda_gpu_mem_size_sum' stats report

 Total (MB)  Count  Avg (MB)  Med (MB)  Min (MB)  Max (MB)  StdDev (MB)           Operation          
 ----------  -----  --------  --------  --------  --------  -----------  ----------------------------
    602.112  2,000     0.301     0.301     0.301     0.301        0.000  [CUDA memset]               
    313.006      6    52.168     2.359     0.003   154.141       78.995  [CUDA memcpy Host-to-Device]

Generated:
	report4.nsys-rep
	report4.sqlite

GPU RTX 4090 RTX 5090
Linear Time 5.295 7.446
Manu Time 5.188 2.749

y = x @ weight.t() + biasの計算時間を示すManu TimeはRTX 4090の約1.89倍で十分に速いですが、y = linear(x)の計算時間を示すLinear TimeはRTX 5090の方が遅いですね...。

限定的な対処方法

eqy氏の報告でカーネルの選択が不適切であることが指摘され、TF32を有効にするオプションを使用することで適切なカーネルが使用されることが判明しました。

OK, glad you can see some improvement but I still cannot reproduce the non-TF32 kernel selection for some reason :/
In case you are OK with using TF32 (note that precision can suffer), you can set this in the script with torch.backends.cuda.matmul.allow_tf32 = True rather than using an environment variable. Note that there would also be a similar speed increase if you chose to do this on 4090.

https://github.com/pytorch/pytorch/issues/150725#issuecomment-2798199049

TF(TensorFloat)32はNVIDIAが提案する新しいデータ型です。NVIDIAの報告によると精度低下は少ないそうです。eqy氏も精度が許せばTF32を使用することを提案しています。

使用方法

  • TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=1 python3 hoge.pyのように環境変数を指定する。
  • Pythonコード内にtorch.backends.cuda.matmul.allow_tf32 = Trueを加える。

TF32を使用して測定

RTX 4090

同様にRTX 4090から始めます。

Linear Time:1.583

Manu Time:2.161

Generating '/tmp/nsys-report-c509.qdstrm'
[1/7] [========================100%] report2.nsys-rep
[2/7] [========================100%] report2.sqlite
[3/7] Executing 'nvtx_sum' stats report
[4/7] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)       Med (ns)     Min (ns)    Max (ns)     StdDev (ns)                Name
 --------  ---------------  ---------  -------------  -------------  ---------  -----------  -------------  ------------------------------
     50.4    1,856,046,221      4,000      464,011.6      743,186.5      1,780      809,661      369,929.6  cuLaunchKernel
     36.0    1,323,851,832          4  330,962,958.0  272,609,611.5      2,830  778,629,779  393,859,024.5  cudaDeviceSynchronize
     12.7      467,790,132      2,000      233,895.1      308,531.0      1,890    9,206,196      241,596.2  cudaLaunchKernel
      0.5       16,795,549          6    2,799,258.2      126,505.0      5,410    8,601,755    4,239,930.0  cudaMemcpyAsync
      0.3       10,323,117          4    2,580,779.3    3,196,510.0    524,941    3,405,156    1,374,991.1  cuLibraryLoadData
      0.1        4,023,897      4,000        1,006.0          190.0        150    3,242,925       51,272.1  cuKernelGetFunction
      0.0        1,084,922          1    1,084,922.0    1,084,922.0  1,084,922    1,084,922            0.0  cudaFree
      0.0          947,550         10       94,755.0       73,970.0      2,110      190,610       65,543.8  cudaMalloc
      0.0          177,830          6       29,638.3       38,225.0      8,090       44,600       16,567.8  cudaStreamSynchronize
      0.0          153,990      2,000           77.0           80.0         60          990           24.2  cudaMemsetAsync
      0.0          133,220        810          164.5          150.0         50          640          107.8  cuGetProcAddress_v2
      0.0           17,510          7        2,501.4        1,330.0        800        6,380        2,032.2  cudaStreamIsCapturing_v10000
      0.0            7,190         18          399.4          150.0        130        1,850          560.1  cudaEventCreateWithFlags
      0.0            5,030          5        1,006.0          560.0        300        3,230        1,251.3  cuLibraryGetKernel
      0.0            2,950          3          983.3          670.0        600        1,680          604.3  cuInit
      0.0              760          3          253.3          130.0        120          510          222.3  cuModuleGetLoadingMode
      0.0              470          2          235.0          235.0        120          350          162.6  cudaGetDriverEntryPoint_v11030

[5/7] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                                                  Name
 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     83.0    3,063,579,069      4,000  765,894.8  759,690.0   747,626   821,675     19,292.5  void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_128x256_16x3_tn_align4>(T1::Params)
     17.0      626,660,493      2,000  313,330.2  313,252.0   310,884   328,260        957.1  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…

[6/7] Executing 'cuda_gpu_mem_time_sum' stats report

 Time (%)  Total Time (ns)  Count   Avg (ns)    Med (ns)   Min (ns)  Max (ns)   StdDev (ns)           Operation
 --------  ---------------  -----  -----------  ---------  --------  ---------  -----------  ----------------------------
    100.0       16,635,447      6  2,772,574.5  103,361.5       576  8,549,710  4,220,313.1  [CUDA memcpy Host-to-Device]

[7/7] Executing 'cuda_gpu_mem_size_sum' stats report

 Total (MB)  Count  Avg (MB)  Med (MB)  Min (MB)  Max (MB)  StdDev (MB)           Operation
 ----------  -----  --------  --------  --------  --------  -----------  ----------------------------
    313.006      6    52.168     2.359     0.003   154.141       78.995  [CUDA memcpy Host-to-Device]

Generated:
    report2.nsys-rep
    report2.sqlite

使用されるカーネルが変わり、TF32を使用することでFP32のRTX 4090より速くなっています。

GPU RTX 4090 FP32 RTX 4090 TF32
Linear Time 5.295 1.583
Manu Time 5.188 2.161

RTX 5090

Linear Time:1.337

Manu Time:1.667

Collecting data...
Generating '/tmp/nsys-report-498b.qdstrm'
[1/7] [========================100%] report2.nsys-rep
[2/7] [========================100%] report2.sqlite
[3/7] Executing 'nvtx_sum' stats report
[4/7] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls   Avg (ns)     Med (ns)    Min (ns)  Max (ns)   StdDev (ns)                Name               
 --------  ---------------  ---------  -----------  -----------  --------  ---------  -----------  ---------------------------------
     51.9       1526711325       4000     381677.8     623785.0      1690     643569     304898.1  cuLaunchKernel                   
     36.7       1077921895          4  269480473.8  211326728.5      3770  655264668  325332923.2  cudaDeviceSynchronize            
      9.8        289095785       2000     144547.9     185527.0      1910   12893603     296443.3  cudaLaunchKernel                 
      0.7         19582693          6    3263782.2     141142.5      6030    9960229    4946413.6  cudaMemcpyAsync                  
      0.5         15574760          4    3893690.0    4837992.5    885736    5013039    2007207.4  cuLibraryLoadData                
      0.2          5655355       4000       1413.8        170.0       150    4915411      77717.5  cuKernelGetFunction              
      0.1          2424561          1    2424561.0    2424561.0   2424561    2424561          0.0  cudaGetDeviceProperties_v2_v12000
      0.0          1103053          1    1103053.0    1103053.0   1103053    1103053          0.0  cudaFree                         
      0.0          1008262         10     100826.2      77018.5      1780     211936      71809.9  cudaMalloc                       
      0.0           173449       2000         86.7         80.0        69       1160         28.4  cudaMemsetAsync                  
      0.0           145077        838        173.1        160.0        50        770        116.1  cuGetProcAddress_v2              
      0.0            93448          6      15574.7      19030.0      3230      21740       7403.1  cudaStreamSynchronize            
      0.0            17060          7       2437.1       1210.0       590       6060       2175.5  cudaStreamIsCapturing_v10000     
      0.0            15830          6       2638.3       2845.0       650       4360       1641.1  cuLibraryGetKernel               
      0.0            10510         18        583.9        165.0       140       4730       1132.2  cudaEventCreateWithFlags         
      0.0             4770          4       1192.5        955.0       510       2350        822.4  cuInit                           
      0.0             1230          3        410.0        130.0        70       1030        537.8  cuModuleGetLoadingMode           
      0.0              580          2        290.0        290.0       180        400        155.6  cudaGetDriverEntryPoint_v11030   

[5/7] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     43.7       1278779479       2000  639389.7  639039.0    637535    651615       1715.9  void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_128x128_16x5_tn_align4>(T1::Params)             
     43.4       1268932877       2000  634466.4  633279.0    630015    644639       3071.2  void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_128x256_16x3_tn_align4>(T1::Params)             
     12.9        378866711       2000  189433.4  189408.0    186399    194367        902.8  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…

[6/7] Executing 'cuda_gpu_mem_time_sum' stats report

 Time (%)  Total Time (ns)  Count  Avg (ns)   Med (ns)  Min (ns)  Max (ns)  StdDev (ns)           Operation          
 --------  ---------------  -----  ---------  --------  --------  --------  -----------  ----------------------------
    100.0         19346391      6  3224398.5  100624.0       512   9882347    4920908.4  [CUDA memcpy Host-to-Device]

[7/7] Executing 'cuda_gpu_mem_size_sum' stats report

 Total (MB)  Count  Avg (MB)  Med (MB)  Min (MB)  Max (MB)  StdDev (MB)           Operation          
 ----------  -----  --------  --------  --------  --------  -----------  ----------------------------
    313.006      6    52.168     2.359     0.003   154.141       78.995  [CUDA memcpy Host-to-Device]

Generated:
	report2.nsys-rep
	report2.sqlite

選択されているカーネルが変化し、速くなっていることがわかります。

GPU RTX 4090 RTX 5090 RTX 4090 TF32 RTX 5090 TF32
Linear Time 5.295 7.446 1.583 1.337
Manu Time 5.188 2.749 2.161 1.667
脚注
  1. https://developer.nvidia.com/blog/nvidia-transitions-fully-towards-open-source-gpu-kernel-modules/ ↩︎

Discussion