👻

NVIDIA NeMoを利用したGPT-OSSの学習

に公開

はじめに

Turing CTO室に所属している東京科学大学(Institute of Science Tokyo)の藤井です。
本記事では、OpenAIから2025年8月にリリースされたgpt-ossNVIDIA NeMoフレームワークにて学習するための方法について解説します。

2025年11月4日時点では、NVIDIA公式からは、LoRA finetunigを行う方法についてのみ解説されており、Long Context継続事前学習(Continual Pre-Training)など本格的な学習を行うにはハードルが多数あります。
本記事では、学習を行うために解決する必要があるすべての問題に関して、詳細な解決方法を記しました。gpt-ossを利用したモデル学習にお役立てください。

gpt-oss

About

gpt-ossとは、OpenAIよりリリースされたLLMであり、gpt-oss-20bgpt-oss-120bの2つのモデルサイズがあります。いずれのモデルも以下のように高い言語処理能力を英語では示しています。


Artificial Analysisより

しかしながら、日本に関する知識や日本語能力については限定されており、改善の余地が存在します。

モデルアーキテクチャ

gpt-ossのモデルアーキテクチャには特筆するべき点がいくつかあります。
昨今のオープンLLMで採用されているアーキテクチャとは異なる点が多く、それにより以下で述べるように学習を行う上でのハードルが上がっています。

  1. bias項の存在: Llama-2以降、多くのOpenLLMではMLP, Attentionともにbias項がないのが一般的でした。しかし、gpt-ossでは、GPT-2の時代と同様にbias項が存在しています。
  2. QK Normの欠如: Qwen3にも導入されているように昨今のLLMでは学習安定化のためにQK Normを入れることが増えていますが、gpt-ossでは導入されていません。
  3. self-attention sink(learnable softmax)の存在: 導入背景などについての解説は控えますが、softmaxの分母に学習可能なバイアス項が導入されています。

上記のようなアーキテクチャの変更がモデル性能に及ぼしている影響は大きくないと推測されますが、学習を行う上では、とくに3番目の点が弊害となります。

NGC

gpt-ossを学習するための方法を調べるとNVIDIA NeMo Framework User Guideが目に付くでしょう。
そこでは、NVIDIAの NeMo Framework用のコンテナが紹介されており、25.07.gpt_ossというコンテナを利用すれば非常に簡単に学習可能であるかのように書かれています。(小規模なfinetuningであればその通りです)

しかし、Long Context学習や、Continual Pre-Training(継続事前学習)を行うとなるとそうもいきません。本節では、NGCを利用して学習環境を整える様子について解説を行います。

実装の摘出

以下ではスパコン(スーパーコンピューター)での作業を想定して、singularityを利用して作業を行います。適時、コマンドをお使いの環境に合わせて読み替えてください。

まず、以下のように25.07.gpt_oss.defを作成し、singularity buildを行います。

25.07.gpt_oss.def
Bootstrap: docker
From: nvcr.io/nvidia/nemo:25.07.gpt_oss

%post
  pip install --no-cache-dir wandb transformers datasets jsonlines tqdm

なお、buildを行う際は、Lustre, NFS上ではなくできるだけ/scratchなどのLocal Storageで行うことで処理時間を短縮することをオススメします。

cd /scratch
export SINGULARITY_TMPDIR=/scratch/tmp

singularity build --sandbox 25.07.gpt_oss 25.07.gpt_oss.def

.sifを作成するとRead onlyになってしまうので、以下で作業を行うことを想定してsandboxを作成します。

このコンテナの中で利用されているNeMoやMegatron-LMはGitHubにてtag打ちされているものと異なる実装のため、git管理下に置くためにコンテナから実装を摘出します。

singularity shell --bind /path/to/your:/path/to/your 25.07.gpt_oss
Singularity>

コンテナ内に入り次第、/opt/NeMo/, /opt/megatron-lmにある実装をコンテナ外のパスにcopyして、コンテナから抜けてもアクセスできるようにします。

cp -R /opt/NeMo /path/to/your
cp -R /opt/megatron-lm /path/to/your

