模块设计 🏗
每个算法以 Model📦
, Algorithm👣
, Agent🤖
, Trainer🔁
四个类为主组成,并以组合的方式进行交互。
Model📦
:定义单个或多个前向网络;输入是环境状态,输出是网络的原始输出。
Algorithm👣
:定义 Model📦
的更新算法和 Model📦
输出的后处理(argmax
, ...)。
Agent🤖
:定义 Algorithm👣
与环境交互的接口和训练数据的预处理。
Trainer🔁
:定义 Agent🤖
的整体训练流程和辅助训练的工具(Buffer
, ...)。
调用 Trainer.__call__
函数将得到一个生成器📽,该生成器保存了训练流程和所有相关数据。生成器每步返回一个 log_data
训练日志📒,持续调用该生成器即可完成训练并得到所有 log_data
。
Logger📊
部分使用 Tensorboard 和 Weights & Biases 记录训练日志。对 Trainer.__call__
函数进行装饰,具体实现见核心代码。
abstractions.py |
---|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102 | class Model(nn.Module):
def __init__(self, config: dict[str, Any]) -> None:
pass
def value(self, x: torch.Tensor, a: Optional[torch.Tensor] = None) -> tuple[Any]:
# 返回 单个 或 多个 critic 的输出值
pass
def action(self, x: torch.Tensor) -> tuple[Any]:
# 返回 动作 | 动作概率分布
pass
class Algorithm:
def __init__(self, config: dict[str, Any]) -> None:
self.model = Model(config)
# 1. 初始化 model, target_model
# 2. 初始化 optimizer
pass
def predict(self, obs: torch.Tensor) -> tuple[Any]:
# 返回 动作 | 动作概率分布 | Q函数的预估值
pass
def learn(self, data: BufferSamples) -> dict[str, Any]:
# 根据训练数据(观测量和输入的reward),定义损失函数,用于更新 Model 中的参数。
# 1. 计算目标
# 2. 计算损失
# 3. 优化模型
# 4. 返回训练信息
pass
def sync_target(self) -> None:
# 同步 model 和 target_model
pass
class Agent:
def __init__(self, config: dict[str, Any]) -> None:
self.alg = Algorithm(config)
# 1. 初始化 Algorithm
# 2. 初始化 运行步数变量
pass
def predict(self, obs: np.ndarray) -> np.ndarray:
# 1. obs 预处理 to_tensor & to_device
# 2. Algorithm.predict 得到 act
# 3. act 后处理 to_numpy & to_cpu
# 4. 返回评估使用的 act
pass
def sample(self, obs: np.ndarray) -> np.ndarray:
# 1. obs 预处理 to_tensor & to_device
# 2. Algorithm.predict 得到 act
# 3. act 后处理 to_numpy & to_cpu
# 4. 返回训练使用的 act
pass
def learn(self, data: BufferSamples) -> dict[str, Any]:
# 数据预处理
# 调用 Algorithm.learn
# 返回 Algorithm.learn 的返回值
pass
class Trainer:
@dataclasses.dataclass
class Config:
exp_name: Optional[str] = None
seed: int = 1
# ...
def __init__(self, config: Config = Config()) -> None:
self.agent = Agent(config)
# 1. 初始化参数
# 2. 初始化训练和评估环境
# 3. 初始化 Buffer
# 4. 初始化 Agent
pass
def __call__(self) -> Generator[dict[str, Any], None, None]:
# 1. 规定训练流程
# 2. 返回一个生成器,生成器每步返回一个 log_data 字典
pass
def _run_collect(self) -> dict[str, Any]:
# 1. 采样一步,并加入到 Buffer 中
# 2. 返回 log_data 字典
pass
def _run_train(self) -> dict[str, Any]:
# 1. 从 Buffer 取出一组训练数据
# 2. 训练单步
# 3. 返回 log_data 字典
pass
if __name__ == "__main__":
trainer = Trainer()
for log_data in trainer():
print(log_data)
|
最后更新:
2023-03-01