🏇

【ML-Agents】Unity×強化学習でドローンを自動操縦したい

2023/06/12に公開

完成モデル

UnityでML-Agentを用いて、ドローンを自動操縦するモデルを作成しました。


ずっと見てると酔いそう

開発環境

  • OS: macOS Ventura Version13.2.1
  • ML-Agent Version release 20

What is ML-Agents?

ML-AgentsとはUnity Machine Learning Agentsのことで、Unity を使って機械学習を行うことのできるフレームワークです。

以下より、「Release 20」の「download」からダウンロードし、Unityのパッケージマネージャーからインポートします。

https://github.com/Unity-Technologies/ml-agents

ML-Agentsによる強化学習

1. 強化学習の各要素

項目 内容
行動 ドローンの左右移動/前後移動 / 上昇 / 下降 / 回転
報酬 加算:チェックポイントの通過、減算:壁への衝突
観察 ドローンのRigidbodyの速度ベクトル / 現在の回転角度

2. 学習設定ファイル

「学習設定ファイル」は、学習に利用する「ハイパーパラメータ」を設定するファイルです。「ハイパーパラメータ」とは、学習アルゴリズムの持つパラメータの中で人が調節しないといけないパラメータのことです。
「ml-agents-release_20」フォルダの「config」フォルダ直下に、「drone_training.yaml」ファイルを作成し、以下を記述します。

config/drone_training.yaml
behaviors:
  DroneAgent:
    trainer_type: ppo # Use PPO (Proximal Policy Optimization) for training algorithm
    hyperparameters:
      batch_size: 128 
      buffer_size: 2048 # Experience Replay Buffer Size
      learning_rate: 3.0e-5 
      beta: 0.01 # Target value of KL divergence
      epsilon: 0.2 # Clipping parameters for PPO
      lambd: 0.95 # GAE (Generalized Advantage Estimation) lambda parameter
      num_epoch: 3 # epoch number
      learning_rate_schedule: linear # Scheduling Methods for Learning Rates
    network_settings:
      normalize: false # No input normalization
      hidden_units: 1024 # Number of hidden layer units
      num_layers: 3 # Number of hidden layers
    reward_signals:
      extrinsic:
        gamma: 0.99 # Discount rate of compensation
        strength: 1.0 # Reward weighting
    keep_checkpoints: 5 # Number of checkpoints to save
    max_steps: 5.0e5 # Maximum number of training steps
    time_horizon: 128 # Maximum number of time steps during training
    summary_freq: 10000 # Frequency of summary output
    threaded: true # Enable parallel training

2. Playerスクリプト

Player.cs
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;

// This is a public class called Player that inherits from another class called Agent
public class Player : Agent
{
    // These are public fields used to control the player's movement and tilt
    public float moveSpeed = 20f;
    public float rotSpeed = 100f;
    public float verticalForce = 20f;
    public float forwardTiltAmount = 0;
    public float sidewaysTiltAmount = 0;
    public float tiltVel = 2f;

    // These are private fields used for physics calculations and keeping track of checkpoints hit
    private Rigidbody playerRb;
    private float tiltAng = 45f;
    private int cptCount;

    // This method is called once when the agent is initialized. It finds the Rigidbody component on the game object.
    public override void Initialize()
    {
        playerRb = GetComponent<Rigidbody>();
    }

    // This method is called each time a new episode begins. It resets the player's position and velocity, as well as the checkpoint count.
    public override void OnEpisodeBegin()
    {
        cptCount = 0;
        transform.position = new Vector3(0, 10.0f, 0);
        transform.rotation = Quaternion.identity;
        playerRb.velocity = Vector3.zero;
        playerRb.angularVelocity = Vector3.zero;

        // Applies force to the player in the direction of (0, 200, 200) in global space.
        playerRb.AddForce(transform.TransformDirection(new Vector3(0, 200.0f, 200.0f)));
    }

    // This method is called whenever the player collides with a trigger collider.
    private void OnTriggerEnter(Collider other)
    {
        // If the collider has the tag "Checkpoint", it awards a reward and increases the checkpoint count. If the count reaches 4, it ends the episode and awards more points.
        if (other.CompareTag("CheckPoint"))
        {
            AddReward(1.0f);
            cptCount++;
            if (cptCount >= 4)
            {
                AddReward(10.0f);
                EndEpisode();
            }
        }
        // If the collider has the tag "Wall", it penalizes the player and ends the episode.
        else if (other.CompareTag("Wall"))
        {
            AddReward(-5.0f);
            EndEpisode();
        }
    }

