百度发布强化学习框架PARL,开源NeurIPS 2018夺冠的训练代码
2019年01月18日 由 浅浅 发表
855001
0
百度发布了PARL,它是一个基于PaddlePaddle的灵活高效的强化学习框架。同时,百度也开源了在NeurIPS 2018假肢挑战赛夺冠的解决方案的代码。
PARL特征
- 可重复:算法可以稳定地再现许多有影响力的强化学习算法的结果
- 规模大:能够支持数千个CPU和多GPU的高性能并行训练
- 可重复使用:存储库中提供的算法可以通过定义前向网络直接适应新任务,并且将自动构建训练机制
- 可扩展:通过在框架中继承抽象类来快速构建新算法
PARL旨在构建用于训练算法的智能体以执行复杂任务。
模型
Model被抽象以构建前馈式网络,该网络定义了策略网络或批评网络的给定状态作为输入。
算法
Algorithm描述了更新参数的机制,Model并且通常包含至少一个模型。
智能体
Agent是环境和算法之间的数据桥梁。它负责与外部的数据I / O,并在进入训练过程之前描述数据预处理。
示例
以下是使用DQN算法为atari游戏构建智能体的示例。
import parl
from parl.algorithms import DQN, DDQN
class AtariModel(parl.Model):
"""AtariModel
This class defines the forward part for an algorithm,
its input is state observed on environment.
"""
def __init__(self, img_shape, action_dim):
# define your layers
self.cnn1 = layers.conv_2d(num_filters=32, filter_size=5,
stride=[1, 1], padding=[2, 2], act='relu')
...
self.fc1 = layers.fc(action_dim)
def value(self, img):
# define how to estimate the Q value based on the image of atari games.
img = img / 255.0
l = self.cnn1(img)
...
Q = self.fc1(l)
return Q
"""
three steps to build an agent
1. define a forward model which is critic_model is this example
2. a. to build a DQN algorithm, just pass the critic_model to `DQN`
b. to build a DDQN algorithm, just replace DQN in following line with DDQN
3. define the I/O part in AtariAgent so that it could update the algorithm based on the interactive data
"""
model = AtariModel(img_shape=(32, 32), action_dim=4)
algorithm = DQN(model)
agent = AtariAgent(algorithm)
安装:
- Python 2.7或3.5+
- PaddlePaddle> = 1.2.1(我们尝试使存储库始终与最新版本的PaddlePaddle兼容)
NeurIPS 2018解决方案
它由三部分组成。第一部分是最终提交的模型,一个可以跟随随机目标速度的合理控制器。第二部分用于课程学习,在低速步行中学习自然而有效的步态。最后一部分在随机速度环境中学习的智能体进行第二轮评估。
开源:
github.com/PaddlePaddle/PARL/tree/develop/examples/NeurIPS2018-AI-for-Prosthetics-Challenge
PARL:
github.com/PaddlePaddle/PARL