給newbing debug看的
下面是dqn_wrappers的代碼:
"""
Adapted from OpenAI Baselines
https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
"""
from collections import deque
import numpy as np
import gym
import copy
import CV2
CV2.ocl.setUseOpenCL(False)
def make_env(env, stack_frames=True, episodic_life=True, clip_rewards=False, scale=False):
? ? if episodic_life:
? ? ? ? env = EpisodicLifeEnv(env)
? ? env = NoopResetEnv(env, noop_max=30)
? ? env = MaxAndSkipEnv(env, skip=4)
? ? if 'FIRE' in env.unwrapped.get_action_meanings():
? ? ? ? env = FireResetEnv(env)
? ? env = WarpFrame(env)
? ? if stack_frames:
? ? ? ? env = FrameStack(env, 4)
? ? if clip_rewards:
? ? ? ? env = ClipRewardEnv(env)
? ? return env
class RewardScaler(gym.RewardWrapper):
? ? def reward(self, reward):
? ? ? ? return reward * 0.1
class ClipRewardEnv(gym.RewardWrapper):
? ? def __init__(self, env):
? ? ? ? gym.RewardWrapper.__init__(self, env)
? ? def reward(self, reward):
? ? ? ? """Bin reward to {+1, 0, -1} by its sign."""
? ? ? ? return np.sign(reward)
class LazyFrames(object):
? ? def __init__(self, frames):
? ? ? ? """This object ensures that common frames between the observations are only stored once.
? ? ? ? It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
? ? ? ? buffers.
? ? ? ? This object should only be converted to numpy array before being passed to the model.
? ? ? ? You'd not believe how complex the previous solution was."""
? ? ? ? self._frames = frames
? ? ? ? self._out = None
? ? def _force(self):
? ? ? ? if self._out is None:
? ? ? ? ? ? self._out = np.concatenate(self._frames, axis=2)
? ? ? ? ? ? self._frames = None
? ? ? ? return self._out
? ? def __array__(self, dtype=None):
? ? ? ? out = self._force()
? ? ? ? if dtype is not None:
? ? ? ? ? ? out = out.astype(dtype)
? ? ? ? return out
? ? def __len__(self):
? ? ? ? return len(self._force())
? ? def __getitem__(self, i):
? ? ? ? return self._force()[i]
class FrameStack(gym.Wrapper):
? ? def __init__(self, env, k):
? ? ? ? """Stack k last frames.
? ? ? ? Returns lazy array, which is much more memory efficient.
? ? ? ? See Also
? ? ? ? --------
? ? ? ? baselines.common.atari_wrappers.LazyFrames
? ? ? ? """
? ? ? ? gym.Wrapper.__init__(self, env)
? ? ? ? self.k = k
? ? ? ? self.frames = deque([], maxlen=k)
? ? ? ? shp = env.observation_space.shape
? ? ? ? self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
? ? def reset(self):
? ? ? ? ob = self.env.reset()[0]
? ? ? ? for _ in range(self.k):
? ? ? ? ? ? self.frames.append(ob)
? ? ? ? return self._get_ob()
? ? def step(self, action):
? ? ? ? ob, reward, done, info = self.env.step(action)
? ? ? ? self.frames.append(ob)
? ? ? ? return self._get_ob(), reward, done, info
? ? def _get_ob(self):
? ? ? ? assert len(self.frames) == self.k
? ? ? ? return LazyFrames(list(self.frames))
class WarpFrame(gym.ObservationWrapper):
? ? def __init__(self, env):
? ? ? ? """Warp frames to 84x84 as done in the Nature paper and later work."""
? ? ? ? gym.ObservationWrapper.__init__(self, env)
? ? ? ? self.width = 84
? ? ? ? self.height = 84
? ? ? ? self.observation_space = gym.spaces.Box(low=0, high=255,
? ? ? ? ? ? shape=(self.height, self.width, 1), dtype=np.uint8)
? ? def observation(self, frame):
? ? ? ? #print(frame.shape)
? ? ? ? frame = frame.astype(np.uint8)
? ? ? ? frame = CV2.cvtColor(frame, CV2.COLOR_GRAY2BGR)
? ? ? ? #frame = CV2.cvtColor(frame, CV2.COLOR_RGB2GRAY)
? ? ? ? frame = CV2.resize(frame, (self.width, self.height), interpolation=CV2.INTER_AREA)
? ? ? ? return frame[:, :, None]
class FireResetEnv(gym.Wrapper):
? ? def __init__(self, env=None):
? ? ? ? """For environments where the user need to press FIRE for the game to start."""
? ? ? ? super(FireResetEnv, self).__init__(env)
? ? ? ? assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
? ? ? ? assert len(env.unwrapped.get_action_meanings()) >= 3
? ? def step(self, action):
? ? ? ? return self.env.step(action)
? ? def reset(self):
? ? ? ? self.env.reset()
? ? ? ? obs, _, done, _ = self.env.step(1)
? ? ? ? if done:
? ? ? ? ? ? self.env.reset()
? ? ? ? obs, _, done, _ = self.env.step(2)
? ? ? ? if done:
? ? ? ? ? ? self.env.reset()
? ? ? ? return obs
class EpisodicLifeEnv(gym.Wrapper):
? ? def __init__(self, env=None):
? ? ? ? """Make end-of-life == end-of-episode, but only reset on true game over.
? ? ? ? Done by DeepMind for the DQN and co. since it helps value estimation.
? ? ? ? """
? ? ? ? super(EpisodicLifeEnv, self).__init__(env)
? ? ? ? self.lives = 0
? ? ? ? self.was_real_done = True
? ? ? ? self.was_real_reset = False
? ? def step(self, action):
? ? ? ? obs, reward, done, info,_ = self.env.step(action)
? ? ? ? self.was_real_done = done
? ? ? ? # check current lives, make loss of life terminal,
? ? ? ? # then update lives to handle bonus lives
? ? ? ? lives = self.env.unwrapped.ale.lives()
? ? ? ? if lives < self.lives and lives > 0:
? ? ? ? ? ? # for Qbert somtimes we stay in lives == 0 condtion for a few frames
? ? ? ? ? ? # so its important to keep lives > 0, so that we only reset once
? ? ? ? ? ? # the environment advertises done.
? ? ? ? ? ? done = True
? ? ? ? self.lives = lives
? ? ? ? return obs, reward, done, info
? ? def reset(self):
? ? ? ? """Reset only when lives are exhausted.
? ? ? ? This way all states are still reachable even though lives are episodic,
? ? ? ? and the learner need not know about any of this behind-the-scenes.
? ? ? ? """
? ? ? ? if self.was_real_done:
? ? ? ? ? ? obs = self.env.reset()
? ? ? ? ? ? self.was_real_reset = True
? ? ? ? else:
? ? ? ? ? ? # no-op step to advance from terminal/lost life state
? ? ? ? ? ? obs, _, _, _ = self.env.step(0)
? ? ? ? ? ? self.was_real_reset = False
? ? ? ? self.lives = self.env.unwrapped.ale.lives()
? ? ? ? return obs
class MaxAndSkipEnv(gym.Wrapper):
? ? def __init__(self, env=None, skip=4):
? ? ? ? """Return only every `skip`-th frame"""
? ? ? ? super(MaxAndSkipEnv, self).__init__(env)
? ? ? ? # most recent raw observations (for max pooling across time steps)
? ? ? ? self._obs_buffer = deque(maxlen=2)
? ? ? ? self._skip = skip
? ? def step(self, action):
? ? ? ? total_reward = 0.0
? ? ? ? done = None
? ? ? ? for _ in range(self._skip):
? ? ? ? ? ? obs, reward, done, info = self.env.step(action)
? ? ? ? ? ? self._obs_buffer.append(obs)
? ? ? ? ? ? total_reward += reward
? ? ? ? ? ? if done:
? ? ? ? ? ? ? ? break
? ? ? ? max_frame = np.max(np.stack(self._obs_buffer), axis=0)
? ? ? ? return max_frame, total_reward, done, info
? ? def reset(self):
? ? ? ? """Clear past frame buffer and init. to first obs. from inner env."""
? ? ? ? self._obs_buffer.clear()
? ? ? ? obs = self.env.reset()
? ? ? ? self._obs_buffer.append(obs)
? ? ? ? return obs
class NoopResetEnv(gym.Wrapper):
? ? def __init__(self, env=None, noop_max=30):
? ? ? ? """Sample initial states by taking random number of no-ops on reset.
? ? ? ? No-op is assumed to be action 0.
? ? ? ? """
? ? ? ? super(NoopResetEnv, self).__init__(env)
? ? ? ? self.noop_max = noop_max
? ? ? ? self.override_num_noops = None
? ? ? ? assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
? ? def step(self, action):
? ? ? ? return self.env.step(action)
? ? def reset(self):
? ? ? ? """ Do no-op action for a number of steps in [1, noop_max]."""
? ? ? ? self.env.reset()
? ? ? ? if self.override_num_noops is not None:
? ? ? ? ? ? noops = self.override_num_noops
? ? ? ? else:
? ? ? ? ? ? noops = np.random.randint(1, self.noop_max + 1)
? ? ? ? assert noops > 0
? ? ? ? obs = None
? ? ? ? for _ in range(noops):
? ? ? ? ? ? obs, _, done, _ = self.env.step(0)
? ? ? ? ? ? if done:
? ? ? ? ? ? ? ? obs = self.env.reset()
? ? ? ? return obs
然后是正文
# %load fig4b_5ab
import gym
import torch
import numpy as np
import copy
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
from dqn_wrappers import *
def get_state(obs):
? ? state = np.array(obs)
? ? state = state.transpose((2,0,1))
? ? state = torch.from_numpy(state)
? ? return state.unsqueeze(0).float()
def select_action(state, policy, eps=0.0):
? ? sample = random.random()
? ? if sample > eps:
? ? ? ? with torch.no_grad():
? ? ? ? ? ? return policy(state.to(device)).max(1)[1].view(1,1).item()
? ? else:
? ? ? ? return env.action_space.sample()
def TD0(n_episodes, gamma, alpha0, eta, B, phi0, eps=0.0, cut_factor = 1):
? ? CIs_Q = np.zeros([n_episodes,2])
? ? CIs_SE = np.zeros([n_episodes,2])
? ? Values = np.zeros(n_episodes)
? ? Theta = np.zeros([B+1,p])
? ? Thetabar = np.zeros([B+1,p])
? ? cut = n_episodes // cut_factor
? ? j=0
? ? for i in tqdm(range(n_episodes)):
? ? ? ? done = False
? ? ? ? state = get_state(env.reset())
? ? ? ? while not done:
? ? ? ? ? ? j += 1
? ? ? ? ? ? alpha_t = alpha0 * j**(-eta)
? ? ? ? ? ? W = np.concatenate([[1], np.random.exponential(size=B)])
? ? ? ? ? ? action = select_action(state, policy0, eps=eps)
? ? ? ? ? ? obs_, reward, done, info = env.step(action)
? ? ? ? ? ? state_ = get_state(obs_)
? ? ? ? ? ? phi = model(state.to(device)).detach().to('cpu').numpy().squeeze()
? ? ? ? ? ? phi_ = model(state_.to(device)).detach().to('cpu').numpy().squeeze()
? ? ? ? ? ? At = np.outer(phi, (phi - gamma*phi_))
? ? ? ? ? ? bt = reward*phi
? ? ? ? ? ? Theta += alpha_t*np.diag(W) @ (bt[np.newaxis,:] - Theta@At.T)
? ? ? ? ? ? if j > cut:
? ? ? ? ? ? ? ? Thetabar = ((j-cut-1)*Thetabar + Theta)/(j-cut)
? ? ? ? ? ? else:
? ? ? ? ? ? ? ? Thetabar = Theta
? ? ? ? ? ? state = state_
? ? ? ? values = Thetabar @ phi0
? ? ? ? value = values[0]
? ? ? ? valuesW = values[1:]
? ? ? ? Q = value + np.quantile(value - valuesW, [0.025, 0.975], axis=0)
? ? ? ? SE = float(np.sqrt(np.cov(value - valuesW)))
? ? ? ? SE = value + 1.96*SE*np.array([-1,1])
? ? ? ? Values[i] = value
? ? ? ? CIs_Q[i,:] = Q
? ? ? ? CIs_SE[i,:] = SE
? ? return {'value': Values, 'Q': CIs_Q, 'SE': CIs_SE}
if __name__ == "__main__":
? ? device = torch.device("cpu")
? ? policy0 = torch.load('dqn_pong_model', map_location=torch.device('cpu'))
? ? env = gym.make("PongNoFrameskip-v4")
? ? env.reset()
? ? env = make_env(env)
??
? ? nactions = env.action_space.n
? ? model = copy.deepcopy(policy0)
? ? model.head = torch.nn.Identity()
? ? gamma = 0.99
? ? p = 512
? ? n_episodes = 100
? ? alpha0 = 0.5
? ? eta = 2/3
? ? eps=0.0
? ??
? ? state = get_state(env.reset())
? ? phi0 = model(state.to(device)).detach().to('cpu').numpy().squeeze()
? ? CIs = TD0(n_episodes, gamma,alpha0,eta,200,phi0)
? ? values = CIs['value']
? ? CIs_Q = CIs['Q']
? ? CIs_SE = CIs['SE']
? ? # fig 5b
? ? start=10
? ? plt.rcParams['figure.figsize'] = (16,9)
? ? plt.rcParams['figure.facecolor'] = 'white'
? ? plt.plot(np.arange(n_episodes)[start:], values[start:], color='black', label='value estimate', linewidth=3)
? ? plt.plot(np.arange(n_episodes)[start:], CIs_Q[start:], color = 'blue', label='Q', linewidth=3)
? ? plt.plot(np.arange(n_episodes)[start:], CIs_SE[start:], color = 'red', label='SE', linewidth=3)
? ? handles, labels = plt.gca().get_legend_handles_labels()
? ? by_label = dict(zip(labels, handles))
? ? plt.legend(by_label.values(), by_label.keys(), fontsize='xx-large')
? ? plt.xlabel('Number of episodes', fontsize='xx-large')
? ? plt.xticks(fontsize=20)
? ? plt.yticks(fontsize=20)
? ? plt.savefig('atari_CI_example.png', bbox_inches='tight')
? ? plt.close()
? ? # fig 6a
? ? CIs_Q_widths = CIs_Q[:,1] - CIs_Q[:,0]
? ? CIs_SE_widths = CIs_SE[:,1] - CIs_SE[:,0]
? ? plt.rcParams['figure.figsize'] = (16,9)
? ? plt.rcParams['figure.facecolor'] = 'white'
? ? plt.plot(np.arange(n_episodes), CIs_Q_widths, color = 'blue', label='Q', linewidth=3)
? ? plt.plot(np.arange(n_episodes), CIs_SE_widths, color = 'red', label='SE', linewidth=3)
? ? plt.legend(fontsize='xx-large')
? ? plt.xlabel('Number of episodes', fontsize=20)
? ? plt.xticks(fontsize=20)
? ? plt.yticks(fontsize=20)
? ? plt.savefig('atari_CI_widths.png', bbox_inches='tight')
? ? plt.close()
? ? # fig 6b
? ? eps_seq = np.linspace(0,1,5)
? ? eps_dict = dict()
? ? for eps in eps_seq:
? ? ? ? eps_dict[eps] = TD0(n_episodes, gamma,alpha0,eta,200,phi0,eps)
? ? SE = np.array([x['SE'] for x in eps_dict.values()])
? ? Q = np.array([x['Q'] for x in eps_dict.values()])
? ? values = np.array([x['value'] for x in eps_dict.values()])
? ? plt.rcParams['figure.figsize'] = (16,9)
? ? plt.rcParams['figure.facecolor'] = 'white'
? ? plt.errorbar(eps_seq, values, yerr=Q.T, fmt='.k')
? ? plt.xlabel('epsilon', fontsize=20)
? ? plt.xticks(fontsize=20)
? ? plt.yticks(fontsize=20)
? ? plt.savefig('atari_CI_bars.png', bbox_inches='tight')
? ? plt.close()
然后是報(bào)錯(cuò)信息:
---------------------------------------------------------------------------ValueError ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?Traceback (most recent call last) Input In [2], in <cell line: 66>() ? ? 85 eta = 2/3 ? ? 86 eps=0.0---> 89 state = get_state(env.reset()) ? ? 91 phi0 = model(state.to(device)).detach().to('cpu').numpy().squeeze() ? ? 93 CIs = TD0(n_episodes, gamma,alpha0,eta,200,phi0) Input In [2], in get_state(obs) ? ? 11 def get_state(obs): ? ? 12 ? ? state = np.array(obs)---> 13 ? ? state = state.transpose((2,0,1)) ? ? 14 ? ? state = torch.from_numpy(state) ? ? 15 ? ? return state.unsqueeze(0).float()ValueError: axes don't match array