YOLOv8を使って野球のバッティング動作の骨格推定

2024/12/29に公開

はじめに

先日開催されたソフトバンクホークスのデータサイエンス業務の体験会が福岡県の筑後で開催されていて、それに参加してきました。(交通費負担がかなりきつかった)
https://www.softbankhawks.co.jp/news/detail/202400637726.html

その中のセッションの一つに野球の動作の姿勢の推定してみるというものがあり、自分でも出来そうだなと思ったので、実装してみました。

使用したモデルと環境

  • Model : YOLOv8
  • 環境 : Google colab (CPUを使って推定できるのでGPUを使う必要はありません)

参考にしたページなど

このページでは画像を用いて姿勢推定していました。
https://zenn.dev/collabostyle/articles/21ebb9ac52c744
お茶の水女子大学のこのページも非常に参考になりました。
https://tsuchidalab.jp/basic/yolov8/

YOLOv8

詳しい説明は、公式のサイトがあるので以下を参考にされると良いと思います。
https://docs.ultralytics.com/ja/models/yolov8/#supported-tasks-and-modes
YOLOv8ができることは

  • セグメンテーション
  • 骨格の推定
  • 物体検出

などです。骨格推定をするならOpenPoseもあるのですが、個人的に新しいもの好きなのと、実装が簡単そうということ、計算が鬼早くてCPUで済むことと、OpenPoseは商用利用が許されていないこと、諸々を踏まえて使いませんでした。(あとシンプルに骨格の線が太くてダサい。。。)

使用する動画

今回は、大谷翔平選手のバッティングフォームを使用します。
https://youtube.com/shorts/OfZermjPzqo?si=bDWxB0Vo3cfbmXpS

骨格推定の実装

1.Google Driveのマウント

Google Driveに入力動画や出力動画を保存するので、まずはマウントします。

from google.colab import drive
drive.mount('/content/drive')
%cd ./drive/MyDrive

2.必要なライブラリのインストール

ライブラリをimportします。

try:
  from ultralytics import YOLO
except:
  %pip install ultralytics >& /dev/null
  from ultralytics import YOLO
import cv2
import csv
import os
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.animation import FuncAnimation

3. 使用するモデルの選択

YOLOv8は姿勢推定の学習済みモデルとして

  • yolov8n-pose.pt
  • yolov8s-pose.pt
  • yolov8m-pose.pt
  • yolov8l-pose.pt
  • yolov8x-pose.pt
  • yolov8x-pose-p6.pt

があります。下に行くほど精度はいいですが、学習に時間がかかります。
使いたいモデルをコメントアウトすることで使えるようにします。

model=YOLO(
    #'yolov8n-pose.pt'
    'yolov8s-pose.pt'
    #'yolov8m-pose.pt'
    #'yolov8l-pose.pt'
    #'yolov8x-pose.pt'
    #'yolov8x-pose-p6.pt'
)

今回は、yolov8s-pose.ptを使用します。

[補足]3.1 変数の管理

変数の管理をConfigクラスを作ってまとめることが多いのでそのようにしていますが、もっと上手い方法はあると思います。

class CFG:
  video_name='Ohtani'# ここを変えることでプロジェクトを変更させる
  root_path=f'/content/drive/MyDrive/dataset/{video_name}'
  input_video_path=f'{root_path}/InputVideo.mp4'
  output_video_path=f'{root_path}/OutputVideo.mp4'
  csv_path=f'{root_path}/Output.csv'

  KEYPOINTS_NAMES=[
    'nose',
    'eye(L)','eye(R)','ear(L)','ear(R)',
    'shoulder(L)','shoulder(R)','elbow(L)','elbow(R)',
    'wrist(L)','wrist(R)','hip(L)','hip(R)',
    'knee(L)','knee(R)','ankle(L)','ankle(R)'
    ]

  connections = [
      ('nose', 'eye(L)'), ('nose', 'eye(R)'), ('eye(L)', 'ear(L)'), ('eye(R)', 'ear(R)'),
      ('nose', 'shoulder(L)'), ('nose', 'shoulder(R)'), ('shoulder(L)', 'elbow(L)'),
      ('shoulder(R)', 'elbow(R)'), ('elbow(L)', 'wrist(L)'), ('elbow(R)', 'wrist(R)'),
      ('shoulder(L)', 'shoulder(R)'), ('shoulder(L)', 'hip(L)'), ('shoulder(R)', 'hip(R)'),
      ('hip(L)', 'hip(R)'), ('hip(L)', 'knee(L)'), ('hip(R)', 'knee(R)'),
      ('knee(L)', 'ankle(L)'), ('knee(R)', 'ankle(R)')
    ]

  keypoints = {
      'nose': 0, 'eye(L)': 1, 'eye(R)': 2, 'ear(L)': 3, 'ear(R)': 4,
      'shoulder(L)': 5, 'shoulder(R)': 6, 'elbow(L)': 7, 'elbow(R)': 8,
      'wrist(L)': 9, 'wrist(R)': 10, 'hip(L)': 11, 'hip(R)': 12,
      'knee(L)': 13, 'knee(R)': 14, 'ankle(L)': 15, 'ankle(R)': 16
  }


