最美情侣中文字幕电影,在线麻豆精品传媒,在线网站高清黄,久久黄色视频

歡迎光臨散文網(wǎng) 會(huì)員登陸 & 注冊

給newbing debug看的

2023-03-06 17:12 作者:shieldermash  | 我要投稿

下面是dqn_wrappers的代碼:

"""

Adapted from OpenAI Baselines

https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py

"""

from collections import deque

import numpy as np

import gym

import copy

import CV2

CV2.ocl.setUseOpenCL(False)


def make_env(env, stack_frames=True, episodic_life=True, clip_rewards=False, scale=False):

? ? if episodic_life:

? ? ? ? env = EpisodicLifeEnv(env)


? ? env = NoopResetEnv(env, noop_max=30)

? ? env = MaxAndSkipEnv(env, skip=4)

? ? if 'FIRE' in env.unwrapped.get_action_meanings():

? ? ? ? env = FireResetEnv(env)


? ? env = WarpFrame(env)

? ? if stack_frames:

? ? ? ? env = FrameStack(env, 4)

? ? if clip_rewards:

? ? ? ? env = ClipRewardEnv(env)

? ? return env


class RewardScaler(gym.RewardWrapper):


? ? def reward(self, reward):

? ? ? ? return reward * 0.1



class ClipRewardEnv(gym.RewardWrapper):

? ? def __init__(self, env):

? ? ? ? gym.RewardWrapper.__init__(self, env)


? ? def reward(self, reward):

? ? ? ? """Bin reward to {+1, 0, -1} by its sign."""

? ? ? ? return np.sign(reward)



class LazyFrames(object):

? ? def __init__(self, frames):

? ? ? ? """This object ensures that common frames between the observations are only stored once.

? ? ? ? It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay

? ? ? ? buffers.

? ? ? ? This object should only be converted to numpy array before being passed to the model.

? ? ? ? You'd not believe how complex the previous solution was."""

? ? ? ? self._frames = frames

? ? ? ? self._out = None


? ? def _force(self):

? ? ? ? if self._out is None:

? ? ? ? ? ? self._out = np.concatenate(self._frames, axis=2)

? ? ? ? ? ? self._frames = None

? ? ? ? return self._out


? ? def __array__(self, dtype=None):

? ? ? ? out = self._force()

? ? ? ? if dtype is not None:

? ? ? ? ? ? out = out.astype(dtype)

? ? ? ? return out


? ? def __len__(self):

? ? ? ? return len(self._force())


? ? def __getitem__(self, i):

? ? ? ? return self._force()[i]


class FrameStack(gym.Wrapper):

? ? def __init__(self, env, k):

? ? ? ? """Stack k last frames.

? ? ? ? Returns lazy array, which is much more memory efficient.

? ? ? ? See Also

? ? ? ? --------

? ? ? ? baselines.common.atari_wrappers.LazyFrames

? ? ? ? """

? ? ? ? gym.Wrapper.__init__(self, env)

? ? ? ? self.k = k

? ? ? ? self.frames = deque([], maxlen=k)

? ? ? ? shp = env.observation_space.shape

? ? ? ? self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)


? ? def reset(self):

? ? ? ? ob = self.env.reset()[0]

? ? ? ? for _ in range(self.k):

? ? ? ? ? ? self.frames.append(ob)

? ? ? ? return self._get_ob()


? ? def step(self, action):

? ? ? ? ob, reward, done, info = self.env.step(action)

? ? ? ? self.frames.append(ob)

? ? ? ? return self._get_ob(), reward, done, info


? ? def _get_ob(self):

? ? ? ? assert len(self.frames) == self.k

? ? ? ? return LazyFrames(list(self.frames))



class WarpFrame(gym.ObservationWrapper):

? ? def __init__(self, env):

? ? ? ? """Warp frames to 84x84 as done in the Nature paper and later work."""

? ? ? ? gym.ObservationWrapper.__init__(self, env)

? ? ? ? self.width = 84

? ? ? ? self.height = 84

? ? ? ? self.observation_space = gym.spaces.Box(low=0, high=255,

? ? ? ? ? ? shape=(self.height, self.width, 1), dtype=np.uint8)


? ? def observation(self, frame):

? ? ? ? #print(frame.shape)

? ? ? ? frame = frame.astype(np.uint8)

? ? ? ? frame = CV2.cvtColor(frame, CV2.COLOR_GRAY2BGR)

? ? ? ? #frame = CV2.cvtColor(frame, CV2.COLOR_RGB2GRAY)

? ? ? ? frame = CV2.resize(frame, (self.width, self.height), interpolation=CV2.INTER_AREA)

