🦃

日本語VLM「Heron-NVILA」公開 ─ Qwen2.5-VL-7B・Gemma3-12Bに匹敵する性能

に公開

はじめに

チューリングの横井です。チューリングでは視覚と言語を統合的に理解できるAIを自動運転に応用するため、Vision Language モデル(VLM)「Heron」の開発に取り組んでいます。このたび、経済産業省およびNEDOが推進する日本の生成AIの開発力強化に向けたプロジェクト「GENIAC」第2期の支援のもと開発したVLM「Heron-NVILA」15B, 2B, 1B, 33Bを公開しました。

https://huggingface.co/turing-motors/Heron-NVILA-Lite-15B

この記事では開発したHeron-NVILAのアーキテクチャ、学習内容、ベンチマーク評価を紹介します。

モデルアーキテクチャ

Heron-NVILAのアーキテクチャは名前の通りNVIDIAが提案したVLMであるNVILAを用いています。
NVILA_arch

https://arxiv.org/abs/2412.04468

NVILAは 「Vision Encoder → Projector(2 層 MLP)→ LLM」 という 3 段構成を取りつつ

  • Scale ─ 画像を入力解像度のタイルに分割して、細部まで取り込む
  • Compress ─ 空間 / 時間方向でトークンを絞り込む

という Scale-then-Compress 戦略を取り入れることで、性能を維持しつつ計算コストの削減を実現しています。

上記の論文の図では「Projector → Compress → LLM」の順に見えますが、公式リポジトリでは圧縮処理がProjectorブロックの冒頭に組み込まれており、データは「Compress → Projector → LLM」の順で流れます。

実際のデータの流れは以下のイメージです。

入力画像
 ▼
Scale
 └ 画像のアスペクト比を極力保つようにリサイズし、画像を 448x448 などVision Encoderの入力解像度のタイルに分割
   ・元のアスペクト比に近いタイル構成を選び、その構成が作れるサイズに画像をリサイズするため、完全に元のアスペクト比が維持されるわけではない
   ・Liteでは画像全体を入力解像度にリサイズしたものも1タイルとして使う
 ▼
Vision Encoder
 └ SigLIP-ViT(patch 14×14, hidden 1152)などで各タイルをエンコード
   ・448x448の解像度でパッチサイズが14x14なら、(448/14) x (448/14) = 32 x 32 = 1024 個のパッチトークンが出力されます
 ▼
Compress(Spatial-to-Channel 圧縮)
 └ 2 × 2 圧縮 / 3 × 3 圧縮
   2 × 2 圧縮の場合:
   ・隣接する 2 × 2 = 4 のトークンを連結 → 1 トークン化  
   ・トークン数:32 x 32 / 4 → 256
   ・各トークンの埋め込み次元:1152 × 4 = 4608
   ・32×32など1タイルのトークンが 2 や 3 で割り切れない場合はゼロパディングで割り切れるようにする
 ▼
Projector
 └ 2 層 MLP で LLM へ射影
   ・出力形状は各タイルごとに [256 (圧縮されたトークン数), 4096 (LLMの隠れ層次元数)]のようになる
 ▼
LLM 
 └ 視覚トークンとテキストトークンを連結して自己回帰でテキストを生成
   ・入力は「<1タイル分の視覚トークン>\n<1タイル分の視覚トークン>\n...<テキストトークン>」のようなイメージ

論文では、元の画像サイズに加えて (448, 448)、(896, 896)、(1344, 1344) にリサイズした画像でもScale処理を行うマルチスケール対応の Dynamic-S2 が紹介されていますが、Heron-NVILAではこのマルチスケール化を行わないLite方式を採用しています。論文によると、Liteの方がわずかな性能低下はありますが、マルチスケール対応は学習・推論時間が増加するため実用性を重視してLiteを選択しました。そのため、モデル名はHeron-NVILA-LiteとLiteを明記しています。なお、NVILA-Liteの詳細な説明は2025/4時点でのarXiv論文には記載されておらず、公式リポジトリのissue#167で簡単に触れられているのみです。

Heron-NVILAのVision EncoderとLLMは公式リポジトリで使用している以下のモデルを採用しました:

LLMについては、日本語特化のLlama-3.1-Swallow-8Bを用いて少ないステップ数での学習を試しましたが、Heron Benchにおける評価スコアがQwen2.5-7B-Instructを用いた方が高かったためQwen2.5を採用しました。
公式のNVILA-Liteの重みは公開されていますが、ライセンスがCreative Commons Attribution-NonCommercial 4.0で商用利用が許可されていません。そのため、公式のNVILA-Liteの重みは使用せず、上記のVision EncoderとLLMから学習を行いました。

Projectorの2層MLPについてはランダムな初期重みを用いています。
学習時間を短縮するため、Heron-NVILA-Lite-15Bでは 3×3 圧縮を使用し、Heron-NVILA-Lite-1BおよびHeron-NVILA-Lite-2Bでは 2×2 圧縮を使用しました。3×3 圧縮は視覚トークン数を1/9に削減できるため学習・推論の高速化やメモリ使用量の削減に有効ですが、大幅なトークン数の削減によってProjectorの学習を難しくなり、2×2 圧縮に比べて性能低下することが論文に記載されています。(LLMの計算コストの一つは、トークン間の関連性を計算するSelf-Attention。Self-AttentionはQueryとKeyの内積をトークン数Nの全組み合わせについて計算するため、トークン数の2乗に比例して計算量が増加する。このため、LLMの入力トークン数が増えるほど学習・推論時間が増加します。また、Self-AttentionのAttention MatrixのサイズはN × Nであるため、必要なVRAMも同じくトークン数の2乗に比例して増加します。)

学習

ABCI 生成AIハッカソンで日本語VLMを作成しましたの記事と同じように、オープンソースの日本語・英語データセットを用いて継続事前学習を行いました。NVILAの論文では5ステージの学習(学習データや学習率を段階的に変更し、計5回の学習を実施)を採用していますが、5ステージ学習に対応する日本語データの確保が困難であったため、NVILAの前身であるVILAの3ステージ学習を採用しました。

以下の表が各学習ステージの概要です。

学習ステージ 役割 学習対象 学習データ 学習データ件数
Stage 1 Alignment Projector のみ学習 Japanese image text pairsLLaVA-Pretrain 1.1M
Stage 2 Pretraining Projector と LLM のみ学習 Filtered MOMIJI (CC-MAIN-2024-46, CC-MAIN-2024-51, CC-MAIN-2025-05) 13M
Japanese image text pairs (subset)Japanese interleaved data (subset)mmc4-core (subset)coyo-700m (subset)wikipedia_jallava_pretrain_jastair_captions 20M
Stage 3 Supervised Fine-tuning (SFT) Vision Encoder, Projector, LLM すべて学習 llava-instruct-v1_5-en-subset-358kllava-instruct-jajapanese-photos-convja-vg-vqasynthdog-ja (subset)ai2dsynthdog-ensherlock 1.1M

データ

データの工夫としては、Stage2のPretrainingにCommon Crawlの日本語データであるMOMIJIを追加した点です。 このデータはGENIACの成果物です。

今回の学習では以下の最小限の前処理を施しています。

  • テキスト長フィルタ:本文が 5,000 文字を超えるサンプルを除外
  • 画像サイズフィルタ:縦・横いずれかが 10 px 未満、または 16,000 px を超える画像を持つデータを除外
  • 画像挿入位置を変更CLIP ViT-H/14で画像とテキストの類似度を計算し、類似度が最も高いテキストの直前に画像を配置

前処理には多くの改善の余地があります。社内で検証した範囲でも、性能向上の可能性が確認できました。

学習データの詳細は以下です。

Stage1の学習データ

Data Type: 画像-短いキャプション データ

  • Japanese image text pairs (558k)

    • 画像データ: CommonCrawl from 2020-2022
    • テキスト: alt-text ベース
  • LLaVA-Pretrain (595K)

    • 画像データ: CC3M
    • テキスト: BLIPによる合成テキスト
Stage2の学習データ

