🐥

Keras-rl2で始める強化学習(ブロック崩し)

7 min read

はじめに

今回は、keras-rl2を使ってあの忌々しいブロック崩しを強化学習させようと思います。
今回のブロック崩しプログラムは授業で扱ったものをそのままPythonに移しているため、バグが多々有りますが、無視をします。文句は教授に。
備忘録なので期待はしないでください。
ではまず、modelを作りましょう。

modelをつくる

今回はmodelをnetwork.pyに記述します。

from tensorflow.keras import models
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, InputLayer
from tensorflow.keras.optimizers import Adam
from rl.memory import SequentialMemory
from rl.policy import BoltzmannQPolicy
from rl.agents.dqn import DQNAgent

class Network():
    def __init__(self, load):
        self.model = Sequential([Flatten(input_shape=(1,15)),Dense(16,activation='relu'),Dense(16,activation='relu'),Dense(16,activation='relu'),Dense(3,activation='linear')])
	self.dqn_agent()
        
        if(load):
            self.load()

    def dqn_agent(self):
        memory = SequentialMemory(limit=50000, window_length=1)
        policy = BoltzmannQPolicy()
        self.dqn = DQNAgent(model=self.model,nb_actions=3,gamma=0.99,memory=memory,nb_steps_warmup=100,target_model_update=1e-2,policy=policy)
        self.dqn.compile(Adam(lr=1e-3), metrics=['mae'])

    def fit(self,env,nb_steps):
        self.dqn.fit(env,nb_steps=nb_steps,visualize=True,verbose=1)
    
    def test(self,env):
        self.dqn.test(env,nb_episodes=10,visualize=True)

    def save(self):
        self.model.save('weight/model.h5')
        self.dqn.save_weights('dqn_weight',overwrite=True)

    def load(self):
        self.model = models.load_model('weight/model.h5')
        self.dqn.load_weights('dqn_weight')

全体像はこんな感じです。
今回、モデルへの入力データとして、10個のブロックが健在しているか?とボールのX,Y座標、ボールのx,y方向へのスピード、ボールを打ち返すラケットのx座標を渡します。
そのため入力データは15個になります。
また、停止、右、左、の3つの移動をラケットにさせたいため、出力データは3つになります。

self.model = Sequential([Flatten(input_shape=(1,15)),Dense(16,activation='relu'),Dense(16,activation='relu'),Dense(16,activation='relu'),Dense(3,activation='linear')])

DQNの設定はこんな感じです。

    def dqn_agent(self):
        memory = SequentialMemory(limit=50000, window_length=1)
        policy = BoltzmannQPolicy()
        self.dqn = DQNAgent(model=self.model,nb_actions=3,gamma=0.99,memory=memory,nb_steps_warmup=100,target_model_update=1e-2,policy=policy)
        self.dqn.compile(Adam(lr=1e-3), metrics=['mae'])

gammaの値を小さくすると時間に対して報酬が減衰する量が増えます。
また、出力データが3つなのでnb_actionsは3です。

ゲーム環境をつくる

env.pyに記述します

import gym
import numpy as np
from processing_py import *
import processing_py as pp
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import models
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, InputLayer
from rl.memory import SequentialMemory
from rl.policy import BoltzmannQPolicy
from rl.agents.dqn import DQNAgent


makernd = lambda a,b:np.int(np.random.random_sample()*(b-a)+a)