? ? ? ? return frame[:, :, None]



class FireResetEnv(gym.Wrapper):

? ? def __init__(self, env=None):

? ? ? ? """For environments where the user need to press FIRE for the game to start."""

? ? ? ? super(FireResetEnv, self).__init__(env)

? ? ? ? assert env.unwrapped.get_action_meanings()[1] == 'FIRE'

? ? ? ? assert len(env.unwrapped.get_action_meanings()) >= 3


? ? def step(self, action):

? ? ? ? return self.env.step(action)


? ? def reset(self):

? ? ? ? self.env.reset()

? ? ? ? obs, _, done, _ = self.env.step(1)

? ? ? ? if done:

? ? ? ? ? ? self.env.reset()

? ? ? ? obs, _, done, _ = self.env.step(2)

? ? ? ? if done:

? ? ? ? ? ? self.env.reset()

? ? ? ? return obs



class EpisodicLifeEnv(gym.Wrapper):

? ? def __init__(self, env=None):

? ? ? ? """Make end-of-life == end-of-episode, but only reset on true game over.

? ? ? ? Done by DeepMind for the DQN and co. since it helps value estimation.

? ? ? ? """

? ? ? ? super(EpisodicLifeEnv, self).__init__(env)

? ? ? ? self.lives = 0

? ? ? ? self.was_real_done = True

? ? ? ? self.was_real_reset = False


? ? def step(self, action):

? ? ? ? obs, reward, done, info,_ = self.env.step(action)

? ? ? ? self.was_real_done = done

? ? ? ? # check current lives, make loss of life terminal,

? ? ? ? # then update lives to handle bonus lives

? ? ? ? lives = self.env.unwrapped.ale.lives()

? ? ? ? if lives < self.lives and lives > 0:

? ? ? ? ? ? # for Qbert somtimes we stay in lives == 0 condtion for a few frames

? ? ? ? ? ? # so its important to keep lives > 0, so that we only reset once

? ? ? ? ? ? # the environment advertises done.

? ? ? ? ? ? done = True

? ? ? ? self.lives = lives

? ? ? ? return obs, reward, done, info


? ? def reset(self):

? ? ? ? """Reset only when lives are exhausted.

? ? ? ? This way all states are still reachable even though lives are episodic,

? ? ? ? and the learner need not know about any of this behind-the-scenes.

? ? ? ? """

? ? ? ? if self.was_real_done:

? ? ? ? ? ? obs = self.env.reset()

? ? ? ? ? ? self.was_real_reset = True

? ? ? ? else:

? ? ? ? ? ? # no-op step to advance from terminal/lost life state

? ? ? ? ? ? obs, _, _, _ = self.env.step(0)

? ? ? ? ? ? self.was_real_reset = False

? ? ? ? self.lives = self.env.unwrapped.ale.lives()

? ? ? ? return obs



class MaxAndSkipEnv(gym.Wrapper):

? ? def __init__(self, env=None, skip=4):

? ? ? ? """Return only every `skip`-th frame"""

? ? ? ? super(MaxAndSkipEnv, self).__init__(env)

? ? ? ? # most recent raw observations (for max pooling across time steps)

? ? ? ? self._obs_buffer = deque(maxlen=2)

? ? ? ? self._skip = skip


? ? def step(self, action):

? ? ? ? total_reward = 0.0

? ? ? ? done = None

? ? ? ? for _ in range(self._skip):

? ? ? ? ? ? obs, reward, done, info = self.env.step(action)

? ? ? ? ? ? self._obs_buffer.append(obs)

? ? ? ? ? ? total_reward += reward

? ? ? ? ? ? if done:

? ? ? ? ? ? ? ? break


? ? ? ? max_frame = np.max(np.stack(self._obs_buffer), axis=0)


? ? ? ? return max_frame, total_reward, done, info


? ? def reset(self):

? ? ? ? """Clear past frame buffer and init. to first obs. from inner env."""

? ? ? ? self._obs_buffer.clear()

? ? ? ? obs = self.env.reset()

? ? ? ? self._obs_buffer.append(obs)

? ? ? ? return obs


class NoopResetEnv(gym.Wrapper):

? ? def __init__(self, env=None, noop_max=30):

? ? ? ? """Sample initial states by taking random number of no-ops on reset.

? ? ? ? No-op is assumed to be action 0.

? ? ? ? """

? ? ? ? super(NoopResetEnv, self).__init__(env)

? ? ? ? self.noop_max = noop_max

? ? ? ? self.override_num_noops = None

? ? ? ? assert env.unwrapped.get_action_meanings()[0] == 'NOOP'


