DQN算法实战

DQN算法实战知识预习本文不赘述 DQN 的原理 可先看 Q learning 以及 DQN 其他相关文章 参考比如 datawhale 李宏毅笔记 Q 学习 实战的意思就是撸出代码来 在简单了解原理的基础上 根据论文分析 Human levelcontrol

原理简介

DQN是Q-leanning算法的优化和延伸,Q-leaning中使用有限的Q表存储值的信息,而DQN中则用神经网络替代Q表存储信息,这样更适用于高维的情况,相关知识基础可参考datawhale李宏毅笔记-Q学习。

论文方面主要可以参考两篇,一篇就是2013年谷歌DeepMind团队的Playing Atari with Deep Reinforcement Learning,一篇是也是他们团队后来在Nature杂志上发表的Human-level control through deep reinforcement learning。后者在算法层面增加target q-net,也可以叫做Nature DQN。

Nature DQN使用了两个Q网络,一个当前Q网络?用来选择动作,更新模型参数,另一个目标Q网络?′用于计算目标Q值。目标Q网络的网络参数不需要迭代更新,而是每隔一段时间从当前Q网络?复制过来,即延时更新,这样可以减少目标Q值和当前的Q值相关性。

要注意的是,两个Q网络的结构是一模一样的。这样才可以复制网络参数。Nature DQN和Playing Atari with Deep Reinforcement Learning相比,除了用一个新的相同结构的目标Q网络来计算目标Q值以外,其余部分基本是完全相同的。细节也可参考强化学习(九)Deep Q-Learning进阶之Nature DQN。

代码实战

程序代码见github,这里简要说明一下思路。

强化学习基本接口

首先是强化学习训练的基本接口,即通用的训练模式:

for i_episode in range(MAX_EPISODES): state = env.reset() # reset环境状态 for i_step in range(MAX_STEPS): action = agent.choose_action(state) # 根据当前环境state选择action next_state, reward, done, _ = env.step(action) # 更新环境参数 agent.memory.push(state, action, reward, next_state, done) # 将state等这些transition存入memory agent.update() # 每步更新网络 state = next_state # 跳转到下一个状态 if done: break 

如上,首先需要循环多个episode训练,在每个episode中,首先需要重置环境,然后开始探索,为了避免无法收敛或者说无限循环的情况,每个episode加一个MAX_STEPS,接下来的流程如下:

  1. agent选择动作
  2. 环境根据agent的动作反馈出新的state和reward
  3. agent进行更新,如有memory就会将transition(包含state,reward,action等)存入memory中
  4. 跳转到下一个状态
    如果提前done了,就跳出for循环,进行下一个episode的训练。

两个Q网络

前面讲了Nature DQN中有两个Q网络,一个是policy_net,一个是延时更新的target_net,两个网络的结构是一模一样的,如下(见model.py):

import torch.nn as nn import torch.nn.functional as F class FCN(nn.Module): def __init__(self, n_states=4, n_actions=18): """ 初始化q网络,为全连接网络 n_states: 输入的feature即环境的state数目 n_actions: 输出的action总个数 """ super(FCN, self).__init__() self.fc1 = nn.Linear(n_states, 128) # 输入层 self.fc2 = nn.Linear(128, 128) # 隐藏层 self.fc3 = nn.Linear(128, n_actions) # 输出层 def forward(self, x): # 各层对应的激活函数 x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) return self.fc3(x) 

输入为state,输出为action,注意根据state和action的维度调整隐藏层的层数,这里设为128

agent.py中我们定义强化学习算法,包括choose_actionupdate两个主要函数,初始化中:

self.policy_net = FCN(n_states, n_actions).to(self.device) self.target_net = FCN(n_states, n_actions).to(self.device) # target_net的初始模型参数完全复制policy_net self.target_net.load_state_dict(self.policy_net.state_dict()) self.target_net.eval() # 不启用 BatchNormalization 和 Dropout # 可查parameters()与state_dict()的区别,前者require_grad=True 

可以看到policy_net跟target_net结构和初始参数一样,但在更新的时候target是每隔一段episode更新的,如下(见main.py):

# 更新target network,复制DQN中的所有weights and biases if i_episode % cfg.target_update == 0: agent.target_net.load_state_dict(agent.policy_net.state_dict()) 