なお、下記で述べる修正を行ったMegatron-LMをGitHub上で公開していますので、ご自由にご利用ください。

https://github.com/okoge-kaz/gpt-oss-megatron-lm

変更のapply

こちらのPull Requestにかかれているように、このコンテナ内の実装を利用するとYarnの実装がHuggingFace実装と乖離してしまっているようです。そこで、以下のようにコンテナ内のmegatron-lmの実装を修正する必要があります。

Singularity> vim /opt/megatron-lm/megatron/core/models/common/embeddings/rope_utils.py
Singularity> rm /opt/megatron-lm/megatron/core/models/common/embeddings/rope_utils.py
Singularity> vim /opt/megatron-lm/megatron/core/models/common/embeddings/rope_utils.py
Singularity> rm /opt/megatron-lm/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py
Singularity> vim /opt/megatron-lm/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py
Singularity> rm /opt/megatron-lm/megatron/core/transformer/dot_product_attention.py
Singularity> vim /opt/megatron-lm/megatron/core/transformer/dot_product_attention.py
Singularity> rm /opt/megatron-lm/megatron/core/transformer/utils.py
Singularity> vim /opt/megatron-lm/megatron/core/transformer/utils.py

なお、修正が必要な差分は以下のとおりです。
https://github.com/okoge-kaz/gpt-oss-megatron-lm/commit/01b3824fe9d81b211b8aee6bfb35bd92169f8eb9

NeMo

コンテナ内から摘出してきたNeMoをGit管理下におき、実装を行っていきます。

現状

まず、現状を確認しましょう。

gpt-ossの学習をNeMoで行うには、HuggingFace形式で公開されているcheckpointをNeMoで読み込めるようにNeMo形式のcheckpointに変換する必要があります。
また、コンテナから摘出したNeMo内のtutorials/llm/gpt-oss/ticket-routing-lora/gpt-oss-lora.ipynbにあるtutorialはLoRA SFTしか解説していないばかりか、nemo/collections/llm/recipes/gpt_oss_20b.pyの実装もpretrain用の実装をサポートしていません。

そのため、まだまだ道のりは遠そうです...。
1つ1つ片付けていきましょう。

hf -> nemo

公式のドキュメントにconvertスクリプトの使い方が書いてあるのですが、正直分かりづらいです。
以下のようにconvert scriptを実装し、利用すると簡単に使用できます。

experiments/ckpt-convert/hf-to-nemo/gpt-oss.py
import argparse
from nemo.collections import llm

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Convert Hugging Face GPT-OSS checkpoints to NeMo format.")
    parser.add_argument(
        "--model-size",
        type=str,
        choices=["20B", "120B"],
        required=True,
        help="Size of the GPT-OSS model to convert (20B or 120B).",
    )
    parser.add_argument(
        "--hf-checkpoint-path", type=str, required=True, help="Path to the Hugging Face GPT-OSS checkpoint."
    )
    parser.add_argument(
        "--nemo-output-path", type=str, required=True, help="Path to save the converted NeMo checkpoint."
    )
    args = parser.parse_args()
    # For GPT-OSS 20B
    if args.model_size == "20B":
        llm.import_ckpt(
            model=llm.GPTOSSModel(llm.GPTOSSConfig20B()),
            source="hf://" + args.hf_checkpoint_path,
            output_path=args.nemo_output_path,
        )
    # For GPT-OSS 120B
    elif args.model_size == "120B":
        llm.import_ckpt(
            model=llm.GPTOSSModel(llm.GPTOSSConfig120B()),
            source="hf://" + args.hf_checkpoint_path,
            output_path=args.nemo_output_path,
        )
    else:
        raise ValueError("Unsupported model size. Choose either '20B' or '120B'.")

    print(f"Conversion complete! NeMo checkpoint saved at {args.nemo_output_path}")

上記のように実装したスクリプトを利用して以下のようにすることで、HuggingFace formatのcheckpointをNeMo形式のcheckpointにconvertすることが可能です。

