🏪

論文実装:AIエージェントによる店内移動

に公開

論文実装:AIエージェントによる店内移動

今回の論文

Sari Sandbox - A Virtual Retail Store Environment for Embodied AI Agentsという論文を読んで、AIエージェントによる店内移動という内容が興味深かったので、

まず、論文の内容ですが、ざっくり次のような感じです。(英語があまり得意ではないので、違っていたらご指摘ください。)

概要

本研究は、実店舗の小売環境を模倣した仮想空間「Sari Sandbox」を構築し、エンボディドAIエージェントの学習・評価を可能にすることを目的としている。エージェントは棚の間を移動し、商品を探索・認識・操作するタスクを実行できる。環境は高い再現性を持ち、複雑な小売シナリオをシミュレートすることで、ロボティクスやマルチモーダルAI研究の新たなベンチマークとなる。

目的・背景

小売分野では、店舗内の自動化(在庫確認、商品補充、顧客案内など)が求められている。現実環境でのロボット学習にはコストやリスクが伴うため、仮想空間での学習が重要視されている。既存のエンボディドAI研究用環境(例:AI2-THOR、Habitat)は家庭やオフィスに焦点を当てており、小売特有の課題(多様な商品、密集した棚配置、複雑な動線)を反映できていない。そこで小売に特化したシミュレーション環境の開発が必要とされた。

調査・実験・検証内容

  • 環境設計

    • 棚、SKU(商品)、ゴールエリアを配置。

    • エージェントは移動、商品操作などの行動を実行可能。

    • 状態は視覚情報とメタデータで提供される。

  • タスク設定

    • 商品探索、ピッキング、棚操作など。

    • 部分的観測に基づく意思決定を必要とする。

  • 評価

    • 強化学習やプランニング手法を適用し、ベースライン性能を測定。

    • 複数の難易度・シナリオでエージェントの汎用性を検証。

実施内容