cfg=CFG()

4.動画のダウンロード

Youtubeから動画をダウンロードします。
リンクのところに先ほどの動画のリンクを入れればダウンロードすることができます。

!pip install yt-dlp >& /dev/null

from yt_dlp import YoutubeDL

#最高の画質と音質を動画をダウンロードする
ydl_opts = {
  'format': 'best',
  'outtmpl': cfg.input_video_path, #ここでビデオのpathを指定する。
}

#動画のURLを指定
with YoutubeDL(ydl_opts) as ydl:
    ydl.download(['https://youtube.com/shorts/OfZermjPzqo?si=bDWxB0Vo3cfbmXpS'])

5. 動画のフレームレートの変更と再生

動画のフレームレートを30に変更します。

!ffmpeg -y -i {cfg.input_video_path} -vf "fps=30" {cfg.root_path}/fp30.mp4

次に、ダウンロードした動画を再生してみます。

### 動画を再生して確認する
from IPython.display import HTML
from base64 import b64encode

mp4 = open( f'{cfg.root_path}/fp30.mp4', 'rb').read()
data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()
HTML(f"""
<video width="50%" height="50%" controls>
      <source src="{data_url}" type="video/mp4">
</video>""")

すると、動画が再生されます。(動画はアップロードできないのでここでの図示は割愛します。gifに変換すれば載せられるようですが、3MB以下にする必要があるみたいです。)

6.モデルの関数の実装

まずは、ビデオの情報を受け取ってを主力する関数を設定します。

def setup_video_writer(capture,output_path):
  fourcc=cv2.VideoWriter_fourcc(*'mp4v')
  fps=capture.get(cv2.CAP_PROP_FPS)
  width=int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  height=int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  return cv2.VideoWriter(output_path, fourcc, fps, (width, height))

次に、骨格の座標をcsvファイルに出力する関数を設定します。

def write_pose2csv(csv_path,frame_count,keypoints,confs):
  row=[frame_count]
  for index,keypoint in enumerate(zip(keypoints,confs)):
    x,y=int(keypoint[0][0]),int(keypoint[0][1])
    score=keypoint[1]
    row.extend([x,y,score])
  with open(csv_path,mode='a',newline='') as file:
    writer = csv.writer(file)
    writer.writerow(row)

次に、フレームにキーポイントの名前と骨格を描画する関数を設定します。

def draw_keypoints(frame, keypoints, confs):
    """フレームにキーポイントと骨格を描画する"""
    coords=[]#座標を格納するリストを先に作っておく

    for index, keypoint in enumerate(zip(keypoints, confs)):
        x, y = int(keypoint[0][0]) ,int(keypoint[0][1])
        score = keypoint[1]
        #print(f'{cfg.KEYPOINTS_NAMES[index]}のkeypointの座標とスコアは{keypoint}')
        if score >= 0.5:
            cv2.circle(frame, (x, y), 5, (255, 0, 255), -1)
            cv2.putText(frame, cfg.KEYPOINTS_NAMES[index], (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 255), 1, cv2.LINE_AA)
        coords.append((x,y))#keypointの順番にx座標とy座標をタプル型で追加していく
    #print(coords)

    #for connection in cfg.connections:
    for i, connection in enumerate(cfg.connections):
      if i not in [0, 1, 2, 3 ,4 ,5]:#耳は推定しにくいから外したいのでconnectionの中で鼻より上の接続は省く
      #print(connection)
        nodeA , nodeB = cfg.keypoints[connection[0]] , cfg.keypoints[connection[1]]#keypointsの引数に文字を指定することで返り値がキーポイントに割り当てられている整数になる
        cv2.line(frame, coords[nodeA], coords[nodeB], (0, 255, 0), 2)
    return frame