HF_CHECKPOINT_PATH="/path/tp/gpt-oss-20b"
NEMO_OUTPUT_PATH="/path/to/checkpoints/hf-to-nemo/gpt-oss-20B.nemo"
mkdir -p $(dirname ${NEMO_OUTPUT_PATH})

export NUMEXPR_MAX_THREADS=192

singularity exec \
  --nv \
  --bind /path/to:/path/to \
  --bind /tmp:/tmp \
  /path/to/25.07.gpt_oss.sif \
  python experiments/ckpt-convert/hf-to-nemo/gpt-oss.py \
    --model-size 20B \
    --hf-checkpoint-path ${HF_CHECKPOINT_PATH} \
    --nemo-output-path ${NEMO_OUTPUT_PATH}

これで、NeMo形式のcheckpointを得ることが出来ました。

pretrain_recipe

nemo/collections/llm/recipes/gpt_oss_20b.py, nemo/collections/llm/recipes/gpt_oss_120b.pyを見るとpretrain recipeがないことに気づきます。(以下のようなfinetune recipeしかありません)

@run.cli.factory(target=finetune, name=NAME)
def finetune_recipe(
    dir: Optional[str] = None,
    resume_path: str = "openai/gpt-oss-20b",
    name: str = "default",
    num_nodes: int = 1,
    num_gpus_per_node: int = 8,
    peft_scheme: Optional[str] = "lora",
    packed_sequence: bool = False,
) -> run.Partial:

そこで、以下のようにpretrain recipeを実装していきます。
あくまで以下は実装の一例であり、オプション等をイジらない場合はもっと簡素化することができると思います。

@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
    dir: Optional[str] = None,
    name: str = "default",
    num_nodes: int = 1,
    num_gpus_per_node: int = 8,
    performance_mode: bool = False,
    tensor_parallel_size: int = 1,
    context_parallel_size: int = 1,
    expert_parallel_size: int = 1,
    pipeline_parallel_size: int = 1,
    sequence_parallelism: bool = False,
    seq_length: int = 32768,
    global_batch_size: int = 256,
    micro_batch_size: int = 1,
    lr: float = 3e-4,
    min_lr: float = 3e-5,
    train_steps: int = 25000,
    warmup_steps: int = 1000,
    adam_beta1: float = 0.9,
    adam_beta2: float = 0.95,
    adam_eps: float = 1e-8,
    weight_decay: float = 0.1,
    clip_grad: float = 1.0,
    constant_step: int = 0,
    fp8: str = "",
    fn: Callable = pretrain,
) -> run.Partial:
    recipe = run.Partial(
        fn,
        model=model(),
        trainer=trainer(
            num_nodes=num_nodes,
            num_gpus_per_node=num_gpus_per_node,
            tensor_parallelism=tensor_parallel_size,
            context_parallelism=context_parallel_size,
            pipeline_parallelism=pipeline_parallel_size,
            sequence_parallelism=sequence_parallelism,
            expert_parallel_size=expert_parallel_size,
            fp8=fp8,
            callbacks=[
                run.Config(
                    TimingCallback,
                    log_tokens_per_sec=True,
                ),
            ],
        ),
        data=run.Config(
            MockDataModule,
            seq_length=seq_length,
            global_batch_size=global_batch_size,
            micro_batch_size=micro_batch_size,
        ),
        log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)),
        optim=distributed_fused_adam_with_cosine_annealing(
            train_steps=train_steps,
            warmup_steps=warmup_steps,
            constant_steps=constant_step,
            adam_beta1=adam_beta1,
            adam_beta2=adam_beta2,
            adam_eps=adam_eps,
            max_lr=lr,
            min_lr=min_lr,
            weight_decay=weight_decay,
            clip_grad=clip_grad,
        ),
        resume=default_resume(),
    )

次に、trainer()も実装してしまいます。
(こちらも以下は一例ですので、用途に合わせて実装粒度は変更してください)

