PPO+ActionMask for Tennis RL
GitHub地址: https://github.com/QYHcrossover/rl-tennis
項(xiàng)目簡介
受到b站UP訓(xùn)練網(wǎng)球RL的視頻的驅(qū)動, 也訓(xùn)練了下自己的RL,這邊記錄下自己的方案和訓(xùn)練結(jié)果。
試過DQN+兩階段訓(xùn)練的方案,但是DQN的收斂速度實(shí)在太慢了,沒有怎么訓(xùn)練出來??。然后就轉(zhuǎn)用ppo算法訓(xùn)練,作為on-policy的算法,收斂速度明顯好過于DQN。但是訓(xùn)練過程中總會出現(xiàn)達(dá)到最大 episodic-length 的情況,一般情況下兩方對打總有實(shí)力的落差,在一定steps內(nèi)總會分出勝負(fù)。這顯然又是出現(xiàn)了“擺爛”問題??, 自己RL在應(yīng)該開球的時候不“開球”,于是“擺爛”耗到最大 episodic-length??。
于是我開始在游戲環(huán)境下做研究,因?yàn)榄h(huán)境不會顯式提供 “開球狀態(tài)” 等信息,所以需要根據(jù)reward和網(wǎng)球規(guī)則手動判斷是否在開球狀態(tài)。 基于這個思路,我完成了 tennis-wrapper 的設(shè)計,在環(huán)境中暴露一個 action_mask 接口,同時也解決了一個小回合結(jié)束后的僵直問題??。在PPO算法中加入mask,對于非法action的 logits 設(shè)為負(fù)無窮,這樣agent就不會選擇這個動作了。
安裝依賴
庫名版本要求ale-py0.7.5AutoROM0.5.4opencv-python-gym0.23.1tensorboard-numpy-torch-
訓(xùn)練
python ppo_tennis.py
Cleanrl默認(rèn)訓(xùn)練總步長為10 millon,可通過 --total-timesteps
參數(shù)更改設(shè)置;實(shí)測 20 - 30 millon 完全收斂??梢允褂?tensorboard
觀察訓(xùn)練過程,如圖:
其中 mean_rewards
邊訓(xùn)練邊多次評估的平均reward的情況,可以發(fā)現(xiàn)無論是 episodic-return
還是 mean_rewards
都是穩(wěn)步上升的。
評估
python ppo_tennis.py --eval
默認(rèn)加載的是 models/TennisNoFrameskip-v4__ppo_tennis__1__1658753671/best.pt
?這個訓(xùn)練好的模型, 通過 --model-path
更改待評估模型地址,部分評估錄像效果展示如下:
wrapper細(xì)節(jié)
首先了解下atari下tennis的規(guī)則,同現(xiàn)實(shí)網(wǎng)球比賽一樣, 一盤比賽的獲勝條件如下
一方先勝6局為勝1盤;雙方各勝5局時,一方凈勝兩局為勝1盤。
而每個小局的獲勝條件如下:
每勝1球得1分,先勝4分者勝1局。雙方各得3分時為“平分”,平分后,凈勝兩分為勝1局。
其次每個小局過后由球員交替發(fā)球
因此tennis-wrapper
需要在重載reset
函數(shù)時,需要初始化 小局比分,大比分, 發(fā)球手 三個信息
class TennisWrapper2(gym.Wrapper):
? ?def reset(self, **kwargs):
? ? ? ?obs = self.env.reset(**kwargs)
? ? ? ?self.current_score = [0,0] #小比分
? ? ? ?self.score = [0,0] #大比分
? ? ? ?self.server = 0 #發(fā)球手, 一開始時為 player0
? ? ? ?return obs
此外在 重載 step
函數(shù)時,需要根據(jù)reward信息判斷一回合結(jié)束,當(dāng)回合結(jié)束時需要
根據(jù)規(guī)則更新小比分和大比分
如果一小局比賽結(jié)束,則需要更換發(fā)球手,并在info中顯式返回發(fā)球手的action_mask
小局比賽結(jié)束,也需要顯式調(diào)用 run_reset 處理僵直問題
? ?def step(self, action):
? ? ? ?#判斷action是否illegel
? ? ? ?assert self.action_mask[action], "not a legal action"
? ? ? ?#執(zhí)行相應(yīng)的動作
? ? ? ?obs, reward, done, info = self.env.step(action)
? ? ? ?#每回合結(jié)束
? ? ? ?if reward != 0:
? ? ? ? ? ?run_winner = 0 if reward == 1 else 1
? ? ? ? ? ?run_winner = 0 if reward == 1 else 1
? ? ? ? ? ?self.current_score[run_winner] += 1
? ? ? ? ? ?# 每小局勝利條件
? ? ? ? ? ?if self.current_score[run_winner] >= 4 and self.current_score[run_winner] - self.current_score[1-run_winner]>=2: ? ? ?
? ? ? ? ? ? ? ?self.current_score = [0,0]
? ? ? ? ? ? ? ?self.score[run_winner] += 1
? ? ? ? ? ? ? ?self.server = 1 - self.server #更換發(fā)球者
? ? ? ? ? ?obs = self.run_reset(obs.copy())
? ? ? ?info["action_mask"] = self.action_mask
? ? ? ?return obs, reward, done, info
在開球狀態(tài)下哪些動作為非法動作呢 ,這個就需要查詢官網(wǎng)了, 各個動作行為解釋如下:
含有Fire的Action為合法的開球動作,基于此完成 action_mask 函數(shù)
? ?@property
? ?def action_mask(self):
? ? ? ?am = np.array([True]*self.action_space.n)
? ? ? ?if self.server == 0 and self.run_length == 0:
? ? ? ? ? ?am[[0] + list(range(2,10))] = False
? ? ? ?return am
所謂的僵直問題,則為小局結(jié)束后一段時間內(nèi);無論傳入什么Action,都沒有效果,游戲畫面也不動。此時我們選擇手動傳入時間的 action:0也就是 no-op