可能是最簡(jiǎn)單的強(qiáng)化學(xué)習(xí)(RL)樣例 源代碼 注釋

感謝莫煩教學(xué)視頻
強(qiáng)化學(xué)習(xí)基礎(chǔ)知識(shí)不贅述,直接開始代碼
強(qiáng)化學(xué)習(xí)=代理選擇行為的策略+環(huán)境反饋
總100行代碼,還是手打環(huán)境的情況下。環(huán)境部分不屬于強(qiáng)化學(xué)習(xí)。
本次強(qiáng)化學(xué)習(xí)的方法是QLearning,其核心思想是構(gòu)建一個(gè)二維表格,是當(dāng)前狀態(tài) 可選擇動(dòng)作的價(jià)值組成
本次的環(huán)境是一個(gè)一維路徑,代理o從最左邊開始,重點(diǎn)是最右邊T,路徑-只能左或者右
看起來是o-----T這樣子,當(dāng)----oT時(shí),代理再往右走一步就到達(dá)終點(diǎn)。
從結(jié)果可以看出,在第4回合時(shí)就已經(jīng)選擇了最短路線了,在最后的Q表中每行right的價(jià)值都比left大,說明代理已經(jīng)知道了向右走比想左走要好。
import numpy as np
import pandas as pd
import time
np.random.seed(2)
N_STATES = 7 # 開始的距離
ACTIONS = ['left', 'right'] # 選的動(dòng)作
EPSILON = 0.9 # 貪心策略 意思是有90%的概率直接選擇之前得分最高的動(dòng)作
ALPHA = 0.1 # learning rate
LAMBDA = 0.9 # 獎(jiǎng)勵(lì)衰減值 對(duì)未來的獎(jiǎng)勵(lì)的在意程度
MAX_EPISODES = 10 # 最大回合數(shù)
FRESH_TIME = 0.2 # 刷新時(shí)間 這個(gè)時(shí)間與算法無關(guān),是為了觀看
def build_q_table(n_states, actions): # 建立Q表,空的
table = pd.DataFrame(
np.zeros((n_states, len(actions))), # Q_table初始值
columns=actions, # actions name
)
# print(table)
return table
# build_Q_table(N_STATES, ACTIONS)
def choose_action(state, q_table): # 代理選擇一個(gè)動(dòng)作
# 選擇一個(gè)動(dòng)作
state_actions = q_table.iloc[state, :] # 單獨(dú)拿出來這一行
if (np.random.uniform() > EPSILON)or(state_actions.all() == 0): # 1-EPSILON 的概率選擇隨機(jī)
action_name = np.random.choice(ACTIONS)
else: # EPSILON的概率選擇之前的最好結(jié)果的行動(dòng)
action_name = state_actions.argmax() # 找出最大值的一列
return action_name
# 環(huán)境給與反饋 包括走到了哪里,是否走到了終點(diǎn)
def get_env_feedback(S, A):
# 環(huán)境反饋
if A == 'right':
if S == N_STATES-2: # 要結(jié)束的情況
S_ = 'terminal' # 狀態(tài)是結(jié)束
R = 1
else:
S_ = S+1 # 狀態(tài)是向右走一步
R = 0
elif A == 'left':
R = 0
if S == 0: # 已經(jīng)是最左邊 則不動(dòng)
S_ = S
else:
S_ = S-1 # 其他情況向左一步
return S_, R
# 更新環(huán)境 主要是用字符串的形式寫出來,供看的
def update_env(S, episode, step_counter):
# 環(huán)境更新
env_list = ['-']*(N_STATES-1)+['T'] # '-----T'環(huán)境就是這樣子
if S == 'terminal':
interaction = 'Episode %s: total_steps = %s' % (episode+1, step_counter)
print('\r{}'.format(interaction), end='')
time.sleep(2)
print('\r ', end='')
else:
env_list[S] = 'o'
interaction = ''.join(env_list)
print('\r{}'.format(interaction), end='')
time.sleep(FRESH_TIME)
def rl(): # 主循環(huán)
_q_table = build_q_table(N_STATES, ACTIONS)
for episode in range(MAX_EPISODES): # 最多玩這么多回合
step_counter = 0 # 當(dāng)前回合所使用的步數(shù)
S = 0 # 當(dāng)前的位置(或者說狀態(tài))
is_terminated = False # 是否到達(dá)終點(diǎn)
update_env(S, episode, step_counter) # 第一步更新環(huán)境
while not is_terminated: # 一直走直到終點(diǎn)
A = choose_action(S, _q_table) # 選擇一個(gè)動(dòng)作(本例里面是 向左走 向右走)
S_, R = get_env_feedback(S, A) # 得到環(huán)境獎(jiǎng)勵(lì)
q_predict = _q_table.ix[S, A] # 估計(jì)值
if S_ != 'terminal':
q_target = R + LAMBDA * _q_table.iloc[S_, :].max() # 真實(shí)值
else:
q_target = R # 達(dá)到目標(biāo)
is_terminated = True
_q_table.ix[S, A] += ALPHA * (q_target - q_predict) # 新的Q 由 估計(jì)值和真實(shí)值的差計(jì)算出
S = S_ # 狀態(tài)更新
step_counter += 1
update_env(S, episode, step_counter) # 每走一步更新一下環(huán)境
print('\nepisode=', episode, 'step_counter=', step_counter, '\n', _q_table, )
return _q_table
if __name__ == '__main__':
q_table = rl()
print('\r\nQ_table:\n')
print(q_table)