def trainer(
    tensor_parallelism: int = 1,
    pipeline_parallelism: int = 1,
    pipeline_parallelism_type: Optional[torch.dtype] = None,
    virtual_pipeline_parallelism: Optional[int] = None,
    context_parallelism: int = 2,
    expert_parallel_size: int = 4,
    sequence_parallelism: bool = False,
    num_nodes: int = 1,
    num_gpus_per_node: int = 8,
    max_steps: int = 1168251,
    fp8: str = "",
    callbacks: Optional[list[run.Config[Callback]]] = None,
) -> run.Config[nl.Trainer]:
    strategy = run.Config(
        nl.MegatronStrategy,
        tensor_model_parallel_size=tensor_parallelism,
        pipeline_model_parallel_size=pipeline_parallelism,
        pipeline_dtype=pipeline_parallelism_type,
        virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism,
        context_parallel_size=context_parallelism,
        expert_model_parallel_size=expert_parallel_size,
        sequence_parallel=sequence_parallelism,
        gradient_as_bucket_view=True,
        ckpt_async_save=True,
        ckpt_parallel_load=True,
        ddp=run.Config(
            DistributedDataParallelConfig,
            check_for_nan_in_grad=True,
            grad_reduce_in_fp32=True,
            overlap_grad_reduce=True,
            overlap_param_gather=True,
            average_in_collective=True,
            data_parallel_sharding_strategy="optim_grads_params",  # For custom FSDP only
        ),
        fsdp=None,  # Set to 'megatron' to use Megatron FSDP, 'pytorch' to use PyTorch FSDP 2 (WIP)
    )

    precision = None
    if fp8 == "current":
        precision = nemotron_h_bf16_with_fp8_current_scaling_mixed()
    elif fp8 == "blockwise":
        precision = bf16_with_fp8_subchannel_scaling_mixed()
    else:
        precision = bf16_mixed()

    trainer = run.Config(
        nl.Trainer,
        accelerator="gpu",
        accumulate_grad_batches=1,
        callbacks=callbacks,
        devices=num_gpus_per_node,
        limit_test_batches=50,
        limit_val_batches=32,
        log_every_n_steps=1,
        max_steps=max_steps,
        num_nodes=num_nodes,
        plugins=precision,
        strategy=strategy,
        use_distributed_sampler=False,
        val_check_interval=2000,
        enable_progress_bar=False,
    )

    return trainer

Wandb Loggerに渡すCallBackの実装や、checkpoint saveパスをMegatron-LM互換にするための実装、データセット関係など詳細はまだあるのですが、ここでは割愛します。

現状

ここまで来れば学習できるようになったと思いたいのですが、そうもいきません。
現状では、learnable softmax(gpt-oss独自の機構)に対応したDotProductAttentionFlashAttention、install済みのTransformerEngine存在しないので、Context Parallelismを利用することができず、学習可能なcontext sizeが8,192あたりに制限されてしまいます。

実際、無理やり学習しようとcontext parallel size > 1としてみると以下のようなエラーが出ます。