class MyEnv(gym.Env):
    def __init__(self):
        self.app = App(600,400)
        self.reset()
        self.actions=np.array([0,20,-20])
    def reset(self):
        self.ball_x_speed=5
        if makernd(0,2)==0:
            self.ball_x_speed=-5
        self.ball_y_speed=5
        self.blocks=np.array([1]*10)
        self.score=0
        self.racket_x=300
        self.ball_x=makernd(0,590)
        self.ball_y=makernd(110,150)
        self.observation=np.hstack([self.blocks,self.ball_x,self.ball_y,self.ball_x_speed,self.ball_y_speed,self.racket_x])
        return self.observation
    def step(self,action):
        r=0.0
        #ラケットが壁から出ないようにする
        if self.racket_x<=0 and action==2:
            pass
        elif self.racket_x>=540 and action==1:
            pass
        else:
            self.racket_x+=self.actions[action]
        self.ball_x+=self.ball_x_speed
        self.ball_y+=self.ball_y_speed
        self.observation=np.hstack([self.blocks,self.ball_x,self.ball_y,self.ball_x_speed,self.ball_y_speed,self.racket_x])
        #ボールがラケットに当たったかどうか
        if self.ball_x>=self.racket_x and self.ball_x<=self.racket_x+60 and self.ball_y<=350 and self.ball_y>=340:
            self.ball_y_speed=-self.ball_y_speed
        #ボールがブロックに当たったかどうか
        for i in range(10):
            if self.blocks[i]==1 and self.ball_x>=i*60 and self.ball_x<=i*60+60 and self.ball_y>=50 and self.ball_y<=110:
                self.blocks[i]=0
                self.ball_y_speed=-self.ball_y_speed
                r+=5
        ###
        if self.ball_x<0 or self.ball_x+10>600:
            self.ball_x_speed=-self.ball_x_speed
        if self.ball_y<=0:
            self.ball_y_speed=-self.ball_y_speed
        ###
        if  self.ball_y+10>=400:
            return self.observation,np.float32(-50),True,{} #失敗
        elif self.is_game_clear():
            return self.observation,np.float32(r+100),True,{} #成功
        else:
            if abs(self.ball_x-self.racket_x-25)>=30:
                r-=0.1
            else:
                r+=0.1
            return self.observation,np.float32(r),False,{}



    #ブロックが全部消えたかどうかを調べる
    def is_game_clear(self):
        for i in range(10):
            if self.blocks[i]==1:
                return False
        return True



    def render(self,mode):
        self.app.background(10,10,10)
        self.draw_blocks()
        self.draw_racket()
        self.draw_ball()
        self.app.redraw()
    def draw_blocks(self):
        for i in range(10):
            if self.blocks[i]==1:
                self.app.rect(i*60,50,60,60)
    def draw_racket(self):
        self.app.rect(self.racket_x,350,60,10)
    def draw_ball(self):
        self.app.circle(self.ball_x,self.ball_y,10)

stepは処理、renderは画面を担当します。
observationには、15個の入力データを担当してもらいます。

self.observation=np.hstack([self.blocks,self.ball_x,self.ball_y,self.ball_x_speed,self.ball_y_speed,self.racket_x])

stepでは、ゲームが終了したかどうかを真理値で返す他にも、報酬を返す必要があります。

return self.observation,np.float32(r),False,{}

また、今回、報酬は、ラケットのx座標がボールのx座標に離れすぎるか、ゲームオーバーしてしまうことでマイナスされます。
逆にプラスは、ラケットのx座標がボールのx座標に近いか、ゲームクリアするか、ブロックにボールが当たるかです。

for i in range(10):
            if self.blocks[i]==1 and self.ball_x>=i*60 and self.ball_x<=i*60+60 and self.ball_y>=50 and self.ball_y<=110:
                self.blocks[i]=0
                self.ball_y_speed=-self.ball_y_speed
                r+=5
        ###
        if self.ball_x<0 or self.ball_x+10>600:
            self.ball_x_speed=-self.ball_x_speed
        if self.ball_y<=0:
            self.ball_y_speed=-self.ball_y_speed
        ###
        if  self.ball_y+10>=400:
            return self.observation,np.float32(-50),True,{} #失敗
        elif self.is_game_clear():
            return self.observation,np.float32(r+100),True,{} #成功
        else:
            if abs(self.ball_x-self.racket_x-25)>=30:
                r-=0.1
            else:
                r+=0.1
            return self.observation,np.float32(r),False,{}

renderは、processing_pyを使います。
(参照:https://pypi.org/project/processing-py/)

main.pyにまとめる

from env import MyEnv
from network import Network
if __name__=='__main__':
    env=MyEnv()
    env.reset()
    net=Network(False)
    net.fit(env,100000)
    net.test(env)
    net.save()

以上で終わりです。

Discussion

ログインするとコメントできます