Data Type: 画像-短いキャプション データ, 画像-長いキャプションデータ, 画像-テキストインターリーブデータ

  • MOMIJI (13M subset)

    • 画像データ: CommonCrawl from 2024/02 - 2025/01
    • テキスト: Web文書 (image-textが交互に配置) ベース
  • Japanese image text pairs (6M subset)

    • 画像データ: CommonCrawl from 2020-2022 (COYOのやり方を踏襲)
    • テキスト: alt-text ベース
  • Japanese interleaved data (2M subset)

    • 画像データ: CommonCrawl from 2020-2022 (MMC4 のやり方を踏襲)
    • テキスト: Web文書 (image-textが交互に配置)ベース
      • 画像の位置は、CLIPによる類似度で決定する
  • coyo-700m (6M subset)

    • 画像データ: CommonCrawl from 2020/10-2021/08
    • テキスト: alt-text ベース
  • mmc4-core (4M subset)

    • 画像データ: CommonCrawl from 2018-2019
    • テキスト: Web文書 (image-textが交互に配置)ベース
      • 画像の位置は、CLIPによる類似度で決定する
  • wikipedia_ja (1.6M)

    • 画像データ: Wikipedia
    • テキスト: 画像とその周辺にあるテキストベース
  • llava_pretrain_ja (595K)

    • 画像データ: CC3M
    • テキスト: BLIPによる合成テキストをDeepLで日本語訳したもの
  • stair_captions (414k)

    • 画像データ: COCO2017
    • テキスト: COCOの画像に対して人手で説明文を付与したもの
Stage3の学習データ

Data Type: 画像対話データ, 画像-長いキャプションデータ, 一般的なVQA(自由記述解答), 一般的なVQA(選択肢解答), GQA(位置関係), OCR(画像文字読み取り), BBox推定

  • llava-instruct-v1_5-en-subset-358k (358k subset)

    • 画像データ:
      Dataset Images 画像抽出元
      LLaVA 158K COCO2017
      VQAv2 53K COCO2014
      GQA 46K COCO2014, Flickr
      OCRVQA 80K Original
      TextVQA 22K Open Images v3
    • テキスト:
      • type1: conversation (LLaVA)
      • type2: detailed description (LLaVA)
      • type3: complex reasoning (LLaVA)
      • type4: VQA (VQAv2; 一般的なQA)
      • type5: GQA (GQA; 位置関係)
      • type6: OCR (OCRVQA, TextVQA; OCR)
  • llava-instruct-ja (156k)

    • 画像データ:
      Dataset Images 画像抽出元
      LLaVA 158K COCO2017
    • テキスト:
      • type1: conversation (LLaVA)
      • type2: detailed description (LLaVA)
      • type3: complex reasoning (LLaVA)
  • japanese-photos-conv (11.8k)

  • ja-vg-vqa (98.7k)

    • 画像データ: Visual Genome Dataset
    • テキスト: GPT-4o を使用した対話データ
  • synthdog-ja (66.3k)

    • 画像データ: Original
    • テキスト: 翻訳テキストデータ
  • synthdog-en (65.5k)

    • 画像データ: Original
    • テキスト: 翻訳テキストデータ
  • AI2D (3.09k)

    • 画像データ: Original
    • テキスト: 選択式 QA データ
  • sherlock (317k)

    • 画像データ: Vizual Genome Dataset + Original Dataset
    • テキスト: 対話データ (bbox推定, brief VQA)

Heron-NVILAの学習データは商用利用可能とみなせるデータだけを使用しています。ただし、SFTでGPT-4oの出力を学習データに使用しているため、Heron-NVILAの重みはOpenAIの利用規約に従う必要があります。

コード

Heron-NVILAの学習コードは公式リポジトリを活用しました。MOMIJIなどの独自データを使うためにデータセットクラスを追加、連続学習時間制限を削除、チャットテンプレートのシステムプロンプトを日本語にするなど軽微な変更は入れていますが、lossなど学習の主要部分は公式実装を流用しました。

追加したデータセットクラスのサンプルコード
# VILA/llava/data/dataset.py
class LazyMOMIJIDataset(Dataset):
    def __init__(
        self,
        data_path: str,
        image_folder: str,
        tokenizer: transformers.PreTrainedTokenizer,
        data_args: DataArguments,
        training_args: TrainingArguments,
        image_following_text_only: bool = False,
        text_only: bool = False,
    ):
        super().__init__()

        pkl_files = sorted(glob.glob(f"{data_path}/**/*.pkl", recursive=True)) + sorted(
            glob.glob(f"{data_path}/*.pkl")
        )
        count_files = sorted(glob.glob(f"{data_path}/**/*.count", recursive=True)) + sorted(
            glob.glob(f"{data_path}/*.count")
        )
        pkl_files = list(dict.fromkeys(pkl_files))
        count_files = list(dict.fromkeys(count_files))

        if data_args.mm4c_n_shards_start != -1:
            s, e = data_args.mm4c_n_shards_start, data_args.mm4c_n_shards_end
            count_info_list = count_files[s:e]
        else:
            count_info_list = count_files

        n_samples = [
            int(open(os.path.join(data_path, f)).read().strip()) for f in count_info_list
        ]
        print("total MOMIJI samples", sum(n_samples))

        pgm = get_pg_manager()
        seq_parallel = training_args.seq_parallel_size if pgm else 1
        rank = training_args.process_index // seq_parallel
        world_size = training_args.world_size // seq_parallel
        shared_size = len(count_info_list) // world_size

        gpu_samples = [sum(n_samples[i * shared_size : (i + 1) * shared_size]) for i in range(world_size)]
        self.n_samples = min(gpu_samples) * world_size
        self.idx_offset = rank * min(gpu_samples)
        shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size
        print(f"* loading data from shard {shard_start}-{shard_end}")

        shard_names = [
            d.replace(".count", ".pkl") for d in count_info_list[shard_start:shard_end]
        ]

        self.data_list, self.data_path_list = [], []
        for shard_name in shard_names:
            with open(os.path.join(data_path, shard_name), "rb") as f:
                part = pickle.load(f)
            self.data_list.extend(part)
            self.data_path_list.extend([os.path.join(data_path, shard_name)] * len(part))

        print(f"* loaded totally {len(self.data_list)} samples")

        self.tokenizer = tokenizer
        self.data_args = data_args
        self.image_folder = image_folder
        self.image_following_text_only = image_following_text_only
        self.text_only = text_only

    def __len__(self):
        return self.n_samples

    @property
    def modality_lengths(self):
        lengths = []
        for info in self.data_list:
            num_images = min(6, len(info["image_info"]))
            sentences = [info["text_list"][x["matched_text_index"]] for x in info["image_info"][:num_images]]
            cur_len = num_images * self.num_image_tokens // 2 + sum(len(s) for s in sentences)
            lengths.append(cur_len)
        return lengths

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        info = self.data_list[i - self.idx_offset]
        sentences = [s.replace("<image>", "<IMAGE>") for s in info["text_list"]]
        sim_matrix = info["similarity_matrix"]

        images, sentence_ixs = [], []
        if not self.text_only:
            for sample_image, _ in zip(info["image_info"], sim_matrix):
                rawbytes = base64.b64decode(sample_image["image_base64"])
                image = Image.open(io.BytesIO(rawbytes)).convert("RGB")
                images.append(image)
                sentence_ixs.append(sample_image["matched_text_index"])

        max_num_images = 6
        if len(images) > max_num_images:
            images, sentence_ixs = images[:max_num_images], sentence_ixs[:max_num_images]

        images = [images[idx] for idx in np.argsort(sentence_ixs)]
        for ix in sentence_ixs:
            sentences[ix] = f"<image>\n{sentences[ix]}"

        if self.image_following_text_only:
            text = self.tokenizer.pad_token.join(sentences)
        else:
            text = " ".join(sentences)
        text = text.replace("<image> ", "<image>").replace(" <image>", "<image>")
        text = f"{text}{self.tokenizer.eos_token}"

        if images:
            if self.data_args.image_aspect_ratio == "dynamic_s2":
                images, block_sizes = dynamic_s2_process_images_and_prompt(
                    images, text, self.data_args, self.image_folder
                )
            elif self.data_args.image_aspect_ratio == "dynamic":
                images, text = dynamic_process_images_and_prompt(
                    images, text, self.data_args, self.image_folder, max_tiles=6
                )
            else:
                images = torch.stack(
                    [process_image(img, self.data_args, self.image_folder) for img in images]
                )
        else:
            images = None

        input_ids = tokenizer_image_token(text, self.tokenizer, return_tensors="pt")
        image_token_id = self.tokenizer.media_token_ids["image"]

        if input_ids[-1] == image_token_id:
            last_non_img = torch.where(input_ids != image_token_id)[0][-1] + 1
            input_ids = input_ids[:last_non_img]

        n_im_patch = (input_ids == image_token_id).sum().item()
        if self.data_args.image_aspect_ratio != "dynamic_s2":
            images = images[:n_im_patch]
            assert len(images) == n_im_patch

        if self.tokenizer.bos_token_id is not None and input_ids[0] != self.tokenizer.bos_token_id:
            input_ids = torch.cat([torch.tensor([self.tokenizer.bos_token_id]), input_ids])

        targets = input_ids.clone()
        if self.image_following_text_only:
            label_idx = 0
            while label_idx < targets.shape[-1] and targets[label_idx] != image_token_id:
                targets[label_idx] = IGNORE_INDEX
                label_idx += 1

            pad_token = self.tokenizer.pad_token_id
            for pad_token_idx in torch.where(targets == pad_token)[0]:
                token_idx = pad_token_idx + 1
                while token_idx < targets.shape[-1] and targets[token_idx] != image_token_id:
                    targets[token_idx] = IGNORE_INDEX
                    token_idx += 1
            targets[targets == pad_token] = IGNORE_INDEX

        data_dict = dict(input_ids=input_ids, labels=targets, image=images)
        if self.data_args.image_aspect_ratio == "dynamic_s2":
            data_dict["block_sizes"] = block_sizes
        return data_dict