[rank62]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py", line 1370, in forward
[rank62]:     raise ValueError(
[rank62]: ValueError: No dot product attention backend is available for the provided inputs. Please run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 to find out the reasons for disabling all backends.

エラー文にあるように、DEBUGフラグを付けて実行すると以下のようになります。

export NEMO_LOG_TRAIN_LOSS=1
export NEMO_LOG_MEMORY_USAGE=1

以下のログでは、TransformerEngineにおいてどのような設定が渡されて、その結果 Attention Backendとして何が選択されたかのログが出ています。

DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.xxx.xxx', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.3', 'flash_attn_3_version': 'not installed', 'cudnn_version': '9.13.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'sbhd_sbhd_sbhd', 'batch_size': 1, 'num_heads': 64, 'num_gqa_groups': 8, 'max_seqlen_q': 32768, 'max_seqlen_kv': 32768, 'head_dim_qk': 64, 'head_dim_v': 64, 'attn_mask_type': 'causal', 'window_size': (128, 0), 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': True, 'cp_comm_type': 'a2a', 'deterministic': False, 'is_training': False, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'learnable', 'return_max_logit': False}
[DEBUG    | DotProductAttention]: Disabling FusedAttention as no backend supports the provided input
DEBUG:DotProductAttention:Disabling FusedAttention as no backend supports the provided input
[DEBUG    | DotProductAttention]: Available backends = {FlashAttention=False, FusedAttention=False, UnfusedDotProductAttention=False}
DEBUG:DotProductAttention:Available backends = {FlashAttention=False, FusedAttention=False, UnfusedDotProductAttention=False}
[DEBUG    | DotProductAttention]: Disabling FusedAttention as no backend supports the provided input
DEBUG:DotProductAttention:Disabling FusedAttention as no backend supports the provided input
[DEBUG    | DotProductAttention]: Disabling FlashAttention 2 due to NVTE_FLASH_ATTN=0
DEBUG:DotProductAttention:Disabling FlashAttention 2 due to NVTE_FLASH_ATTN=0
[DEBUG    | DotProductAttention]: Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0
DEBUG:DotProductAttention:Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0
[DEBUG    | DotProductAttention]: Disabling FlashAttention for softmax_type = learnable
[DEBUG    | DotProductAttention]: Available backends = {FlashAttention=False, FusedAttention=False, UnfusedDotProductAttention=False}
DEBUG:DotProductAttention:Available backends = {FlashAttention=False, FusedAttention=False, UnfusedDotProductAttention=False}
[DEBUG    | DotProductAttention]: Disabling FlashAttention 2 due to NVTE_FLASH_ATTN=0
DEBUG:DotProductAttention:Disabling FlashAttention 2 due to NVTE_FLASH_ATTN=0
[DEBUG    | DotProductAttention]: Selected backend = NoBackend
DEBUG:DotProductAttention:Selected backend = NoBackend
[DEBUG    | DotProductAttention]: Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0
DEBUG:DotProductAttention:Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0

ログに記載されているように、FlashAttention, FusedAttention, UnFusedAttentionのどれも利用することが出来なかったので、エラーが発生しています。

TransformerEngineのupdate

GPT-OSSのlearnable softmaxに対応する実装を自前で実装しようとしましたが、この程度の実装であればNVIDIAのTransformerEngine teamが実装していないはずはないと思い直し、調査を始めました。すると以下のPull Requestを発見しました。

https://github.com/NVIDIA/TransformerEngine/pull/2148

実装を見るとcommunication typeがp2pではなくa2a(=all to all)のContext Parallel対応のlearnable softmax向けのFusedAttentionがあることが判明しました。
そこで、sandboxに入り、TransformerEngineのversionをupdateすることで対応しました。

cuDNNのupdate

これで完了かと思いきや、そうではありません。
まだ動きません。 以下のようなデバッグメッセージが出ます。

[DEBUG    | DotProductAttention]: Disabling FusedAttention as no backend supports the provided input
DEBUG:DotProductAttention:Disabling FusedAttention as no backend supports the provided input
[DEBUG    | DotProductAttention]: Available backends = {FlashAttention=False, FusedAttention=False, UnfusedDotProductAttention=False}
DEBUG:DotProductAttention:Available backends = {FlashAttention=False, FusedAttention=False, UnfusedDotProductAttention=False}
[DEBUG    | DotProductAttention]: Disabling FusedAttention as no backend supports the provided input
DEBUG:DotProductAttention:Disabling FusedAttention as no backend supports the provided input

Disabling FusedAttention as no backend supports the provided inputが出る箇所を探すと以下の実装が見つかります。

https://github.com/NVIDIA/TransformerEngine/blob/e7227af98070ebfcdb08b7f0a99bb87abe7b8532/transformer_engine/common/fused_attn/fused_attn.cpp#L373-L376

つまり、cuDNNのversionが91301(=9.13.1)未満であるため、FusedAttentionが利用できていないということです。もう一度先程のPullRequestを見るとDescriptionに以下のようにあります。

FusedAttention backend for FP16/BF16 and BSHD/SBHD: cuDNN 9.13.1+ and cudnn-frontend 1.14.1

そのうえで、cuDNNのリリースリストを確認すると9.13.1はつい最近出たことが判明しました。

Singularity内部のcuDNN versionを確認すると以下のように9.13.0であることが分かります。