? ? def step(self, action):

? ? ? ? return self.env.step(action)


? ? def reset(self):

? ? ? ? """ Do no-op action for a number of steps in [1, noop_max]."""

? ? ? ? self.env.reset()

? ? ? ? if self.override_num_noops is not None:

? ? ? ? ? ? noops = self.override_num_noops

? ? ? ? else:

? ? ? ? ? ? noops = np.random.randint(1, self.noop_max + 1)

? ? ? ? assert noops > 0

? ? ? ? obs = None

? ? ? ? for _ in range(noops):

? ? ? ? ? ? obs, _, done, _ = self.env.step(0)

? ? ? ? ? ? if done:

? ? ? ? ? ? ? ? obs = self.env.reset()

? ? ? ? return obs

然后是正文

# %load fig4b_5ab

import gym

import torch

import numpy as np

import copy

import random

from tqdm import tqdm

import matplotlib.pyplot as plt

from dqn_wrappers import *


def get_state(obs):

? ? state = np.array(obs)

? ? state = state.transpose((2,0,1))

? ? state = torch.from_numpy(state)

? ? return state.unsqueeze(0).float()


def select_action(state, policy, eps=0.0):

? ? sample = random.random()

? ? if sample > eps:

? ? ? ? with torch.no_grad():

? ? ? ? ? ? return policy(state.to(device)).max(1)[1].view(1,1).item()

? ? else:

? ? ? ? return env.action_space.sample()


def TD0(n_episodes, gamma, alpha0, eta, B, phi0, eps=0.0, cut_factor = 1):

? ? CIs_Q = np.zeros([n_episodes,2])

? ? CIs_SE = np.zeros([n_episodes,2])

? ? Values = np.zeros(n_episodes)

? ? Theta = np.zeros([B+1,p])

? ? Thetabar = np.zeros([B+1,p])

? ? cut = n_episodes // cut_factor

? ? j=0


? ? for i in tqdm(range(n_episodes)):

? ? ? ? done = False

? ? ? ? state = get_state(env.reset())

? ? ? ? while not done:

? ? ? ? ? ? j += 1

? ? ? ? ? ? alpha_t = alpha0 * j**(-eta)

? ? ? ? ? ? W = np.concatenate([[1], np.random.exponential(size=B)])

? ? ? ? ? ? action = select_action(state, policy0, eps=eps)

? ? ? ? ? ? obs_, reward, done, info = env.step(action)

? ? ? ? ? ? state_ = get_state(obs_)

? ? ? ? ? ? phi = model(state.to(device)).detach().to('cpu').numpy().squeeze()

? ? ? ? ? ? phi_ = model(state_.to(device)).detach().to('cpu').numpy().squeeze()

? ? ? ? ? ? At = np.outer(phi, (phi - gamma*phi_))

? ? ? ? ? ? bt = reward*phi

? ? ? ? ? ? Theta += alpha_t*np.diag(W) @ (bt[np.newaxis,:] - Theta@At.T)

? ? ? ? ? ? if j > cut:

? ? ? ? ? ? ? ? Thetabar = ((j-cut-1)*Thetabar + Theta)/(j-cut)

? ? ? ? ? ? else:

? ? ? ? ? ? ? ? Thetabar = Theta

? ? ? ? ? ? state = state_

? ? ? ? values = Thetabar @ phi0

? ? ? ? value = values[0]

? ? ? ? valuesW = values[1:]

? ? ? ? Q = value + np.quantile(value - valuesW, [0.025, 0.975], axis=0)

? ? ? ? SE = float(np.sqrt(np.cov(value - valuesW)))

? ? ? ? SE = value + 1.96*SE*np.array([-1,1])

? ? ? ? Values[i] = value

? ? ? ? CIs_Q[i,:] = Q

? ? ? ? CIs_SE[i,:] = SE


? ? return {'value': Values, 'Q': CIs_Q, 'SE': CIs_SE}


if __name__ == "__main__":

? ? device = torch.device("cpu")


? ? policy0 = torch.load('dqn_pong_model', map_location=torch.device('cpu'))


? ? env = gym.make("PongNoFrameskip-v4")

? ? env.reset()

? ? env = make_env(env)


??

? ? nactions = env.action_space.n


? ? model = copy.deepcopy(policy0)

? ? model.head = torch.nn.Identity()


? ? gamma = 0.99

? ? p = 512

? ? n_episodes = 100

? ? alpha0 = 0.5

? ? eta = 2/3

? ? eps=0.0

? ??


? ? state = get_state(env.reset())


? ? phi0 = model(state.to(device)).detach().to('cpu').numpy().squeeze()


? ? CIs = TD0(n_episodes, gamma,alpha0,eta,200,phi0)


? ? values = CIs['value']

? ? CIs_Q = CIs['Q']

? ? CIs_SE = CIs['SE']


? ? # fig 5b

? ? start=10

? ? plt.rcParams['figure.figsize'] = (16,9)

? ? plt.rcParams['figure.facecolor'] = 'white'

? ? plt.plot(np.arange(n_episodes)[start:], values[start:], color='black', label='value estimate', linewidth=3)

? ? plt.plot(np.arange(n_episodes)[start:], CIs_Q[start:], color = 'blue', label='Q', linewidth=3)

? ? plt.plot(np.arange(n_episodes)[start:], CIs_SE[start:], color = 'red', label='SE', linewidth=3)

? ? handles, labels = plt.gca().get_legend_handles_labels()

? ? by_label = dict(zip(labels, handles))

? ? plt.legend(by_label.values(), by_label.keys(), fontsize='xx-large')

? ? plt.xlabel('Number of episodes', fontsize='xx-large')

? ? plt.xticks(fontsize=20)

? ? plt.yticks(fontsize=20)

? ? plt.savefig('atari_CI_example.png', bbox_inches='tight')

? ? plt.close()


? ? # fig 6a

? ? CIs_Q_widths = CIs_Q[:,1] - CIs_Q[:,0]

? ? CIs_SE_widths = CIs_SE[:,1] - CIs_SE[:,0]


? ? plt.rcParams['figure.figsize'] = (16,9)

? ? plt.rcParams['figure.facecolor'] = 'white'

? ? plt.plot(np.arange(n_episodes), CIs_Q_widths, color = 'blue', label='Q', linewidth=3)

? ? plt.plot(np.arange(n_episodes), CIs_SE_widths, color = 'red', label='SE', linewidth=3)

? ? plt.legend(fontsize='xx-large')

? ? plt.xlabel('Number of episodes', fontsize=20)

? ? plt.xticks(fontsize=20)

? ? plt.yticks(fontsize=20)

? ? plt.savefig('atari_CI_widths.png', bbox_inches='tight')

? ? plt.close()


? ? # fig 6b

? ? eps_seq = np.linspace(0,1,5)

? ? eps_dict = dict()


? ? for eps in eps_seq:

? ? ? ? eps_dict[eps] = TD0(n_episodes, gamma,alpha0,eta,200,phi0,eps)


? ? SE = np.array([x['SE'] for x in eps_dict.values()])

? ? Q = np.array([x['Q'] for x in eps_dict.values()])

? ? values = np.array([x['value'] for x in eps_dict.values()])


? ? plt.rcParams['figure.figsize'] = (16,9)

? ? plt.rcParams['figure.facecolor'] = 'white'

? ? plt.errorbar(eps_seq, values, yerr=Q.T, fmt='.k')

? ? plt.xlabel('epsilon', fontsize=20)

? ? plt.xticks(fontsize=20)

? ? plt.yticks(fontsize=20)

? ? plt.savefig('atari_CI_bars.png', bbox_inches='tight')

? ? plt.close()

然后是報(bào)錯(cuò)信息:

---------------------------------------------------------------------------ValueError ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?Traceback (most recent call last) Input In [2], in <cell line: 66>() ? ? 85 eta = 2/3 ? ? 86 eps=0.0---> 89 state = get_state(env.reset()) ? ? 91 phi0 = model(state.to(device)).detach().to('cpu').numpy().squeeze() ? ? 93 CIs = TD0(n_episodes, gamma,alpha0,eta,200,phi0) Input In [2], in get_state(obs) ? ? 11 def get_state(obs): ? ? 12 ? ? state = np.array(obs)---> 13 ? ? state = state.transpose((2,0,1)) ? ? 14 ? ? state = torch.from_numpy(state) ? ? 15 ? ? return state.unsqueeze(0).float()ValueError: axes don't match array

給newbing debug看的的評論 (共 條)

分享到微博請遵守國家法律
安图县| 共和县| 信丰县| 锦屏县| 双鸭山市| 武义县| 永城市| 大厂| 同江市| 儋州市| 土默特右旗| 定南县| 定州市| 永仁县| 城市| 米泉市| 永和县| 石嘴山市| 延长县| 马边| 平定县| 福建省| 布尔津县| 越西县| 攀枝花市| 喀喇| 娄烦县| 盘山县| 四子王旗| 邢台市| 砀山县| 梧州市| 道真| 唐河县| 东方市| 宁波市| 安顺市| 江川县| 石首市| 麻栗坡县| 乌什县|