学習のメインコード
# VILA/llava/train/train.py
import copy
import logging
import math
import os
import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence

import torch
import transformers
from torch.utils.data import Dataset
from transformers import AutoConfig, AutoTokenizer, HfArgumentParser, LlamaForCausalLM, set_seed
from transformers.modeling_utils import unwrap_model

import llava.data.dataset as dataset
import llava.data.datasets_mixture as datasets_mixture
from llava import conversation as conversation_lib
from llava.constants import IGNORE_INDEX
from llava.data import make_supervised_data_module
from llava.mm_utils import process_image
from llava.model import LlavaLlamaConfig, LlavaLlamaModel
from llava.model.language_model.qllava_qllama import QLlavaLlamaModel, quantize_args_to_model_class
from llava.train.args import DataArguments, ModelArguments, TrainingArguments
from llava.train.callbacks.autoresume_callback import AutoResumeCallback
from llava.train.llava_trainer import LLaVATrainer, VILADPOTrainer
from llava.train.sequence_parallel import set_pg_manager
from llava.train.slurm_utils import TimeoutTerminateCallback
from llava.train.utils import (
    get_checkpoint_path,
    mprint,
    prepare_config_for_training,
    unit_test_rope_scaling,
    vision_resolution_elevation,
)
from llava.trl.trainer.utils import DPODataCollatorWithPadding

local_rank = None

if "WANDB_PROJECT" not in os.environ:
    os.environ["WANDB_PROJECT"] = "VILA"