1) 環境の要点(sari_sandbox.py

  • 行動(離散6):0:noop, 1:up, 2:down, 3:left, 4:right, 5:pick

  • 成功条件:全SKUを1つずつピック後、右下ゴールに到達

  • 観測(Dict):

    • grid (H,W,3):擬似RGB(ch0=レイアウト/エンティティ, ch1=SKU ID, ch2=ピック済み)

    • inventory (K,):所持SKU数

    • goal (2,):正規化ゴール座標

  • 報酬(例):

    • ステップ負報酬(時間コスト)

    • noopペナルティ

    • 正しいpick報酬

    • 壁/棚衝突ペナルティ

    • 全回収後のゴール到達報酬

主要コード抜粋

# 環境設定
@dataclass
class TaskConfig:
    grid_size: int = 9
    n_shelves: int = 10
    n_skus: int = 3
    max_steps: int = 200
    pick_radius: int = 1
    # 報酬
    reward_step: float = -0.005
    reward_noop: float = -0.01
    reward_pick: float = 1.0
    reward_goal: float = 2.0
    reward_wrong_pick: float = -0.1
    reward_wall: float = -0.02
    with_obstacles: bool = True

class SariSandboxEnv(gym.Env):
    # 行動空間・観測空間を定義
    self.action_space = spaces.Discrete(6)
    grid_space = spaces.Box(low=0, high=255,
                            shape=(N, N, 3), dtype=np.uint8)
    inv_space  = spaces.Box(low=0, high=1, shape=(self.cfg.n_skus,), dtype=np.float32)
    goal_space = spaces.Box(low=0, high=1, shape=(2,), dtype=np.float32)
    self.observation_space = spaces.Dict(grid=grid_space, inventory=inv_space, goal=goal_space)

    # 代表的な遷移
    def _move(self, action) -> float:
        # 境界外や棚セルなら壁ペナ
        ...
        return 0.0

    def _pick(self) -> float:
        """マンハッタン距離 <= pick_radius 内の棚から未回収SKUを優先して1つピック"""
        self._last_picked_idx = -1
        ...
        if target:   # 正しいピック
            y, x, idx = target
            self.inventory[idx] += 1
            self._sku_map[y, x] = 0
            self._last_picked_idx = idx
            return self.cfg.reward_pick
        else:        # 周囲に対象SKUがないならミス扱い
            return self.cfg.reward_wrong_pick

    def step(self, action):
        reward = self.cfg.reward_step
        if action in (1,2,3,4): reward += self._move(action)
        elif action == 5:       reward += self._pick()
        else:                   reward += self.cfg.reward_noop  # 待機にコスト
        ...
        if self._all_picked() and self._at_goal():
            reward += self.cfg.reward_goal
            terminated = True
        ...
        return obs, reward, terminated, truncated, info

補足:棚はランダム配置(開始/ゴールは避ける)、棚セル上にSKUを巡回配置。make_env()TaskConfigを引き渡せるようにしており、grid_sizepick_radiusなどを容易にスイープできます。

2) まずは最小実行(main.py

ランダム移動+たまにpickでも、render()inv=[...]が増えるのを確認できます。

# python main.py
from sari_sandbox import make_env
env = make_env(seed=42)
obs, _ = env.reset()
print(env.render())

for t in range(100):
    a = env.action_space.sample()
    if t % 3 == 0:
        a = 5  # こまめに pick
    obs, r, term, trunc, _ = env.step(a)
    if term or trunc:
        break

print(env.render())  # inv=[...] が 1つ以上ならOK

実行すると、次のような結果がでます。これが初期状態です。
エージェント(A)と、ピックアップ対象のSKU 1, 2, 3、ゴール(G)が配置されています。

そして、コードのrange(100)の通り、100回エージェントが移動したあとの結果として、次のような結果がでます。

この結果だと、100回の処理でinvで[1,0,1]なので、SKU1と3はPICKできたけど、2はできていない。
ゴールとは遠い位置で止まっている。

という結果がでます。

3) デバッグ用の貪欲方策(debug_episode.py

「最寄りのSKUへ最短で寄る。射程内ならpick」という単純方策で、1ステップごとのログとレンダリングを確認できます。壁・棚に阻まれるケースや、pick_radiusの効き方を体感しやすいです。

def nearest_target(env):
    ay, ax = env.agent_pos
    coords = np.argwhere(env._sku_map > 0)
    if coords.size == 0:
        return env.goal_pos
    dists = [abs(int(y)-ay) + abs(int(x)-ax) for (y, x) in coords]
    return tuple(coords[int(np.argmin(dists))])

def greedy_step(env):
    ay, ax = env.agent_pos
    ty, tx = nearest_target(env)
    if abs(ay-ty) + abs(ax-tx) <= env.cfg.pick_radius:
        return 5  # pick
    # マンハッタン距離が縮む方向を優先(進入不可はスキップ)
    for a, (dy, dx) in [(1,(-1,0)), (2,(1,0)), (3,(0,-1)), (4,(0,1))]:
        ny, nx = ay+dy, ax+dx
        if (ty<ay and a==1) or (ty>ay and a==2) or (tx<ax and a==3) or (tx>ax and a==4):
            if 0 <= ny < env.cfg.grid_size and 0 <= nx < env.cfg.grid_size and env._grid_layout[ny, nx] != 1:
                return a
    return 0  # noop

# 実行
env = make_env(seed=42)
obs, _ = env.reset()
print("t=-1", env.render(), sep="")
for t in range(200):
    a = greedy_step(env)
    obs, r, term, trunc, info = env.step(a)
    y, x = info["pos"]
    picked = info["picked_idx"]
    picked_str = f" picked=SKU{picked+1}" if picked >= 0 else ""
    print(f"t={t:03d} act={a} r={r:+.3f} pos=({y},{x}) inv={info['inventory'].tolist()}{picked_str}")
    print(env.render())
    if term or trunc:
        break

デバッグのプログラムだと、1ステップずつ処理を可視化します。
上記の例だとrangeを200にしていますが、流石に長いので下記の結果は、3にした場合です。

actで、downやrightなどエージェントが動いている様子を把握できます。

4) PPOで学習(train_ppo.py

さらに、これを効率化させるための学習がこちらです。
観測は画像(grid)+ベクトル(inventory, goal)の複合なので、Stable-Baselines3のMultiInputPolicyに合わせてカスタム特徴抽出器を用意します。

# pip install stable-baselines3 torch gymnasium

import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import torch as th
import torch.nn as nn

from sari_sandbox import SariSandboxEnv, TaskConfig

# 画像CNN + ベクトルMLP を結合
class GridAndVecExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict, features_dim: int = 256):
        super().__init__(observation_space, features_dim)
        grid_space = observation_space["grid"]
        C, H, W = grid_space.shape[-1], grid_space.shape[0], grid_space.shape[1]

        self.cnn = nn.Sequential(
            nn.Conv2d(C, 16, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(16, 32, 2, 2, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2, 1), nn.ReLU(),
            nn.Flatten(),
        )
        with th.no_grad():
            dummy = th.zeros((1, C, H, W))
            cnn_out = self.cnn(dummy).shape[1]

        vec_dim = observation_space["inventory"].shape[0] + observation_space["goal"].shape[0]
        self.mlp_vec = nn.Sequential(nn.Linear(vec_dim, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU())
        self.linear = nn.Sequential(nn.Linear(cnn_out + 64, features_dim), nn.ReLU())

    def forward(self, obs):
        grid = obs["grid"].float() / 255.0      # (B,H,W,C)
        grid = grid.permute(0, 3, 1, 2)         # (B,C,H,W)
        vec  = th.cat([obs["inventory"].float(), obs["goal"].float()], dim=1)
        return self.linear(th.cat([self.cnn(grid), self.mlp_vec(vec)], dim=1))

def make_train_env(seed=42, cfg):
    def _mk():
        return SariSandboxEnv(TaskConfig(cfg), seed=seed)
    return _mk

def train(total_timesteps=100_000, n_envs=4, seed=42):
    env = make_vec_env(make_train_env(seed=seed, grid_size=9, n_shelves=10, n_skus=3, max_steps=200), n_envs=n_envs)

    policy_kwargs = dict(
        features_extractor_class=GridAndVecExtractor,
        features_extractor_kwargs=dict(features_dim=256),
        net_arch=[256, 256],
    )

    model = PPO(
        policy="MultiInputPolicy",
        env=env,
        learning_rate=3e-4,
        n_steps=2048 // n_envs * n_envs,  # 2048をn_envsで割り切る
        batch_size=1024,
        n_epochs=4,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        policy_kwargs=policy_kwargs,
        seed=seed,
        verbose=1,
        device="auto",
    )

    model.learn(total_timesteps=total_timesteps)
    model.save("./models/sari_ppo.zip")
    print("Saved ./models/sari_ppo.zip")

if __name__ == "__main__":
    train()

学習したもの("/models/sari_ppo.zip")を反映して、動かすのですが、あまり動きが改善されていないのは、上手く学習ができていないのかもしれません。

コード全体

コードは下記でまとめています。

https://github.com/karasu1982/article_sari_sandbox

おわりに

本稿では、論文で提示されるリッチなVR環境をいきなり完全再現するのではなく、2Dグリッドの最小タスクへ落とし込みました。

学習のところが、まだまだ甘いので、もう少し上手く実装できるようにしたいです。

とはいえ、2Dグリッドであれば、これくらい簡単に実装できるんだというのをまとめられたかと思います。

Discussion