Singularity> ls /usr/local/cudnn/lib64/
libcudnn.so		libcudnn_adv_static.a	  libcudnn_cnn_static_v9.a		    libcudnn_engines_runtime_compiled.so	   libcudnn_graph.so.9	       libcudnn_heuristic.so.9.13.0    libcudnn_ops_static.a
libcudnn.so.9		libcudnn_adv_static_v9.a  libcudnn_engines_precompiled.so	    libcudnn_engines_runtime_compiled.so.9	   libcudnn_graph.so.9.13.0    libcudnn_heuristic_static.a     libcudnn_ops_static_v9.a
libcudnn.so.9.13.0	libcudnn_cnn.so		  libcudnn_engines_precompiled.so.9	    libcudnn_engines_runtime_compiled.so.9.13.0    libcudnn_graph_static.a     libcudnn_heuristic_static_v9.a
libcudnn_adv.so		libcudnn_cnn.so.9	  libcudnn_engines_precompiled.so.9.13.0    libcudnn_engines_runtime_compiled_static.a	   libcudnn_graph_static_v9.a  libcudnn_ops.so
libcudnn_adv.so.9	libcudnn_cnn.so.9.13.0	  libcudnn_engines_precompiled_static.a     libcudnn_engines_runtime_compiled_static_v9.a  libcudnn_heuristic.so       libcudnn_ops.so.9
libcudnn_adv.so.9.13.0	libcudnn_cnn_static.a	  libcudnn_engines_precompiled_static_v9.a  libcudnn_graph.so				   libcudnn_heuristic.so.9     libcudnn_ops.so.9.13.0

つまり、9.13.0<9.13.1なため、FusedAttentionが利用できないということです。

解決策

解決策には2つの方法があります。コンテナ内部のcuDNNをどうにかしてupdateする方法、もう1つはcuDNN 9.13.1以降がすでに入っている環境でコンテナを作り直すことです。
NeMoの依存関係は複雑であるためcuDNN 9.13.1以上が入っているNGC PyTorchの上からNeMoの依存関係を作ることは時間を要することが予想されます。
そこで、コンテナ内のcuDNNのversionを上げる選択を検討することにしました。

さて、cuDNNを9.13.0から9.14.0(執筆時の最新)に置き換えた場合、cuDNNに依存しているPyTorch、TransformerEngineの再buildは必要でしょうか?仮に必要な場合は、実質再度、コンテナを作り直す必要があるので、NGC PyTorchから作り直す方が安上がりでしょう。

答えは、です。
詳細は以下のBlogにて書いていますが、PyTorch, TransformerEngineはcuDNNを共有ライブラリで利用しているため、コンテナ内の既定パスに新しいcuDNNをbindするとランタイムが新しいcuDNNを参照することになるので、再度buildすることが不要となっています。

https://zenn.dev/turing_motors/articles/3a434d046bbf48

そのため以下のようにすることで対処できます。

CUDNN_ROOT="/path/to/cudnn/cudnn-linux-x86_64-9.14.0.64_cuda12-archive"

singularity exec --nv .... \
  --bind ${CUDNN_ROOT}/lib:/usr/local/cudnn/lib64:ro \
  --bind ${CUDNN_ROOT}/include:/usr/local/cudnn/include:ro \
  /groups/gch51639/fujii/container/25.07.gpt_oss.fix.sif \

以上の変更により、context length 32k, context parallel size = 4にてGPT-OSSの継続事前学習(continual pre-training)を行うことが出来ました。

(学習時のTraining Loss)

おわりに

本記事では、gpt-ossをNeMoを利用して学習するための方法について解説を行いました。
LLMの学習と聞くと一見華やかに思えるかもしれませんが、実態はこのようなSoftware Engineeringの塊であり、論文で書かれているようなことよりも実装面で苦労していることの方が多いのが実態です。

huggingface smol-trainining-playbookにも次のような文言があるように、論文には載らない多数の試行錯誤がLLM、VLM開発の裏では行われています。

The reality is messier, more iterative, and full of decisions that don’t make it into the final paper.

Tech Blog - Turing

Discussion