可以调整cfg.target_update,注意该变量不要调得太大,否则会收敛很慢,我们最后保存的模型也是这个target_net,如下(见agent.py):

def save_model(self,path): torch.save(self.target_net.state_dict(), path) 

Replay Memory

然后就是Replay Memory了,如下(见memory.py):

import random import numpy as np class ReplayBuffer: def __init__(self, capacity): self.capacity = capacity self.buffer = [] self.position = 0 def push(self, state, action, reward, next_state, done): if len(self.buffer) < self.capacity: self.buffer.append(None) self.buffer[self.position] = (state, action, reward, next_state, done) self.position = (self.position + 1) % self.capacity def sample(self, batch_size): batch = random.sample(self.buffer, batch_size) state, action, reward, next_state, done = zip(*batch) return state, action, reward, next_state, done def __len__(self): return len(self.buffer) 

其实比较简单,主要包括push和sample两个步骤,push是将transitions放到memory中,sample是从memory随机抽取一些transition。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/218963.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月17日 下午11:06
下一篇 2026年3月17日 下午11:06


相关推荐

  • 怎么同时运行两个tomcat?

    怎么同时运行两个tomcat?转载至:http://ask.zol.com.cn/x/4522378.html这几天由于在搞那个jenkins的自动部署项目所以要使用到两个tomcat(因为一个tomcat不能同时开着两个项目),一个作为jenkins服务器,一个作为项目部署服务器,所以找了一些资料看看一台电脑怎么运行两个tomcat。第一步:先下载两个tomcat(不同版本的也行,笔者用的是一个tomcat7,一个…

    2022年6月15日
    39
  • 【转】VS2013 产品密钥 – 所有版本[通俗易懂]

    【转】VS2013 产品密钥 – 所有版本[通俗易懂]VS2013产品密钥–所有版本VisualStudioUltimate2013KEY(密钥):BWG7X-J98B3-W34RT-33B3R-JVYW9VisualStudioPremium2013KEY(密钥):FBJVC-3CMTX-D8DVP-RTQCT-92494VisualStudioProfessional2013KEY(密钥):XD…

    2022年5月19日
    48
  • oracle11g安装后在哪打开_oracle数据库11g安装教程

    oracle11g安装后在哪打开_oracle数据库11g安装教程oracle11gRAC之srvctl命令:srvctl版本查询:[grid@dbserver1:/home/grid]$srvctl-Vsrvctlversion:11.2.0.1.0srvctl常用命令分类:*增加oracleasm/database/listener注册信息eg:srvctladdasm-lLISTENER-p+crs/dbserve-cluster/…

    2025年10月28日
    5
  • 运算放大电路在音频放大电路中的应用研究与实现「建议收藏」

    运算放大电路在音频放大电路中的应用研究与实现「建议收藏」1、导言放大电路是构成各种功能模拟电路的基础电路,也是对模拟信号最基本的处理。音频信号可以分解成若干频率的正玄波之和,其频率分为在20Hz~20KHz。不当的放大电路会造成音频信号的失真,亦会带来干扰和噪声。所有电子信息系统组成的原则都应包含:1、满足功能和性能要求,2、尽量简单,3、电磁兼容,4、调试应用简单。因此本文就来研究在不会增大电路复杂度的前提下,如何实现音频信号放大的同时对信号进…

    2022年5月29日
    37
  • php 开启opcode,php 开启 opcode 测试

    php 开启opcode,php 开启 opcode 测试php 开启 opcode 测试 合理使用 实验环境系统信息 Linuxlocalho localdomain3 10 0 514 10 2 el7 x86 64 1SMPFriMar30 04 05UTC2017x86 64×86 64×86 64GNU Linux 内存 512MCPU1 核 PHP 版本 PHP7 0 21 amp ZendOPcach

    2025年7月3日
    6
  • arduino中Keypad 库函数介绍

    arduino中Keypad 库函数介绍原文:https://playground.arduino.cc/Code/Keypad/Creation构造函数:Keypad(makeKeymap(userKeymap),row[],col[],rows,cols)constbyterows=4;//fourrowsconstbytecols=3;//threecolumnscharkeys[rows][cols]={{‘1′,’2′,’3’},{‘4′,’5′,’6’},{‘

    2022年6月7日
    38

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号