    // This method is called each time the agent senses its environment, such as when the observation data is collected for the neural network.
    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(playerRb.velocity);
        sensor.AddObservation(transform.rotation.eulerAngles);
    }

    // This method is called each time the agent receives an action from the neural network.
    public override void OnActionReceived(ActionBuffers actions)
    {
        // These are the continuous actions received in the buffer
        float horInput = actions.ContinuousActions[0];
        float verInput = actions.ContinuousActions[1];
        float upInput = actions.ContinuousActions[2];
        float downInput = actions.ContinuousActions[3];
        float rotInput = actions.ContinuousActions[4];

        // Calculates and applies force to move the player based on the input values.
        Vector3 moveDirection = new Vector3(horInput, 0, verInput) * moveSpeed;
        playerRb.AddForce(transform.TransformDirection(moveDirection));

        // Applies upward or downward force to the player based on the input values.
        if (upInput > 0)
        {
            playerRb.AddForce(Vector3.up * verticalForce * upInput);
        }

        if (downInput > 0)
        {
            playerRb.AddForce(Vector3.down * verticalForce * downInput);
        }

        // Applies rotational force to the player based on the input values.
        transform.Rotate(0, rotInput * rotSpeed * Time.fixedDeltaTime, 0);

        // Applies forward and sideways tilt to the player based on the input values.
        sidewaysTiltAmount = Mathf.Lerp(sidewaysTiltAmount, -horInput * tiltAng, tiltVel * Time.fixedDeltaTime);
        forwardTiltAmount = Mathf.Lerp(forwardTiltAmount, verInput * tiltAng, tiltVel * Time.fixedDeltaTime);

        Quaternion targetRot = Quaternion.Euler(forwardTiltAmount, 0, sidewaysTiltAmount);
        transform.localRotation = targetRot;
    }

    // This method is called during training when using a player-controlled agent. It maps user input to the action buffer.
    public override void Heuristic(in ActionBuffers actionsOut)
    {
        float horInput = Input.GetAxis("Horizontal");
        float verInput = Input.GetAxis("Vertical");
        float upInput = Input.GetKey(KeyCode.Space) ? 1 : 0;
        float downInput = Input.GetKey(KeyCode.LeftControl) ? 1 : 0;
        float rotInput = Input.GetAxis("Mouse X");

        ActionSegment<float> continuousAct = actionsOut.ContinuousActions;
        continuousAct[0] = horInput;
        continuousAct[1] = verInput;
        continuousAct[2] = upInput;
        continuousAct[3] = downInput;
        continuousAct[4] = rotInput;
    }
}
  1. 上記では、壁のタグをWall、チェックポイントのタグをCheckPointに設定しています。
  2. Playerがゲーム内のトリガーコライダーと衝突するたびに、OnTriggerEnterメソッドが呼び出され、タグがCheckPointのオブジェクトに衝突した場合、報酬を追加し、cptCountを増やします。
  3. すべてのチェックポイントを通過した場合、ゴール報酬を追加し、エピソードを終了します。
    一方、タグがWallのオブジェクトに衝突した場合、ペナルティ報酬を追加し、エピソードを終了します。

学習を開始する

以下のコマンドより、pythonの開発環境を開始します。

python -m venv env

作成した環境をアクティブにします。

source env/bin/activate

ml-agentsのインストールを行います。

pip install mlagents

ml-agents-release_20フォルダに移動して、以下のコマンドを実行します。

mlagents-learn config/drone_training.yaml --run-id=drone_1

学習結果

Tensorboardで報酬の推移を確認。だいたい20万回のエピソードで、最大報酬である14に収束した。

感想

学習率を高くすると、ネットワークのパラメータ更新量が大きくなるため、収束までの時間を短縮することができる。しかし、学習率が非常に高い場合、更新量が大きすぎて目的の最適解を通り越して、局所的な最適解に落ち着いてしまう。
一方、学習率を低くすると、更新量が小さくなるため、収束までの時間は長くなるが、局所的な最適解から脱出することができる。
学習率を適切な値に調整するのが難しかった。

GitHubで編集を提案

Discussion