最後に、modelを使ってキーポイントの推定をする関数を設定します。


def process_video(input_video_path, output_video_path, csv_path):
    """動画をフレームごとに処理し、姿勢情報を取得してCSVと動画に保存する"""
    capture = cv2.VideoCapture(input_video_path)
    video_writer = setup_video_writer(capture, output_video_path)

    with open(csv_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        # ヘッダー行を書き込む
        header = ["frame"]
        for name in cfg.KEYPOINTS_NAMES:
            header.extend([f"{name}_x", f"{name}_y", f"{name}_score"])
        writer.writerow(header)

    frame_count = 0

    while capture.isOpened():
        success, frame = capture.read()
        if not success:
            break

        # 推論を実行
        results = model(frame)

        if results[0].keypoints.conf!=None: #フレームにキーポイントがない場合は推論を実行しない
          if len(results[0].keypoints) > 0:
              keypoints = results[0].keypoints

              confs = keypoints.conf[0].tolist()  # 推論結果:1に近いほど信頼度が高い
              xys = keypoints.xy[0].tolist()  # 座標

              # 姿勢情報をCSVファイルに書き出す
              write_pose2csv(csv_path, frame_count, xys, confs)
              # キーポイントと骨格をフレームに描画する
              frame = draw_keypoints(frame, xys, confs)


        # フレームに骨格情報を描画したものを動画に書き出す
        video_writer.write(frame)
        frame_count += 1

    capture.release()
    video_writer.release()
    cv2.destroyAllWindows()

6.1 補足

keypointの構造は先ほどのコードをコメントアウトすると形状がわかるのですが、鼻を0、右足首を16として、16個のリストがあり、それぞれの要素一つ一つに[(x座標の値、y座標の値) , スコア]が格納されています。

7.推論の実行

先ほど設定した関数を使って、推論を実行します。%%timeを使って実行時間も測っておきます。

%%time
process_video(f'{cfg.root_path}/fp30.mp4', cfg.output_video_path, cfg.csv_path)

結果は以下のように出力され、実行に2分もかかりませんでした。

CPU times: user 1min 31s, sys: 3.6 s, total: 1min 35s
Wall time: 1min 43s

出力は動画で出てくるのですが、ここには結果の動画の一部の画像を載せます。

右足首が少しずれていますが、概ね予測できています。事前学習なしでこの精度はすごいですね。

8.csvファイルから骨格だけを抽出した動画の作成

先ほど保存したkeypointの座標から、keypointだけの動きを図示してみます。

df=pd.read_csv(cfg.csv_path)

def visualize_df(df):
  fig=plt.figure(figsize=(10,8))
  ax=fig.add_subplot(111)

  images=[]

  keypoint_names = [col.split('_')[0] for col in df.columns[1::3]]#ここでスコア以外の座標の情報を取得

  for frame in tqdm(range(len(df))):
    x_coord=df.iloc[frame,1::3]
    y_coord=df.iloc[frame,2::3]

    # フレームごとのartistsを格納するリスト
    frame_artists = []

    scatter=ax.scatter(x_coord,y_coord,c='blue')#キーポイントをプロット
    frame_artists.append(scatter)

    for i , name in enumerate(keypoint_names):
      text=ax.text(x_coord[i], y_coord[i], name, fontsize=9, ha='right')#ここで名前をプロットする
      frame_artists.append(text)

    for i, connection in enumerate(cfg.connections):
      if i not in [0, 1, 2, 3 ,4 ,5]:#例によってここは線から外す
        x1, y1 = x_coord[cfg.keypoints[connection[0]]], y_coord[cfg.keypoints[connection[0]]]
        x2, y2 = x_coord[cfg.keypoints[connection[1]]], y_coord[cfg.keypoints[connection[1]]]
        dx, dy = x2 - x1, y2 - y1
        arrow=ax.arrow(x1, y1, dx, dy, head_width=5, head_length=5, fc='gray', ec='gray', length_includes_head=True)
        frame_artists.append(arrow)

    title=ax.text(0.5, 1.01, 'frame ratio is {:.2f}'.format(frame/len(df)),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
    
    images.append(frame_artists+[title])

  plt.xlabel('X Coordinates')
  plt.ylabel('Y Coordinates')

  plt.grid(True)
  plt.gca().invert_yaxis()  # 原点を左上にするためにy軸を反転する
  plt.axis('equal')  # 縦軸と横軸のスケールを同じにする

  ani=animation.ArtistAnimation(fig,images, interval=50, blit=True, repeat_delay=1000) # interval, blit, repeat_delay を設定

  ani.save(f'{cfg.root_path}/skelton.gif', writer="ffmpeg")


  plt.show()

これを実行すると下のようなgifファイルが生成されます。

大谷選手は左打ちなのでフォルースルーの際に右側が隠れてしまってうまく予測できていませんが、概ねできているようです。

精度の良さに驚くばかりですね!!

9.骨格のキーポイントが通過した軌跡を可視化する

骨格を推定しただけではあまりスイングの特徴をつかめません。そこで、キーポイントが通った軌跡を可視化しようと思います。
(キーポイント可視化について書いている記事が見当たらず、書くのにとても苦労しました。)

%%time
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np

# Keypoints to visualize (replace with your list)
keypoints = cfg.KEYPOINTS_NAMES

# Assuming df_coords is your DataFrame containing keypoint data
fig, ax = plt.subplots()
ax.set(xlim=[-200, 500], ylim=[0, 500])

scat_plots = {}



for keypoint in keypoints:
  x_coords = df[f"{keypoint}_x"]
  y_coords = df[f"{keypoint}_y"]
  scat_plots[keypoint] = ax.scatter(x_coords[1], y_coords[1], label=keypoint, c=f'C{len(scat_plots)}',s=8)


bone_lines = {}
for start, end in cfg.connections:
    bone_lines[(start, end)] = ax.arrow(
        0, 0, 0, 0, color='gray', linewidth=1, head_width=3, head_length=3
    )

ax.legend()



def update(frame):
  for keypoint, scat in scat_plots.items():
    x = df[f"{keypoint}_x"][:frame]
    y = df[f"{keypoint}_y"][:frame]
    data = np.stack([x, y]).T
    scat.set_offsets(data[-10:])#ここの数字を変えることで何frame分の軌跡を表示するかを決められる

  for (start, end), arrow in bone_lines.items():
      start_x = df[f"{start}_x"][frame]  # Access single value for current frame
      start_y = df[f"{start}_y"][frame]
      end_x = df[f"{end}_x"][frame]
      end_y = df[f"{end}_y"][frame]
      dx = end_x - start_x
      dy = end_y - start_y
      arrow.set_data(
          x=start_x, y=start_y, dx=dx, dy=dy
      )  # Update arrow with dx and dy



  return list(scat_plots.values()) + list(bone_lines.values())
plt.xlabel('X Coordinates')
plt.ylabel('Y Coordinates')
plt.title(f'{cfg.video_name} : batting')
plt.grid(True)
plt.gca().invert_yaxis()  # 原点を左上にするためにy軸を反転する
plt.axis('equal')

ani = animation.FuncAnimation(fig=fig, func=update, frames=len(df), interval=30*4)
#intervalで再生速度を変更できる。元のframeが30なので、4倍したら4分の1で再生されるようになる
ani.save(f'{cfg.root_path}/trajectory.gif', writer="ffmpeg")
plt.show()

これを実行すると、以下のようなgifが出力されます。

これで、骨格とキーポイントが通過した軌跡を確認することができました。
軌跡を見る限りは、割とアッパースイングで、打ち出してから右足が着くまでは重心が軸足に寄っていて、右足がついた後、腰から順に体が連動して動いているのがわかります。

10.最後に

今回は大谷選手の骨格の推定までを行いました。
キーポイントで見るとやはり体の動き方が見易いですね。
これを参考に自分のスイングもより良くしていきたいです。

Discussion