def get_nb_trainable_parameters(model) -> tuple[int, int]:
    r"""
    Returns the number of trainable parameters and the number of all parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        num_params = param.numel()
        if num_params == 0 and hasattr(param, "ds_numel"):
            num_params = param.ds_numel

        if param.__class__.__name__ == "Params4bit":
            if hasattr(param, "element_size"):
                num_bytes = param.element_size()
            elif not hasattr(param, "quant_storage"):
                num_bytes = 1
            else:
                num_bytes = param.quant_storage.itemsize
            num_params = num_params * 2 * num_bytes

        all_param += num_params
        if param.requires_grad:
            trainable_params += num_params

    return trainable_params, all_param

def maybe_zero_3(param, ignore_status=False, name=None):
    from deepspeed import zero
    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus

    if hasattr(param, "ds_id"):
        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
            if not ignore_status:
                logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
        with zero.GatheredParameters([param]):
            param = param.data.detach().cpu().clone()
    else:
        param = param.detach().cpu().clone()
    return param

def get_peft_state_maybe_zero_3(named_params, bias):
    if bias == "none":
        to_return = {k: t for k, t in named_params if "lora_" in k}
    elif bias == "all":
        to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
    elif bias == "lora_only":
        to_return = {}
        maybe_lora_bias = {}
        lora_bias_names = set()
        for k, t in named_params:
            if "lora_" in k:
                to_return[k] = t
                bias_name = k.split("lora_")[0] + "bias"
                lora_bias_names.add(bias_name)
            elif "bias" in k:
                maybe_lora_bias[k] = t
        for k, t in maybe_lora_bias:
            if bias_name in lora_bias_names:
                to_return[bias_name] = t
    else:
        raise NotImplementedError
    to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
    return to_return

def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
    to_return = {k: t for k, t in named_params if "lora_" not in k}
    if require_grad_only:
        to_return = {k: t for k, t in to_return.items() if t.requires_grad}
    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
    return to_return

def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
    to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
    return to_return

def find_all_linear_names(model, lora_llm, lora_vt):
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ["mm_projector", "vision_resampler"]
    assert lora_llm or lora_vt, "Not applying LoRA to any of the modules..."

    if not lora_llm:
        multimodal_keywords += ["llm"]
    if not lora_vt:
        multimodal_keywords += ["vision_tower"]

    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            if not "lm_head" in name:
                lora_module_names.add(name)

    return list(lora_module_names)

def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
    """Collects the state dict and dump to disk."""
    if trainer.deepspeed:
        torch.cuda.synchronize()
        trainer.save_model(output_dir, _internal_call=True)
        return

    state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa

def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg

def make_conv(prompt, answer):
    return [
        {
            "from": "human",
            "value": prompt,
        },
        {
            "from": "gpt",
            "value": answer,
        },
    ]

@dataclass
class DPODataCollator(DPODataCollatorWithPadding):
    tokenizer: Any = None

    def collate(self, batch):
        padded_batch = {}
        for k in batch[0].keys():
            if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
                to_pad = [torch.LongTensor(ex[k]) for ex in batch]
                if k.endswith("_input_ids"):
                    padding_value = self.pad_token_id
                elif k.endswith("_labels"):
                    padding_value = self.label_pad_token_id
                else:
                    continue

                padded_batch[k] = torch.nn.utils.rnn.pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
            else:
                padded_batch[k] = [ex[k] for ex in batch]
        for k in ["chosen_input_ids", "rejected_input_ids"]:
            attn_k = k.replace("input_ids", "attention_mask")
            padded_batch[attn_k] = padded_batch[k].ne(self.pad_token_id)
        return padded_batch

    def tokenize_batch_element(self, prompt: str, chosen: str, rejected: str) -> Dict:
        """Tokenize a single batch element.

        At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
            in case the prompt + chosen or prompt + rejected responses is/are too long. First
            we truncate the prompt; if we're still too long, we truncate the chosen/rejected.

        We also create the labels for the chosen/rejected responses, which are of length equal to
            the sum of the length of the prompt and the chosen/rejected response, with
            label_pad_token_id  for the prompt tokens.
        """
        batch = {}

        chosen_sources = make_conv(prompt, chosen)
        rejected_sources = make_conv(prompt, rejected)
        chosen_data_dict = dataset.preprocess([chosen_sources], self.tokenizer, has_image=True)

        rejected_data_dict = dataset.preprocess([rejected_sources], self.tokenizer, has_image=True)

        chosen_data_dict = {k: v[0] for k, v in chosen_data_dict.items()}
        rejected_data_dict = {k: v[0] for k, v in rejected_data_dict.items()}

        for k, toks in {
            "chosen": chosen_data_dict,
            "rejected": rejected_data_dict,
        }.items():
            for type_key, tokens in toks.items():
                if type_key == "token_type_ids":
                    continue
                batch[f"{k}_{type_key}"] = tokens
        return batch

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        tokenized_batch = []
        Xs, keys = [], []
        for feature in features:
            prompt = feature["prompt"]
            chosen = feature["chosen"]
            rejected = feature["rejected"]

            batch_element = self.tokenize_batch_element(prompt, chosen, rejected)
            batch_element["images"] = feature["images"]
            tokenized_batch.append(batch_element)

        padded_batch = self.collate(tokenized_batch)
        return padded_batch

import json

def load_jsonl(save_path):
    with open(save_path) as f:
        data = [json.loads(line) for line in f.readlines()]
    return data

def load_json(path):
    with open(path) as f:
        data = json.load(f)
    return data

def load_data(data_path):
    if "jsonl" in data_path:
        data_list = load_jsonl(data_path)
    else:
        data_list = load_json(data_path)
    return data_list

class DPODataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_mixture: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments):
        super(Dataset, self).__init__()
        data_path = datasets_mixture.DATASETS_LEGACY[data_mixture].data_path
        list_data_dict = load_data(data_path)

        print("Formatting inputs...Skip in lazy mode")
        self.tokenizer = tokenizer
        self.list_data_dict = list_data_dict
        self.data_args = data_args
        self.image_folder = datasets_mixture.DATASETS_LEGACY[data_mixture].image_path

    def __len__(self):
        return len(self.list_data_dict)

    @property
    def lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            img_tokens = 128 if "image" in sample else 0
            length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
        return length_list

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        """
        {
            'prompt': 'Is there a snowman wearing a green scarf and hat in the background?',
            'chosen': 'No, there is no snowman wearing a green scarf and hat in the background of the image. The image features a person ...',
            'rejected': 'No, there is no snowman in the background.',
            'image_path': '/mnt/bn/liangkeg/data/ruohongz/dpo_data/dpo_images/LRVInstruction-000000009569.jpg',
            'image_name': 'LRVInstruction-000000009569.jpg'
        }
        """
        data_dict = copy.deepcopy(self.list_data_dict[i])  # inplace modification following

        video_file = data_dict["video"] + ".mp4"
        video_folder = self.image_folder
        video_path = os.path.join(video_folder, video_file)
        num_video_frames = self.data_args.num_video_frames if hasattr(self.data_args, "num_video_frames") else 8
        loader_fps = self.data_args.fps if hasattr(self.data_args, "fps") else 0.0

        fps = None
        frame_count = None

        images, frames_loaded = dataset.LazySupervisedDataset._load_video(
            video_path, num_video_frames, loader_fps, self.data_args, fps=fps, frame_count=frame_count
        )

        image_tensor = torch.stack([process_image(image, self.data_args, None) for image in images])
        image_tensor = torch.stack([process_image(image, self.data_args, None) for image in images])

        data_dict["images"] = image_tensor

        prompt = data_dict["prompt"]
        prompt = prompt.replace("<video>", "").strip()
        prompt = "<image>\n" * frames_loaded + prompt
        data_dict["prompt"] = prompt

        return data_dict

def train():
    global local_rank

    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if os.getenv("RUN_NAME") is not None:
        training_args.run_name = os.getenv("RUN_NAME")
    else:
        training_args.run_name = training_args.output_dir.split("/")[-1]

    local_rank = training_args.local_rank
    compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)

    bnb_model_from_pretrained_args = {}
    if training_args.bits in [4, 8]:
        from transformers import BitsAndBytesConfig

        bnb_model_from_pretrained_args.update(
            dict(
                device_map={"": training_args.device},
                quantization_config=BitsAndBytesConfig(
                    load_in_4bit=training_args.bits == 4,
                    load_in_8bit=training_args.bits == 8,
                    llm_int8_skip_modules=["lm_head"],
                    llm_int8_threshold=6.0,
                    llm_int8_has_fp16_weight=False,
                    bnb_4bit_compute_dtype=compute_dtype,
                    bnb_4bit_use_double_quant=training_args.double_quant,
                    bnb_4bit_quant_type=training_args.quant_type,  # {'fp4', 'nf4'}
                ),
            )
        )

    set_seed(training_args.seed)

    sp_degree = training_args.seq_parallel_size
    ring_degree = training_args.seq_parallel_ring_size
    if sp_degree > 1:
        set_pg_manager(sp_degree, ring_degree, ring_type=training_args.seq_parallel_ring_type)
        print(f"Sequence parallelism is enabled, SP = {sp_degree}")

    resume_path, continue_training = get_checkpoint_path(training_args.output_dir)

    if not continue_training:
        print(f"Models has been ready under {training_args.output_dir}. Skipp training")
        exit(0)

    if resume_path:
        resume_from_checkpoint = True
        if training_args.lora_enable:
            model_cls = LlavaLlamaModel
            config = LlavaLlamaConfig.from_pretrained(model_args.model_name_or_path, resume=resume_from_checkpoint)
            config.resume_path = model_args.model_name_or_path
        else:
            config = AutoConfig.from_pretrained(resume_path, trust_remote_code=True)
            config.resume_path = resume_path
            model_cls = eval(config.architectures[0])
    else:
        resume_from_checkpoint = False
        if (
            model_args.quantize_model in quantize_args_to_model_class.keys()
        ):  # However, qmem should not used currently becuase I haven't merge the memory reduction version into VILA
            from llava.model.language_model.qllava_qllama import QLlavaLlamaModel

            model_cls = QLlavaLlamaModel
        else:
            assert (
                model_args.quantize_model == "false"
            ), f"{model_args.quantize_model} for model_args.quantize_model is not supported"
            model_cls = LlavaLlamaModel
        config = LlavaLlamaConfig.from_pretrained(model_args.model_name_or_path, resume=resume_from_checkpoint)

        if getattr(config, "resume_path", None) is not None:
            config.resume_path = model_args.model_name_or_path

    prepare_config_for_training(config, model_args, training_args, data_args)
    if model_args.quantize_model in quantize_args_to_model_class.keys():
        model = model_cls(
            config=config,
            model_args=model_args,
            attn_implementation="flash_attention_2",
            model_max_length=training_args.model_max_length,
            cache_dir=training_args.cache_dir,
            **bnb_model_from_pretrained_args,
        )
    else:
        model = model_cls(
            config=config,
            attn_implementation="flash_attention_2",
            model_max_length=training_args.model_max_length,
            cache_dir=training_args.cache_dir,
            **bnb_model_from_pretrained_args,
        )

    if not resume_path or training_args.lora_enable:
        if model_args.mlp_path is not None:
            state_dict = torch.load(model_args.mlp_path, map_location="cpu")
            state_dict_new = {}
            for k, v in state_dict.items():
                if k == "0.weight":
                    state_dict_new["layers.1.weight"] = v
                if k == "0.bias":
                    state_dict_new["layers.1.bias"] = v
                if k == "1.weight":
                    state_dict_new["layers.2.weight"] = v
                if k == "1.bias":
                    state_dict_new["layers.2.bias"] = v
                if k == "3.weight":
                    state_dict_new["layers.4.weight"] = v
                if k == "3.bias":
                    state_dict_new["layers.4.bias"] = v
            model.get_mm_projector().load_state_dict(state_dict_new)

    vision_resolution_elevation(model, config)
    if unit_test_rope_scaling(model, model.llm.config, training_args):
        return

    if 'phi' in model_args.model_name_or_path:
        model.resize_token_embeddings(len(model.tokenizer))

    mprint(model)

    model.llm.config.use_cache = False

    def need_to_modify_do_sample(generation_config):
        if generation_config is None:
            warnings.warn("generation config is None, skip do sample modification")
            return False
        if generation_config.do_sample is False:
            if generation_config.temperature is not None and generation_config.temperature != 1.0:
                return True
            if generation_config.top_p is not None and generation_config.top_p != 1.0:
                return True
        return False

    if need_to_modify_do_sample(model.llm.generation_config):
        model.llm.generation_config.do_sample = True

    if training_args.bits in [4, 8]:
        from peft import prepare_model_for_kbit_training

        model.llm.config.torch_dtype = (
            torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
        )
        model.llm = prepare_model_for_kbit_training(
            model.llm, use_gradient_checkpointing=training_args.gradient_checkpointing
        )

    if training_args.gradient_checkpointing:
        if hasattr(model.llm, "enable_input_require_grads"):
            model.llm.enable_input_require_grads()
        else:

            def make_inputs_require_grad(module, input, output):
                output.requires_grad_(True)

            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

    if training_args.lora_enable:
        from peft import LoraConfig, PeftModel, get_peft_model

        lora_config = LoraConfig(
            use_dora=training_args.use_dora,
            r=training_args.lora_r,
            lora_alpha=training_args.lora_alpha,
            target_modules=find_all_linear_names(model, training_args.lora_llm, training_args.lora_vt),
            lora_dropout=training_args.lora_dropout,
            bias=training_args.lora_bias,
            task_type="CAUSAL_LM",
        )
        if training_args.bits == 16:
            if training_args.bf16:
                model.to(torch.bfloat16)
            if training_args.fp16:
                model.to(torch.float16)
        if resume_from_checkpoint:
            if os.path.exists(os.path.join(resume_path, "non_lora_trainables.bin")):
                non_lora_trainables = torch.load(
                    os.path.join(resume_path, "non_lora_trainables.bin"),
                    map_location="cpu",
                )
                non_lora_trainables = {
                    (k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()
                }
                if any(k.startswith("model.model.") for k in non_lora_trainables):
                    non_lora_trainables = {
                        (k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()
                    }
                model.load_state_dict(non_lora_trainables, strict=False)

            mprint("Resume from checkpoint...", resume_path)
            model = PeftModel.from_pretrained(model, resume_path, is_trainable=True)
        else:
            mprint("Adding LoRA adapters...")
            model = get_peft_model(model, lora_config)
        mprint(model)
        model.print_trainable_parameters()

    if training_args.lora_enable:
        if not training_args.lora_llm:
            model.get_llm().requires_grad_(training_args.tune_language_model)
        if model.get_vision_tower():
            if training_args.lora_vt:

                def make_inputs_require_grad(module, input, output):
                    output.requires_grad_(True)

                model.get_vision_tower().vision_tower.get_input_embeddings().register_forward_hook(
                    make_inputs_require_grad
                )
            elif training_args.tune_vision_tower:
                model.get_vision_tower().requires_grad_(training_args.tune_vision_tower)
            model.get_mm_projector().requires_grad_(training_args.tune_mm_projector)
            mprint(f"mm projector {training_args.tune_mm_projector}")
            model.print_trainable_parameters()
    else:
        model.get_llm().requires_grad_(training_args.tune_language_model)
        mprint(f"Tunable parameters:\nlanguage model {training_args.tune_language_model}")
        if model.get_vision_tower():
            model.get_vision_tower().requires_grad_(training_args.tune_vision_tower)
            model.get_mm_projector().requires_grad_(training_args.tune_mm_projector)
            mprint(f"vision tower {training_args.tune_vision_tower}")
            mprint(f"mm projector {training_args.tune_mm_projector}")
            trainable_params, all_param = get_nb_trainable_parameters(model)
            print(
                f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.4f}"
            )

        if not any(
            [training_args.tune_language_model, training_args.tune_vision_tower, training_args.tune_mm_projector]
        ):
            logging.warning("You are not tuning any part of the model. Please check if this is intended.")

    tokenizer = model.tokenizer

    if tokenizer.bos_token is None:
        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=dict(bos_token="[BOS]"),
            tokenizer=tokenizer,
            model=model.llm,
        )

    tokenizer.pad_token = tokenizer.unk_token
    if tokenizer.pad_token is None:
        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=dict(pad_token="[PAD]"),
            tokenizer=tokenizer,
            model=model.llm,
        )
    if model_args.version in conversation_lib.conv_templates:
        conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
    else:
        conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]

    vision_tower = model.get_vision_tower()
    if vision_tower is not None:
        data_args.image_processor = vision_tower.image_processor
        data_args.is_multimodal = True

        if hasattr(data_args, "num_video_frames") and data_args.num_video_frames != None:
            model.config.num_video_frames = data_args.num_video_frames
        else:
            model.config.num_video_frames = 8

        if hasattr(data_args, "fps"):
            model.config.fps = data_args.fps
        else:
            model.config.fps = 0.0

        model.config.image_aspect_ratio = data_args.image_aspect_ratio
        model.config.mm_projector_lr = training_args.mm_projector_lr
        model.config.vision_tower_lr = training_args.vision_tower_lr
        if model_args.mm_use_im_start_end:
            num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
        assert not model_args.mm_use_im_patch_token

        model.config.num_time_tokens = data_args.num_time_tokens = model_args.num_time_tokens
        model.config.time_token_format = data_args.time_token_format = model_args.time_token_format
        if model_args.num_time_tokens > 0:
            time_tokens = [model.config.time_token_format.format(t=t) for t in range(model.config.num_time_tokens)]
            num_new_tokens = tokenizer.add_tokens(time_tokens)
            assert len(time_tokens) == num_new_tokens or num_new_tokens == 0
            model.resize_token_embeddings(len(tokenizer))
            model.config.time_token_ids = tokenizer.convert_tokens_to_ids(time_tokens)
        else:
            model.config.time_token_ids = []
        model.config.soft_ce_std = model_args.soft_ce_std

        num_patches = model.get_vision_tower().num_patches
        downsample_rate = model.get_mm_projector().downsample_rate
        num_image_tokens = math.ceil(num_patches**0.5 / downsample_rate) ** 2
        data_args.num_image_tokens = num_image_tokens

    if training_args.bits in [4, 8]:
        from peft.tuners.lora import LoraLayer

        for name, module in model.named_modules():
            if isinstance(module, LoraLayer):
                if training_args.bf16:
                    module = module.to(torch.bfloat16)
            if "norm" in name:
                module = module.to(torch.float32)
            if "lm_head" in name or "embed_tokens" in name:
                if hasattr(module, "weight"):
                    if training_args.bf16 and module.weight.dtype == torch.float32:
                        module = module.to(torch.bfloat16)

    data_args.s2_scales = list(map(int, model_args.s2_scales.split(",")))

    data_module = make_supervised_data_module(
        tokenizer=tokenizer,
        data_args=data_args,
        training_args=training_args,
    )

    if training_args.total_time_limit != -1:
        callbacks = [AutoResumeCallback(), TimeoutTerminateCallback(total_time_limit=training_args.total_time_limit)]  # 連続学習時間制限指定できるようにする
    else:
        callbacks = [AutoResumeCallback(), TimeoutTerminateCallback()]  # 4時間連続学習しかできない。4時間超えるとjob落ちる

    if training_args.dpo:
        ref_model = model_cls(
            config=config,
            attn_implementation="flash_attention_2",
            model_max_length=training_args.model_max_length,
            cache_dir=training_args.cache_dir,
            **bnb_model_from_pretrained_args,
        )

        train_dataset = DPODataset(tokenizer=tokenizer, data_mixture=data_args.data_mixture, data_args=data_args)

        data_collator = DPODataCollator(
            tokenizer=tokenizer,
            label_pad_token_id=IGNORE_INDEX,
            pad_token_id=tokenizer.pad_token_id,
        )
        extra_info = []
        extra_info.append(len(train_dataset))
        training_args.sample_lens = extra_info

        trainer = VILADPOTrainer(
            model=model,
            dpo_alpha=1.0,
            gamma=0,
            ref_model=ref_model,
            tokenizer=tokenizer,
            args=training_args,
            beta=training_args.dpo_beta,
            callbacks=callbacks,
            train_dataset=train_dataset,
            data_collator=data_collator,
        )
    else:
        trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args, callbacks=callbacks, **data_module)

        if model_args.quantize_model in ["fp8Activation_qwen2", "fp8ActivationResidual_qwen2"]:
            from llava.model.coat.fp8_trainer import CoatFP8Trainer

            trainer._inner_training_loop = CoatFP8Trainer._inner_training_loop.__get__(
                trainer, LLaVATrainer
            )  # GPT told me to do this

    print(
        "length of dataloader:",
        len(trainer.get_train_dataloader()),
        len(trainer.train_dataset),
        flush=True,
    )
    print(
        "[GPU memory] before trainer",
        torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
        flush=True,
    )

    trainer.train(resume_from_checkpoint=resume_from_checkpoint)

    if training_args.debug_e2e:
        exit()

    trainer.save_state()

    model.llm.config.use_cache = True
    model.config.resume_path = model.config._name_or_path = training_args.output_dir
    if training_args.lora_enable:
        state_dict = get_peft_state_maybe_zero_3(model.named_parameters(), training_args.lora_bias)
        non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters())
        if training_args.local_rank == 0 or training_args.local_rank == -1:
            model.config.save_pretrained(training_args.output_dir)
            model.save_pretrained(training_args.output_dir, state_dict=state_dict)
            torch.save(
                non_lora_state_dict,
                os.path.join(training_args.output_dir, "non_lora_trainables.bin"),
            )
    else:
        safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)

if __name__ == "__main__":
    train()

学習を実行するシェルスクリプトは公式リポジトリの https://github.com/NVlabs/VILA/tree/main/scripts/NVILA-Lite をベースにしました。
主な変更点は次のとおりです。

  • Slurm ジョブスケジューラ対応sbatch でジョブ投入できるよう変更
  • MPI 実行対応mpirun 経由での起動に合わせてランク/環境変数の設定を調整
  • DeepSpeed ZeRO-2 による分散学習 – 高速な並列学習を実現するためdeepspeed --zero_stage 2 で実行

実際に使用したシェルスクリプトは以下です。

Heron-NVILA-Lite-15BのStage1学習で用いたシェルスクリプト
#!/bin/bash
#SBATCH --job-name=align_14b
#SBATCH --time=100:00:00
#SBATCH --partition=aws-p5
#SBATCH --nodes=2
#SBATCH --gpus-per-node=8
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=24
#SBATCH --output=sbatch_outputs/%j-%x.out
#SBATCH --error=sbatch_outputs/%j-%x.out
#SBATCH --nodelist=aws-p5-st-h100-13,aws-p5-st-h100-14

DEFAULT_GLOBAL_TRAIN_BATCH_SIZE=2048
DEFAULT_GRADIENT_ACCUMULATION_STEPS=16

source scripts/setups/train_aws_v4.sh

OUTPUT_DIR="runs/train/NVILA-Lite_14b_siglip_aws_env2_obelics_ja/$SLURM_JOB_NAME"
DATA_MIXTURE="llava_1_5_mm_align_en+llm_jp_mm_pair_step0_558k"
STAGE_PATH="/data/models/Qwen/Qwen2.5-14B-Instruct"
CHAT_TEMPLATE="qwen2_jp"
VISION_TOWER="/data/models/Efficient-Large-Model/paligemma-siglip-so400m-patch14-448"

OPTIONS=" \
  --deepspeed scripts/zero2.json \
  --model_name_or_path $STAGE_PATH \
  --chat_template $CHAT_TEMPLATE \
  --data_mixture $DATA_MIXTURE \
  --vision_tower $VISION_TOWER \
  --mm_vision_select_feature cls_patch \
  --mm_projector mlp_downsample_3x3_fix \
  --tune_vision_tower False \
  --tune_mm_projector True \
  --tune_language_model False \
  --mm_vision_select_layer -2 \
  --mm_use_im_start_end False \
  --mm_use_im_patch_token False \
  --image_aspect_ratio dynamic \
  --bf16 True \
  --output_dir $OUTPUT_DIR/model \
  --num_train_epochs 1 \
  --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \
  --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
  --evaluation_strategy no \
  --save_strategy steps \
  --save_steps 500 \
  --save_total_limit 1 \
  --learning_rate 1e-3 \
  --weight_decay 0. \
  --warmup_ratio 0.03 \
  --lr_scheduler_type cosine \
  --logging_steps 1 \
  --model_max_length 4096 \
  --gradient_checkpointing True \
  --dataloader_num_workers 8 \
  --report_to tensorboard \
  --total_time_limit 6000 \
"

mpirun -np $NUM_GPUS \
  --npernode $NUM_GPU_PER_NODE \
  -x MASTER_ADDR=$MASTER_ADDR \
  -x MASTER_PORT=$MASTER_PORT \
  -x CUDA_LAUNCH_BLOCKING=1 \
  -x CUDA_DEVICE_MAX_CONNECTIONS=1 \
  -x LD_LIBRARY_PATH \
  -x PATH \
  --hostfile $hostfile \
  --mca pml ^cm,ucx \
  --mca btl tcp,self \
  --mca btl_tcp_if_exclude lo,docker0,veth_def_agent \
  --bind-to none -map-by slot \
  --mca orte_debug_daemons 1 \
  --mca plm_base_verbose 5 \
  python llava/train/train_mem_mpirun.py ${OPTIONS}
Heron-NVILA-Lite-15BのStage2学習で用いたシェルスクリプト
#!/bin/bash
#SBATCH --job-name=pretrain_14b
#SBATCH --time=600:00:00
#SBATCH --partition=aws-p5
#SBATCH --nodes=12
#SBATCH --gpus-per-node=8
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=24
#SBATCH --output=sbatch_outputs/%j-%x.out
#SBATCH --error=sbatch_outputs/%j-%x.out
#SBATCH --nodelist=aws-p5-st-h100-[1-12]

DEFAULT_GLOBAL_TRAIN_BATCH_SIZE=1152
DEFAULT_GRADIENT_ACCUMULATION_STEPS=6

source scripts/setups/train_aws_v4.sh

OUTPUT_DIR="runs/train/NVILA-Lite_14b_siglip_aws_env2_obelics_ja/$SLURM_JOB_NAME"

DATA_MIXTURE="mix_wikipedia_ja_llava_pretrain_ja_stair_captions_less_filter+llm_jp_mm_interleaved_step1_6m_v2+llm_jp_mm_pair_step1_6m+coyo_6m+mmc4core_6img_filter_text_filter+obelics_ja"

STAGE_PATH="runs/train/NVILA-Lite_14b_siglip_aws_env2_obelics_ja/align_14b/model"
VISION_TOWER="/data/models/Efficient-Large-Model/paligemma-siglip-so400m-patch14-448"

OPTIONS=" \
  --deepspeed scripts/zero2.json \
  --model_name_or_path $STAGE_PATH \
  --data_mixture $DATA_MIXTURE \
  --vision_tower $VISION_TOWER \
  --mm_vision_select_feature cls_patch \
  --mm_projector mlp_downsample_3x3_fix \
  --tune_vision_tower False \
  --tune_mm_projector True \
  --tune_language_model True \
  --mm_vision_select_layer -2 \
  --mm_use_im_start_end False \
  --mm_use_im_patch_token False \
  --image_aspect_ratio dynamic \
  --bf16 True \
  --output_dir $OUTPUT_DIR/model \
  --num_train_epochs 1 \
  --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \
  --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
  --evaluation_strategy no \
  --save_strategy steps \
  --save_steps 500 \
  --save_total_limit 10 \
  --learning_rate 5e-5 \
  --weight_decay 0. \
  --warmup_ratio 0.03 \
  --lr_scheduler_type cosine \
  --logging_steps 1 \
  --model_max_length 4096 \
  --gradient_checkpointing True \
  --dataloader_num_workers 4 \
  --report_to tensorboard \
  --total_time_limit 360000 \
"

mpirun -np $NUM_GPUS \
  --npernode $NUM_GPU_PER_NODE \
  -x MASTER_ADDR=$MASTER_ADDR \
  -x MASTER_PORT=$MASTER_PORT \
  -x CUDA_LAUNCH_BLOCKING=1 \
  -x CUDA_DEVICE_MAX_CONNECTIONS=1 \
  -x LD_LIBRARY_PATH \
  -x PATH \
  --hostfile $hostfile \
  --mca pml ^cm,ucx \
  --mca btl tcp,self \
  --mca btl_tcp_if_exclude lo,docker0,veth_def_agent \
  --bind-to none -map-by slot \
  python llava/train/train_mem_mpirun.py ${OPTIONS}
Heron-NVILA-Lite-15BのStage3学習で用いたシェルスクリプト
#!/bin/bash
#SBATCH --job-name=sft_14b_GPT4_v6
#SBATCH --time=600:00:00
#SBATCH --partition=aws-p5
#SBATCH --nodes=12
#SBATCH --gpus-per-node=8
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=24
#SBATCH --output=sbatch_outputs/%j-%x.out
#SBATCH --error=sbatch_outputs/%j-%x.out
#SBATCH --nodelist=aws-p5-st-h100-[1-12]

DEFAULT_GLOBAL_TRAIN_BATCH_SIZE=2304
DEFAULT_GRADIENT_ACCUMULATION_STEPS=12

source scripts/setups/train_aws_v4.sh

OUTPUT_DIR="runs/train/NVILA-Lite_14b_siglip_aws_env2_obelics_ja/$SLURM_JOB_NAME"

DATA_MIXTURE="llava_instruct_v1_5_en_subset+llava_instruct_ja+japanese_photos_conv+ja_vg_vqa+synthdog_ja_subset+ai2d_train_12k+synthdog_en+sherlock"

STAGE_PATH="runs/train/NVILA-Lite_14b_siglip_aws_env2_obelics_ja/pretrain_14b/model"
VERSION="qwen2_sft_jp"
VISION_TOWER="/data/models/Efficient-Large-Model/paligemma-siglip-so400m-patch14-448"

OPTIONS=" \
  --deepspeed scripts/zero2.json \
  --model_name_or_path $STAGE_PATH \
  --data_mixture $DATA_MIXTURE \
  --vision_tower $VISION_TOWER \
  --version $VERSION \
  --mm_vision_select_feature cls_patch \
  --mm_projector mlp_downsample_3x3_fix \
  --tune_vision_tower True \
  --tune_mm_projector True \
  --tune_language_model True \
  --mm_vision_select_layer -2 \
  --mm_use_im_start_end False \
  --mm_use_im_patch_token False \
  --image_aspect_ratio dynamic \
  --bf16 True \
  --output_dir $OUTPUT_DIR/model \
  --num_train_epochs 1 \
  --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \
  --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
  --evaluation_strategy no \
  --save_strategy steps \
  --save_steps 100 \
  --save_total_limit 1 \
  --learning_rate 2e-5 \
  --weight_decay 0. \
  --warmup_ratio 0.03 \
  --lr_scheduler_type cosine \
  --logging_steps 1 \
  --model_max_length 4096 \
  --gradient_checkpointing True \
  --dataloader_num_workers 4 \
  --vflan_no_system_prompt True \
  --report_to tensorboard \
  --total_time_limit 36000 \
"

mpirun -np $NUM_GPUS \
  --npernode $NUM_GPU_PER_NODE \
  -x MASTER_ADDR=$MASTER_ADDR \
  -x MASTER_PORT=$MASTER_PORT \
  -x CUDA_LAUNCH_BLOCKING=1 \
  -x CUDA_DEVICE_MAX_CONNECTIONS=1 \
  -x LD_LIBRARY_PATH \
  -x PATH \
  --hostfile $hostfile \
  --mca pml ^cm,ucx \
  --mca btl tcp,self \
  --mca btl_tcp_if_exclude lo,docker0,veth_def_agent \
  --bind-to none -map-by slot \
  python llava/train/train_mem_mpirun.py ${OPTIONS}

苦労した点

計算資源はGENIACの支援で確保したH100インスタンス16ノードを使用し、2〜15ノードでマルチノード学習を行いました。

H100インスタンスのスペック
項目 AWS (P5.48xlarge)
GPU H100 × 8
CPU (2 Socket) AMD EPYC 7R13 @ 2.65 GHz, 48 cores × 2
計 96 cores / 192 threads
Memory 2,048 GB
SSD NVMe SSD 3.84 TB × 8
Interconnect AWS EFA 100 Gbps × 32
計 3.2 Tbps
Shared Storage FSx for Lustre (最大 40.8 GB/s)

最も苦労したのはマルチノード学習の安定化で、学習ジョブがエラーログも残さず途中で停止する現象に長く悩まされました。PyTorch・CUDA・cuDNN・NCCLのバージョンをNVIDIAのドキュメントに従って揃えても症状は完全には解消せず、外れノードを引き当てている可能性に行き着きました。
最終的には、

  1. 数百stepごとにcheckpointを保存
  2. 学習が止まったらノードを再起動してインスタンスを引き直す
  3. 保存したcheckpointから学習再開

という運用でジョブを走らせました。

その他に以下のような問題も起きました。

  • ストレージ容量
    • 画像件数が膨大であることから jpg/png で画像を保持するとinode(ファイル数)上限に達するため、一部の画像は tar/pkl に固めてファイル数を削減しました。また、学習に使わなかったデータも含めるとデータだけで50TB程度ストレージ容量が必要になりました。
  • データライセンス & 取得方法
    • データセットは研究目的限定/商用不可のものも多く、画像は URL のみの提供で img2dataset などで自前クロールが必要でした。テキストと画像で取得フローが異なる場合もありダウンロード作業が煩雑でした。
  • VRAM要件
    • マルチ画像入力により VRAM 消費が増大するバッチも発生しました。ノード追加・マイクロバッチサイズ削減などで調整しました。
  • 学習の安定化
    • 画像とテキストの不整合などが原因で発生したと推定される loss spike や 極端に長いテキスト、超高解像度画像を伴うサンプルによって OOM が発生しました。学習中で問題発生 → 問題とおぼしきデータを確認・除去 → 学習やり直し を繰り返しました。

評価

定量評価

日本語VLMの性能評価ツールであるllm-jp-eval-mmで日本語ベンチマークを評価しました。

llm-jp-eval-mmで使用した評価コード
# llm-jp-eval-mm/examples/heron_nvila.py
from base_vlm import BaseVLM
from utils import GenerationConfig
import torch
from transformers import GenerationConfig as HFGenerationConfig, AutoModel
import unittest


def create_prompt(text, image):
    if image is None or (isinstance(image, list) and len(image) == 0):
        return [text] if text else []
    if not isinstance(image, list):
        image = [image]
    if not text:
        return image
    if "<image>" not in text:
        prompt = image.copy()
        prompt.append(text)
        return prompt
    parts = text.split("<image>")
    prompt, idx = [], 0
    if parts[0] == "":
        prompt.append(image[idx])
        idx += 1
        parts = parts[1:]
    for i, part in enumerate(parts):
        if part:
            prompt.append(part)
        if idx < len(image) and (i < len(parts) - 1 or text.endswith("<image>")):
            prompt.append(image[idx])
            idx += 1
    return prompt


class VLM(BaseVLM):
    def __init__(self, model_id="turing-motors/Heron-NVILA-Lite-15B"):
        self.model_id = model_id
        self.model = AutoModel.from_pretrained(
            model_id, trust_remote_code=True, device_map="auto"
        )

    def generate(self, image, text: str, gen_kwargs: GenerationConfig = GenerationConfig()):
        gen_cfg = HFGenerationConfig(**gen_kwargs.__dict__)
        prompt = create_prompt(text, image)
        with torch.no_grad():
            return self.model.generate_content(prompt, generation_config=gen_cfg)


if __name__ == "__main__":
    unittest.main(argv=["first-arg-is-ignored"], exit=False)
    VLM("turing-motors/Heron-NVILA-Lite-15B").test_vlm()

LLM-as-a-Judgeの評価モデルは「gpt-4o-2024-11-20」を使用しています。 LLMが確率的に回答を生成するためスコアは再実行ごとにわずかに変わります。例えば、Heron Benchなら±0.3未満程度でスコアが上下します。Hugging Face のモデルページにもスコアを記載していますがあちらは評価モデルが「gpt-4o-2024-05-13」での結果です。LLM-as-a-Judgeは評価モデルによってスコアが大きく変わります。
同程度のパラメータ数のモデルと比較して特に良かったベンチマークは以下です。ベンチマークの全項目は列数が多いためトグルで隠しています。Heron-NVILA以外のスコアは2025/5/3時点のllm-jp-eval-mmリーダーボードの値です。

Model Heron /LLM VG-VQA /LLM JVB-ItW /LLM MulIm-VQA /LLM JMMMU /Acc
Heron-NVILA-Lite-1B 58.2 3.4 3.4 3.0 26.4
Heron-NVILA-Lite-2B 63.6 3.7 4.0 3.0 37.0
Heron-NVILA-Lite-15B 73.5 4.0 4.4 4.3 49.5
stabilityai/japanese-stable-vlm 48.4 3.5 3.3 2.3 3.4
SakanaAI/Llama-3-EvoVLM-JP-v2 47.6 3.4 3.5 3.1 36.4
cyberagent/llava-calm2-siglip 54.1 3.6 3.7 2.8 6.1
llm-jp/llm-jp-3-vila-14b 68.0 3.9 4.1 3.5 19.0
sbintuitions/sarashina2-vision-8b 60.5 3.7 4.1 2.6 39.2
sbintuitions/sarashina2-vision-14b 60.1 3.7 4.0 2.6 43.0
MIL-UT/Asagi-14B 41.9 2.0 2.9 2.0 21.7
OpenGVLab/InternVL2-26B 59.7 3.6 3.1 3.3 39.0
Qwen/Qwen2.5-VL-7B-Instruct 70.3 3.7 4.3 4.1 48.2
google/gemma-3-12b-it 72.2 3.7 4.3 4.2 47.6
microsoft/Phi-4-multimodal-instruct 45.5 3.3 3.2 3.4 39.2
llm-jp-eval-mmの全項目
Model Heron /LLM VG-VQA /LLM VG-VQA /Rouge JIC /Acc MECHA /Acc MMMU /Acc JVB-ItW /LLM JVB-ItW /Rouge LLAVA /LLM LLAVA /Rouge JDocQA /Acc JDocQA /LLM MulIm-VQA /LLM MulIm-VQA /Rouge JMMMU /Acc
Heron-NVILA-Lite-1B 58.2 3.4 14.5 48.0 29.0 22.3 3.4 42.7 2.4 33.2 15.2 2.1 3.0 28.2 26.4
Heron-NVILA-Lite-2B 63.6 3.7 17.0 46.4 48.2 35.9 4.0 48.8 2.9 35.4 17.7 2.4 3.0 42.7 37.0
Heron-NVILA-Lite-15B 73.5 4.0 25.9 80.9 68.3 47.1 4.4 50.6 3.6 36.6 19.2 2.7 4.3 60.0 49.5
stabilityai/japanese-stable-vlm 48.4 3.5 45.7 69.3 4.8 7.1 3.3 23.2 1.4 0.6 13.1 2.1 2.3 40.7 3.4
SakanaAI/Llama-3-EvoVLM-JP-v2 47.6 3.4 24.7 67.0 50.7 38.9 3.5 48.4 2.9 27.9 15.4 2.4 3.1 44.3 36.4
cyberagent/llava-calm2-siglip 54.1 3.6 17.7 58.5 11.2 26.7 3.7 46.3 1.9 2.9 8.2 2.0 2.8 40.6 6.1
llm-jp/llm-jp-3-vila-14b 68.0 3.9 16.2 81.3 45.6 32.7 4.1 52.4 3.4 36.0 17.3 2.5 3.5 47.0 19.0
sbintuitions/sarashina2-vision-8b 60.5 3.7 25.4 78.7 56.3 29.7 4.1 44.8 2.5 16.9 22.6 3.0 2.6 30.2 39.2
sbintuitions/sarashina2-vision-14b 60.1 3.7 25.3 80.0 64.4 33.8 4.0 44.3 2.5 15.6 23.9 3.1 2.6 35.3 43.0
MIL-UT/Asagi-14B 41.9 2.0 9.3 76.2 24.1 15.3 2.9 30.9 1.6 0.1 10.4 2.0 2.0 18.4 21.7
llava-hf/llava-1.5-7b 43.1 3.0 14.0 44.0 38.2 34.0 3.0 40.8 2.9 34.5 14.8 2.2 2.5 35.8 29.6
llava-hf/llava-v1.6-mistral-7b 30.0 3.0 11.7 58.2 34.4 35.9 2.9 28.6 3.3 30.7 14.4 2.0 2.3 25.2 25.4
neulab/Pangea-7B 57.0 4.1 54.2 85.6 56.5 43.7 3.9 33.5 3.5 25.9 16.2 2.4 3.4 40.3 37.4
mistralai/Pixtral-12B-2409 60.9 3.5 13.1 61.1 56.4 48.6 3.9 38.3 3.6 31.6 14.8 2.4 4.1 34.5 18.7
meta-llama/Llama-3.2-11B-VI 38.1 3.3 14.2 78.7 49.3 38.2 3.4 30.4 3.7 30.2 17.6 2.5 2.6 24.5 34.6
Efficient-LM/VILA1.5-13B 46.9 3.2 13.0 58.2 46.4 37.0 3.5 42.5 3.6 35.0 14.8 2.2 3.2 40.0 33.5
OpenGVLab/InternVL2-8B 49.8 3.5 11.7 65.7 50.5 49.7 3.5 33.8 3.1 31.5 19.8 2.7 2.9 34.5 39.1
OpenGVLab/InternVL2-26B 59.7 3.6 11.6 73.9 51.3 48.2 3.1 26.7 3.8 30.6 15.3 2.6 3.3 45.4 39.0
Qwen/Qwen2.5-VL-7B 70.3 3.7 9.0 82.7 60.0 50.0 4.3 29.6 3.9 27.1 26.5 3.6 4.1 50.4 48.2
Qwen/Qwen2.5-VL-32B 74.8 3.8 5.2 96.0 68.4 59.1 4.3 14.7 4.0 18.9 25.2 3.8 4.6 42.2 48.8
Qwen/Qwen2.5-VL-72B 85.5 3.9 9.8 90.4 74.7 63.0 4.4 32.0 4.0 28.7 23.9 3.9 4.8 60.9 60.6
google/gemma-3-4B-it 52.8 3.4 12.5 75.4 47.2 40.7 3.7 37.1 3.6 22.1 17.6 2.6 3.7 52.7 37.0
google/gemma-3-12B-it 72.2 3.7 12.5 85.7 62.6 48.1 4.3 35.7 4.0 22.1 20.1 3.0 4.2 59.7 47.6
google/gemma-3-27B-it 69.2 3.8 10.9 88.2 67.7 56.1 4.4 30.9 3.9 21.1 20.2 3.1 4.3 56.3 50.5
microsoft/Phi-4-multimodal 45.5 3.3 19.0 52.3 46.4 53.7 3.2 26.8 3.4 29.5 22.9 2.9 3.4 42.3 39.2
gpt-4o-2024-11-20 93.7 3.9 11.8 95.8 83.7 56.1 4.4 32.2 4.1 29.8 22.0 3.6 4.8 62.5 57.5
各列の説明
  • Model: モデル名
  • Heron/LLM: GPT-4o(gpt-4o-2024-11-20) で評価したJapanese Heron Bench の総合スコア。Heron Bench は日本の風景や文化を写した画像を説明できるかを評価。
  • VG-VQA/LLM, Rouge: JA-VG-VQA500 を GPT-4o(gpt-4o-2024-11-20) で評価したスコアとROUGE-L 類似度。JA-VG-VQA500 はVisual Genome の日本語 VQA サブセット500 問。
  • JIC/Acc: JIC-VQA の正解率。JIC-VQA は日本の食事・花・施設などを写した画像に対する多肢選択問題。
  • MECHA/Acc: MECHA-ja の正解率。MECHA-ja は日本国内で撮影された画像に対する多肢選択問題。
  • MMMU/Acc: MMMU の正解率。MMMU は大学レベルの知識と推論能力を評価する多肢選択問題。
  • JVB-ItW/LLM, Rouge: JA-VLM-Bench-In-the-Wild を GPT-4o(gpt-4o-2024-11-20) で評価したスコアとROUGE-L 類似度。JA-VLM-Bench-In-the-Wild はLLaVA-Bench-In-the-Wild の日本語版。
  • LLAVA/LLM, Rouge: LlaVA-Bench-In-the-Wild を GPT-4o(gpt-4o-2024-11-20) で評価したスコアとROUGE-L 類似度。LlaVA-Bench-In-the-Wild は風景写真やミーム、絵画など実世界の画像を説明できるかを評価する。
  • JDocQA/Acc, LLM: JDocQA の正解率と GPT-4o(gpt-4o-2024-11-20) で評価したスコア。JDocQA は日本語PDF から答えを生成するドキュメント QA。
  • MulIm-VQA/LLM, Rouge: JA-Multi-Image-VQA を GPT-4o(gpt-4o-2024-11-20) で評価したスコアとROUGE-L 類似度。JA-Multi-Image-VQA は2〜4 枚の画像を見て答えるマルチイメージ VQA。
  • JMMMU/Acc: JMMMU の正解率。JMMMU は日本版 MMMU。

定性評価

Heron-Benchの質問に対するHeron-NVILAの回答と筆者の主観評価を残します。

  • 〇 = 質問意図を正しく汲み取り、内容も概ね的確
  • × = 要求に合っていない、もしくは明らかな誤答

質問: 信号機の色は何色ですか?

主観評価 Model 回答
Heron-NVILA-Lite-15B 信号機の色は赤色です。
Heron-NVILA-Lite-2B 信号機の色は赤色です。
Heron-NVILA-Lite-1B 信号機の色は赤色です。

質問: こたつと言えば、何の果物が思い浮かびますか?

主観評価 Model 回答
Heron-NVILA-Lite-15B こたつと言えば、ミカンが思い浮かびます。ミカンは冬の果物として人気があり、こたつの上に置かれたミカンのボウルは、温かい部屋でリラックスしながら楽しむための典型的なシーンを表しています。
Heron-NVILA-Lite-2B こたつと言えば、一般的にはみかんが思い浮かびます。この画像では、こたつの上にいくつかのオレンジが置かれていますが、みかんである可能性が高いです。みかんなのは、冬の寒い季節に温かいお茶やお菓子と一緒に楽しむのにぴったりな果物です。また、こまごまとした果物が置かれたテーブルは、和風の雰囲気を醸し出しており、居心地の良い空間を演出しています。
× Heron-NVILA-Lite-1B こたつと言えば、一般的にオレンジが思い浮かびます。この画像では、こたつの上にオランジェットが置かれています。オレンヂは、温かみのある果物として人気があり、特に冬の季節に温かい飲み物として好まれます。また、こだわりのこたつなデザインは、居心地の良さやリラックスした雰囲気を演出するため、家族や友人と一緒に過ごす場面にぴったりです。このように、オレンチーや果物は、日本の伝統的な生活スタイルや家庭の温かさを象徴する要素として、多くの人に愛されています。

質問: この場所における制限速度はいくつですか?

主観評価 Model 回答
Heron-NVILA-Lite-15B この場所の制限速度は40キロメートル毎時(40 km/h)です。
× Heron-NVILA-Lite-2B この場所の制限速度は23キロメートル/時です。
× Heron-NVILA-Lite-1B この場所では、制限速度は23キロメートル/時(約40キロ)です。

質問: 左の真ん中あたりには何が書かれていますか?

主観評価 Model 回答
Heron-NVILA-Lite-15B 左の真ん中あたりには「雪舟」と書かれています。
× Heron-NVILA-Lite-2B 左の真ん中あたりには「守備」と書かれています。
× Heron-NVILA-Lite-1B 左の真ん中あたりには「夢」の文字が書かれています。

4件だけですが、〇の数とパラメータ数は比例しており 15 B > 2 B > 1 B で性能差があることがわかります。

おわりに

GENIAC第2期の支援のもと開発したVLM Heron-NVILA を紹介しました。モデルは Hugging Face でダウンロードできるのでぜひ触ってみてください!!!

Heron-NVILA-Lite-2B を搭載した以下のようなiOSアプリも公開予定です。

https://x.com/ymg_aq/status/1904824462421156113

アプリの解説記事はこちらです。

https://zenn.dev/turing_motors/articles/8eeda949dc7e2a

最後に宣伝ですが、私の所属するチューリングの基盤AIチームは自動運転への活用を目指してVLMや動画生成モデルなどさまざまな基盤AIの開発を行っています。この記事を読んで興味を持っていただけた方は応募検討してもらえると嬉しいです。カジュアル面談やオフィス見学などライトな交流も大歓迎です!!!
詳細についてはチューリングの採用情報をご覧ください。
https://tur.ing/jobs

Tech Blog - Turing

Discussion