🏇
【ML-Agents】Unity×強化学習でドローンを自動操縦したい
完成モデル
UnityでML-Agentを用いて、ドローンを自動操縦するモデルを作成しました。
ずっと見てると酔いそう
開発環境
- OS: macOS Ventura Version13.2.1
- ML-Agent Version release 20
What is ML-Agents?
ML-AgentsとはUnity Machine Learning Agentsのことで、Unity を使って機械学習を行うことのできるフレームワークです。
以下より、「Release 20」の「download」からダウンロードし、Unityのパッケージマネージャーからインポートします。
ML-Agentsによる強化学習
1. 強化学習の各要素
項目 | 内容 |
---|---|
行動 | ドローンの左右移動/前後移動 / 上昇 / 下降 / 回転 |
報酬 | 加算:チェックポイントの通過、減算:壁への衝突 |
観察 | ドローンのRigidbodyの速度ベクトル / 現在の回転角度 |
2. 学習設定ファイル
「学習設定ファイル」は、学習に利用する「ハイパーパラメータ」を設定するファイルです。「ハイパーパラメータ」とは、学習アルゴリズムの持つパラメータの中で人が調節しないといけないパラメータのことです。
「ml-agents-release_20」フォルダの「config」フォルダ直下に、「drone_training.yaml」ファイルを作成し、以下を記述します。
config/drone_training.yaml
behaviors:
DroneAgent:
trainer_type: ppo # Use PPO (Proximal Policy Optimization) for training algorithm
hyperparameters:
batch_size: 128
buffer_size: 2048 # Experience Replay Buffer Size
learning_rate: 3.0e-5
beta: 0.01 # Target value of KL divergence
epsilon: 0.2 # Clipping parameters for PPO
lambd: 0.95 # GAE (Generalized Advantage Estimation) lambda parameter
num_epoch: 3 # epoch number
learning_rate_schedule: linear # Scheduling Methods for Learning Rates
network_settings:
normalize: false # No input normalization
hidden_units: 1024 # Number of hidden layer units
num_layers: 3 # Number of hidden layers
reward_signals:
extrinsic:
gamma: 0.99 # Discount rate of compensation
strength: 1.0 # Reward weighting
keep_checkpoints: 5 # Number of checkpoints to save
max_steps: 5.0e5 # Maximum number of training steps
time_horizon: 128 # Maximum number of time steps during training
summary_freq: 10000 # Frequency of summary output
threaded: true # Enable parallel training
2. Playerスクリプト
Player.cs
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
// This is a public class called Player that inherits from another class called Agent
public class Player : Agent
{
// These are public fields used to control the player's movement and tilt
public float moveSpeed = 20f;
public float rotSpeed = 100f;
public float verticalForce = 20f;
public float forwardTiltAmount = 0;
public float sidewaysTiltAmount = 0;
public float tiltVel = 2f;
// These are private fields used for physics calculations and keeping track of checkpoints hit
private Rigidbody playerRb;
private float tiltAng = 45f;
private int cptCount;
// This method is called once when the agent is initialized. It finds the Rigidbody component on the game object.
public override void Initialize()
{
playerRb = GetComponent<Rigidbody>();
}
// This method is called each time a new episode begins. It resets the player's position and velocity, as well as the checkpoint count.
public override void OnEpisodeBegin()
{
cptCount = 0;
transform.position = new Vector3(0, 10.0f, 0);
transform.rotation = Quaternion.identity;
playerRb.velocity = Vector3.zero;
playerRb.angularVelocity = Vector3.zero;
// Applies force to the player in the direction of (0, 200, 200) in global space.
playerRb.AddForce(transform.TransformDirection(new Vector3(0, 200.0f, 200.0f)));
}
// This method is called whenever the player collides with a trigger collider.
private void OnTriggerEnter(Collider other)
{
// If the collider has the tag "Checkpoint", it awards a reward and increases the checkpoint count. If the count reaches 4, it ends the episode and awards more points.
if (other.CompareTag("CheckPoint"))
{
AddReward(1.0f);
cptCount++;
if (cptCount >= 4)
{
AddReward(10.0f);
EndEpisode();
}
}
// If the collider has the tag "Wall", it penalizes the player and ends the episode.
else if (other.CompareTag("Wall"))
{
AddReward(-5.0f);
EndEpisode();
}
}
// This method is called each time the agent senses its environment, such as when the observation data is collected for the neural network.
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(playerRb.velocity);
sensor.AddObservation(transform.rotation.eulerAngles);
}
// This method is called each time the agent receives an action from the neural network.
public override void OnActionReceived(ActionBuffers actions)
{
// These are the continuous actions received in the buffer
float horInput = actions.ContinuousActions[0];
float verInput = actions.ContinuousActions[1];
float upInput = actions.ContinuousActions[2];
float downInput = actions.ContinuousActions[3];
float rotInput = actions.ContinuousActions[4];
// Calculates and applies force to move the player based on the input values.
Vector3 moveDirection = new Vector3(horInput, 0, verInput) * moveSpeed;
playerRb.AddForce(transform.TransformDirection(moveDirection));
// Applies upward or downward force to the player based on the input values.
if (upInput > 0)
{
playerRb.AddForce(Vector3.up * verticalForce * upInput);
}
if (downInput > 0)
{
playerRb.AddForce(Vector3.down * verticalForce * downInput);
}
// Applies rotational force to the player based on the input values.
transform.Rotate(0, rotInput * rotSpeed * Time.fixedDeltaTime, 0);
// Applies forward and sideways tilt to the player based on the input values.
sidewaysTiltAmount = Mathf.Lerp(sidewaysTiltAmount, -horInput * tiltAng, tiltVel * Time.fixedDeltaTime);
forwardTiltAmount = Mathf.Lerp(forwardTiltAmount, verInput * tiltAng, tiltVel * Time.fixedDeltaTime);
Quaternion targetRot = Quaternion.Euler(forwardTiltAmount, 0, sidewaysTiltAmount);
transform.localRotation = targetRot;
}
// This method is called during training when using a player-controlled agent. It maps user input to the action buffer.
public override void Heuristic(in ActionBuffers actionsOut)
{
float horInput = Input.GetAxis("Horizontal");
float verInput = Input.GetAxis("Vertical");
float upInput = Input.GetKey(KeyCode.Space) ? 1 : 0;
float downInput = Input.GetKey(KeyCode.LeftControl) ? 1 : 0;
float rotInput = Input.GetAxis("Mouse X");
ActionSegment<float> continuousAct = actionsOut.ContinuousActions;
continuousAct[0] = horInput;
continuousAct[1] = verInput;
continuousAct[2] = upInput;
continuousAct[3] = downInput;
continuousAct[4] = rotInput;
}
}
- 上記では、壁のタグを
Wall
、チェックポイントのタグをCheckPoint
に設定しています。 - Playerがゲーム内のトリガーコライダーと衝突するたびに、
OnTriggerEnter
メソッドが呼び出され、タグがCheckPoint
のオブジェクトに衝突した場合、報酬を追加し、cptCount
を増やします。 - すべてのチェックポイントを通過した場合、ゴール報酬を追加し、エピソードを終了します。
一方、タグがWall
のオブジェクトに衝突した場合、ペナルティ報酬を追加し、エピソードを終了します。
学習を開始する
以下のコマンドより、pythonの開発環境を開始します。
python -m venv env
作成した環境をアクティブにします。
source env/bin/activate
ml-agentsのインストールを行います。
pip install mlagents
ml-agents-release_20
フォルダに移動して、以下のコマンドを実行します。
mlagents-learn config/drone_training.yaml --run-id=drone_1
学習結果
Tensorboardで報酬の推移を確認。だいたい20万回のエピソードで、最大報酬である14に収束した。
感想
学習率を高くすると、ネットワークのパラメータ更新量が大きくなるため、収束までの時間を短縮することができる。しかし、学習率が非常に高い場合、更新量が大きすぎて目的の最適解を通り越して、局所的な最適解に落ち着いてしまう。
一方、学習率を低くすると、更新量が小さくなるため、収束までの時間は長くなるが、局所的な最適解から脱出することができる。
学習率を適切な値に調整するのが難しかった。
Discussion