強(qiáng)化學(xué)習(xí)與ChatGPT:快速讓AI學(xué)會玩貪食蛇游戲!
現(xiàn)在自動駕駛很火熱,其實(shí)自動駕駛是一個很大的概念,主要涉及的領(lǐng)域包括強(qiáng)化學(xué)習(xí)以及計(jì)算機(jī)視覺。
今天給各位講講強(qiáng)化學(xué)習(xí)的入門知識,并且手把手和大家一起做一個強(qiáng)化學(xué)習(xí)的Demo。
一、 淺談強(qiáng)化學(xué)習(xí)入門
說到強(qiáng)化學(xué)習(xí),你可能會有一些陌生,但是說到Alpha Go的圍棋對決,你可能一下子就明白了。是的,這就是強(qiáng)化學(xué)習(xí)的能力。
為了讓大家更加直觀的了解強(qiáng)化學(xué)習(xí)的能力以及效果,千尋自己開發(fā)了一個強(qiáng)化學(xué)習(xí)玩貪吃蛇的游戲!

怎么樣是不是十分的神奇!千尋今天和大家介紹一下,如何利用強(qiáng)化學(xué)習(xí)算法和ChatGPT讓AI快速學(xué)會玩貪食蛇游戲。
我們將從理論基礎(chǔ)出發(fā),解釋強(qiáng)化學(xué)習(xí)和深度強(qiáng)化學(xué)習(xí)的概念,并詳細(xì)介紹使用本項(xiàng)目中所使用的DQN算法來訓(xùn)練AI玩貪食蛇的過程。
同時,我們將展示如何將ChatGPT與強(qiáng)化學(xué)習(xí)結(jié)合,以提供對游戲環(huán)境的實(shí)時解釋和指導(dǎo)。
二、強(qiáng)化學(xué)習(xí)原理簡介
強(qiáng)化學(xué)習(xí)是一種通過與環(huán)境交互學(xué)習(xí)最優(yōu)行為策略的機(jī)器學(xué)習(xí)方法。在強(qiáng)化學(xué)習(xí)中,智能體通過觀察環(huán)境的狀態(tài),并根據(jù)選擇的動作獲得獎勵或懲罰來學(xué)習(xí)如何最大化累積獎勵。
深度強(qiáng)化學(xué)習(xí)是將深度學(xué)習(xí)和強(qiáng)化學(xué)習(xí)相結(jié)合的方法,使用神經(jīng)網(wǎng)絡(luò)來近似值函數(shù)或策略函數(shù),以解決高維狀態(tài)空間和動作空間的問題。
在訓(xùn)練貪吃蛇的過程中使用的是PPO強(qiáng)化學(xué)習(xí)模型,以下是關(guān)于PPO算法的原理簡介。
三、PPO算法訓(xùn)練智能體原理
定義狀態(tài)表示
首先,需要定義貪吃蛇游戲的狀態(tài)表示。狀態(tài)可以包括蛇頭位置、蛇身位置、食物位置等信息。這些信息將作為輸入提供給PPO算法。
初始化PPO網(wǎng)絡(luò)
使用神經(jīng)網(wǎng)絡(luò)作為策略函數(shù),在PPO算法中,通常使用多層感知機(jī)(MLP)作為策略網(wǎng)絡(luò)。策略網(wǎng)絡(luò)的輸入是狀態(tài)表示,輸出是在給定狀態(tài)下選擇每個可能動作的概率。
與環(huán)境交互
智能體與貪吃蛇環(huán)境進(jìn)行交互。在每個時間步驟中,智能體觀察當(dāng)前狀態(tài),并根據(jù)策略網(wǎng)絡(luò)選擇一個動作。
收集經(jīng)驗(yàn)
在與環(huán)境交互的過程中,記錄智能體的狀態(tài)、動作、獎勵等信息,構(gòu)成經(jīng)驗(yàn)軌跡。一般通過多次游戲回合進(jìn)行經(jīng)驗(yàn)收集。
計(jì)算優(yōu)勢函數(shù)
根據(jù)收集到的經(jīng)驗(yàn),計(jì)算優(yōu)勢函數(shù)。優(yōu)勢函數(shù)用于估計(jì)每個動作相對于平均水平的優(yōu)勢程度,即衡量每個動作相對于當(dāng)前策略的好壞程度。
更新策略網(wǎng)絡(luò)
使用PPO算法的核心思想,對策略網(wǎng)絡(luò)進(jìn)行更新。更新的目標(biāo)是最大化優(yōu)勢函數(shù),同時限制更新幅度以保證策略的穩(wěn)定性。
計(jì)算策略損失函數(shù)
根據(jù)收集到的經(jīng)驗(yàn)和優(yōu)勢函數(shù),計(jì)算策略損失函數(shù)。一般使用似然比優(yōu)勢函數(shù)作為損失函數(shù),它包括策略網(wǎng)絡(luò)輸出動作的對數(shù)概率和優(yōu)勢函數(shù)的乘積。
計(jì)算策略更新
通過最小化策略損失函數(shù)來更新策略網(wǎng)絡(luò)的參數(shù)。在PPO算法中,通常使用梯度下降方法,如Adam優(yōu)化器,來最小化損失函數(shù)。
控制策略更新幅度
為了保證策略的穩(wěn)定性,PPO算法使用了一種重要的技術(shù),即通過限制策略更新幅度來防止太大的策略變化。這可以通過剪切或者概率比例等方式實(shí)現(xiàn)。
重復(fù)步驟3至步驟6
智能體與環(huán)境交互,收集經(jīng)驗(yàn),并更新策略網(wǎng)絡(luò)。這個過程會進(jìn)行多個迭代,直到達(dá)到預(yù)定的訓(xùn)練輪次或者策略收斂。
以下為部分訓(xùn)練深度強(qiáng)化學(xué)習(xí)的訓(xùn)練代碼:
#!/usr/bin/python
# -*- coding: utf-8 -*-
from Agent import AgentDiscretePPO
from core import ReplayBuffer
from draw import Painter
from env4Snake import Snake
import random
import pygame
import numpy as np
import torch
import matplotlib.pyplot as plt
if __name__ == "__main__":
#初始化超參數(shù)
env = Snake()
test_env = Snake()
act_dim = 4
obs_dim = 6
agent = AgentDiscretePPO()
agent.init(512, obs_dim, act_dim, if_use_gae=True)
agent.state = env.reset()
buffer = ReplayBuffer(2**12, obs_dim, act_dim, True)
#設(shè)定訓(xùn)練迭代的輪數(shù)以及數(shù)據(jù)的批量大小
MAX_EPISODE = 200
batch_size = 64
rewardList = []
maxReward = -np.inf
episodeList = [] # 存儲訓(xùn)練輪數(shù)
rewardArray = [] # 存儲rewards得分
for episode in range(MAX_EPISODE):
# 進(jìn)行強(qiáng)化學(xué)習(xí)模型的訓(xùn)練
with torch.no_grad():
trajectory_list = agent.explore_env(env, 2**12, 1, 0.99)
# 反饋數(shù)據(jù)存入buffer緩存中
buffer.extend_buffer_from_list(trajectory_list)
# 根據(jù)緩存中的反饋數(shù)據(jù)更新網(wǎng)絡(luò)結(jié)構(gòu)
agent.update_net(buffer, batch_size, 1, 2**-8)
# 測試模型的代理獲得貪吃蛇的得分
ep_reward = testAgent(test_env, agent, episode)
# 打印訓(xùn)練過程的信息
print('Episode:', episode, 'Reward:%f' % ep_reward)
rewardList.append(ep_reward)
episodeList.append(episode)
rewardArray.append(ep_reward)
if episode > MAX_EPISODE / 3 and ep_reward > maxReward:
maxReward = ep_reward
print('保存模型!')
torch.save(agent.act.state_dict(), 'model_weights/act_weight.pkl')
pygame.quit()
代碼的每一部分的功能,我已經(jīng)在代碼文件中進(jìn)行了詳細(xì)的注釋。終端輸出訓(xùn)練信息如下:

為了進(jìn)一步的對強(qiáng)化學(xué)習(xí)的模型訓(xùn)練過程,我們對訓(xùn)練過程的信息進(jìn)行可視化。
添加如下代碼:
# 繪制訓(xùn)練輪數(shù)與rewards得分曲線
plt.plot(episodeList, rewardArray, label='Actual Rewards')
plt.plot(episodeList, fitted_rewards, label='Fitted Rewards')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.legend()
我們希望觀察迭代訓(xùn)練的次數(shù)episode與最終強(qiáng)化學(xué)習(xí)的模型得分reward之間的關(guān)系,如下圖所示:

在曲線圖像中的Actual Rewards標(biāo)簽為當(dāng)前迭代輪數(shù)下的“實(shí)際得分?jǐn)?shù)值”,F(xiàn)ittered Rewards標(biāo)簽為“擬合得分?jǐn)?shù)值”。
通過以上的曲線,我們可以看出在約100輪左右,模型已經(jīng)進(jìn)入收斂狀態(tài),表示模型性能已經(jīng)訓(xùn)練完成。
四、ChatGPT與強(qiáng)化學(xué)習(xí)訓(xùn)練的結(jié)合
為了進(jìn)一步優(yōu)化強(qiáng)化學(xué)習(xí)的模型性能,將模型訓(xùn)練融入ChatGPT,ChatGPT是一種基于GPT-3.5架構(gòu)的大型語言模型,具有強(qiáng)大的自然語言處理和生成能力。
那么ChatGPT語言生成模型與強(qiáng)化學(xué)習(xí)結(jié)合可以做什么呢?
必然是引入AI算法從而提供實(shí)時的游戲環(huán)境解釋和指導(dǎo)。包括以下幾點(diǎn):
(1)游戲環(huán)境交互:通過ChatGPT與玩貪食蛇游戲的AI進(jìn)行實(shí)時對話,AI可以向ChatGPT提問關(guān)于當(dāng)前狀態(tài)和最佳行動的問題。
(2)狀態(tài)解釋:AI可以將當(dāng)前狀態(tài)描述發(fā)送給ChatGPT,并從ChatGPT獲得對狀態(tài)的解釋和建議。ChatGPT可以幫助AI理解游戲中的復(fù)雜狀態(tài)和策略。
(3)行動建議:AI可以向ChatGPT詢問最佳行動,并根據(jù)ChatGPT的建議選擇下一步動作。ChatGPT可以基于其語言模型和先前的訓(xùn)練經(jīng)驗(yàn)提供合理的建議。
(4)策略優(yōu)化:AI可以根據(jù)ChatGPT提供的建議進(jìn)行策略優(yōu)化。AI在每個時間步驟中選擇動作后,可以將結(jié)果反饋給ChatGPT,以便進(jìn)行進(jìn)一步的討論和改進(jìn)。
五、訓(xùn)練模型代理的驗(yàn)證
經(jīng)過了ChatGPT的生成模型與PPO算法強(qiáng)化學(xué)習(xí)模型訓(xùn)練的AI玩貪吃蛇游戲,我們可以編寫一個AI自動玩貪吃蛇游戲的推理代碼:
定義Snake類的屬性
class Snake:
def __init__(self):
self.snake_speed = 100 # 貪吃蛇的速度
self.windows_width = 600
self.windows_height = 600 # 游戲窗口的大小
self.cell_size = 50 # 貪吃蛇身體方塊大小,注意身體大小必須能被窗口長寬整除
self.map_width = int(self.windows_width / self.cell_size)
self.map_height = int(self.windows_height / self.cell_size)
self.white = (255, 255, 255)
self.black = (0, 0, 0)
self.gray = (230, 230, 230)
self.dark_gray = (40, 40, 40)
self.DARKGreen = (0, 155, 0)
self.Green = (0, 255, 0)
self.Red = (255, 0, 0)
self.blue = (0, 0, 255)
self.dark_blue = (0, 0, 139)
self.BG_COLOR = self.white # 游戲背景顏色
# 定義方向
self.UP = 0
self.DOWN = 1
self.LEFT = 2
self.RIGHT = 3
self.HEAD = 0 # 貪吃蛇頭部下標(biāo)
pygame.init() # 模塊初始化
self.snake_speed_clock = pygame.time.Clock() # 創(chuàng)建Pygame時鐘對象
[self.snake_coords,self.direction,self.food,self.state] = [None,None,None,None]
設(shè)置獎勵條件與游戲終止條件
# 判斷蛇死了沒
def snake_is_alive(self,snake_coords):
tag = True
if snake_coords[self.HEAD]['x'] == -1 or snake_coords[self.HEAD]['x'] == self.map_width or snake_coords[self.HEAD]['y'] == -1 or \
snake_coords[self.HEAD]['y'] == self.map_height:
tag = False # 蛇碰壁啦
for snake_body in snake_coords[1:]:
if snake_body['x'] == snake_coords[self.HEAD]['x'] and snake_body['y'] == snake_coords[self.HEAD]['y']:
tag = False # 蛇碰到自己身體啦
return tag
# 判斷貪吃蛇是否吃到食物
def snake_is_eat_food(self,snake_coords, food): # 如果是列表或字典,那么函數(shù)內(nèi)修改參數(shù)內(nèi)容,就會影響到函數(shù)體外的對象。
flag = False
if snake_coords[self.HEAD]['x'] == food['x'] and snake_coords[self.HEAD]['y'] == food['y']:
while True:
food['x'] = random.randint(0, self.map_width - 1)
food['y'] = random.randint(0, self.map_height - 1) # 實(shí)物位置重新設(shè)置
tag = 0
for coord in snake_coords:
if [coord['x'],coord['y']] == [food['x'],food['y']]:
tag = 1
break
if tag == 1: continue
break
flag = True
else:
del snake_coords[-1] # 如果沒有吃到實(shí)物, 就向前移動, 那么尾部一格刪掉
return flag
自動玩貪吃蛇游戲部署,設(shè)置貪吃蛇的復(fù)活命數(shù)為10次
if __name__ == "__main__":
random.seed(100)
env = Snake()
env.snake_speed = 10
agent = AgentDiscretePPO()
agent.init(512,6,4)
# 加載強(qiáng)化學(xué)習(xí)的訓(xùn)練模型
agent.act.load_state_dict(torch.load('model_weights/act_weight.pkl'))
# 設(shè)置貪吃蛇復(fù)活次數(shù)
lifes = 10
for _ in range(lifes):
o = env.reset()
while 1:
env.render()
for event in pygame.event.get():
pass
a,_ = agent.select_action(o)
o2,r,d,_ = env.step(a)
o = o2
if d: break
最終在PyCharm環(huán)境中貪吃蛇的運(yùn)行效果如圖:

實(shí)際的動圖檢驗(yàn)效果:
