强化学习SAC算法流程及与PPO对比

因为最近论文里面用到了这个算法,所以顺便也总结一下

SAC算法流程

主循环中进行动作选择与环境交互

  • Actor网络(策略网络)根据当前状态生成带随机性的动作
  • 执行动作,环境返回下一状态和奖励

update函数当中

目标Q值计算(Critic更新)

  • 双Q网络(Q1, Q2):分别计算当前状态-动作对的Q值()。

  • 目标Q值(Q_target)

    • 下一状态输入Actor(就是策略网络 )得到新动作及对应熵,也就是说 实际上代表的是在下一状态时选择的概率。

    • 目标Critic网络(非训练参数)计算下一动作的Q值,并加熵项:

  • 表示由当前Actor网络根据状态 生成下一个动作。

  • 是从两个目标Critic网络 输出的值中取最小值,目的是避免高估Q值。

    • 是熵调节系数, 表示动作的熵,用于衡量动作的随机性。 表示动作在当前策略下的对数概率(log probability),负号是为了将熵定义为正数。熵越大,表示动作的随机性越强(探索性越高)。

    • 其实我很困惑,为什么不是

      因为熵是 ,肯定为正数,而原式想表达的是减去熵,但是如果不是像我这样写的话,实际上会加上熵?除非是负数。 实际上就是加上熵,熵作为一个奖励,当一个动作的选择概率较小的时候,会适当增加选择它的概率,避免让选择过于确定,这就是熵的作用

  • 最小化Q_current与Q_target的MSE损失,更新critic参数。

Actor策略更新

  • 通过当前状态采样动作,计算更新后的critic和未更新的target_critic的Q值(取双Q最小值)并加熵项:
  • 使用优化器来最小化loss,因为取的是相反数所以实则是最大化该值,以提升策略性能与探索性。
  • 通过梯度上升优化 ,使动作 同时满足:
    • 高Q值(Critic认可的动作)
    • 高熵(鼓励探索,避免策略固化)

目标critic网络软更新

  • 目标Critic参数通过Polyak平均缓慢跟踪主Critic
  • 其中:
    • (例如 0.005),该参数用于控制更新的幅度。通过这种软更新的方式,将当前网络参数 与目标网络参数 进行融合,缓慢同步参数,目的是为了稳定训练过程,避免因参数更新过快导致训练不稳定。

参考代码

虽然这个代码没有考虑熵,以及没考虑两个Q的最小值,并且运行起来没啥结果,但是其他地方我都注释好了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np


# --------------------------
# 状态编码器(极坐标转换 + 特征拼接)
# --------------------------
def encode_state(vehicle_pos, vehicle_speed, nodes_info):
"""
输入:
vehicle_pos: 车辆当前位置 (笛卡尔坐标) [x, y]
vehicle_speed: 车辆速度向量 [vx, vy]
nodes_info: 边缘节点信息列表,每个元素为 [节点x, 节点y, 计算能力, 信任度]

输出:
state_tensor: 编码后的状态张量 (shape: [1, state_dim])
"""
# 将车辆与节点的相对位置转换为极坐标
polar_features = []
for node in nodes_info:
dx = node[0] - vehicle_pos[0]
dy = node[1] - vehicle_pos[1]
r = np.sqrt(dx ** 2 + dy ** 2) # 相对距离
theta = np.arctan2(dy, dx) # 相对角度(弧度)
polar_features.extend([r, theta])

# 速度投影(径向和切向分量)
speed_norm = np.linalg.norm(vehicle_speed)
if speed_norm > 0:
v_r = (vehicle_speed[0] * dx + vehicle_speed[1] * dy) / (r + 1e-5) # 径向速度
v_theta = (vehicle_speed[0] * dy - vehicle_speed[1] * dx) / (r + 1e-5) # 切向速度
else:
v_r, v_theta = 0.0, 0.0

# 假设LSTM隐藏状态(简化为随机向量)
lstm_hidden = np.random.randn(16)

# 拼接所有特征
state = np.concatenate([
polar_features,
[v_r, v_theta],
lstm_hidden,
[node[2] for node in nodes_info], # 节点计算能力
[node[3] for node in nodes_info] # 节点信任度
])

return torch.FloatTensor(state).unsqueeze(0)


# --------------------------
# Actor网络(策略网络)
# --------------------------
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Tanh() # 输出范围[-1,1],需映射到具体动作
)

def forward(self, state):
return self.net(state)


# --------------------------
# Critic网络(Q函数)
# --------------------------
class Critic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
self.q_net = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)

def forward(self, state, action):
return self.q_net(torch.cat([state, action], dim=1))


# --------------------------
# ST-SAC主类
# --------------------------
class ST_SAC:
def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, alpha=0.2, tau=0.005):
# 主网络
self.actor = Actor(state_dim, action_dim)
self.critic = Critic(state_dim, action_dim)
self.critic_target = Critic(state_dim, action_dim) # 新增目标网络
self.critic_target.load_state_dict(self.critic.state_dict()) # 同步初始化

# 优化器
self.actor_optim = optim.Adam(self.actor.parameters(), lr=lr)
self.critic_optim = optim.Adam(self.critic.parameters(), lr=lr)

# 超参数
self.gamma = gamma
self.alpha_base = alpha
self.tau = tau # 目标网络软更新系数

def select_action(self, state, nodes_info, trust_threshold=0.5):
""" 选择动作(含安全掩码) """
with torch.no_grad():
action = self.actor(state)

# 生成动作掩码(信任度 >= 阈值)
valid_nodes = [i for i, node in enumerate(nodes_info) if node[3] >= trust_threshold]
mask = torch.zeros_like(action)
for i in valid_nodes:
mask[:, i] = 1.0 # 假设每个动作维度对应一个节点

# 应用掩码并随机探索
masked_action = action * mask + (1 - mask) * torch.randn_like(action)
return masked_action.squeeze(0).numpy()

def update(self, state, action, reward, next_state, done, vehicle_speed):
# 动态熵系数
v_r = abs(state[0, -len(nodes_info) * 2 + 1])
alpha = self.alpha_base * (1 + torch.sigmoid(torch.tensor(v_r)))

# ----------------- 1. Critic 更新 -----------------
with torch.no_grad():
next_action = self.actor(next_state)
# 使用目标网络计算目标Q值
target_q = self.critic_target(next_state, next_action)
# 当前动作的即时奖励 + 对未来奖励的预测
target_q = reward + (1 - done) * self.gamma * target_q

current_q = self.critic(state, action)
# 目标:让当前Q值逼近目标Q值
critic_loss = nn.MSELoss()(current_q, target_q)

self.critic_optim.zero_grad()
critic_loss.backward()
# 更新critic网络,让网络参数往目标Q值的方向更新
self.critic_optim.step()

# ----------------- 2. Actor 更新 -----------------
# 切断Critic到Actor的梯度传播
pred_action = self.actor(state)
# actor网络的参数更新需要依赖critic网络的输出,critic是更新过的,用来给actor提供指导
# detach()的作用是剥离出一个相同值但不包含梯度的Variable,不参与计算图的构建
q_value = self.critic(state, pred_action).detach() # 关键:detach(),detach()后的梯度不会传播到actor
actor_loss = -q_value.mean() + alpha * (pred_action ** 2).mean()

self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()

# ----------------- 3. 目标网络软更新 -----------------
# 把critic的value更新到target_critic
for t_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
t_param.data.copy_(self.tau * param.data + (1 - self.tau) * t_param.data)


# --------------------------
# 示例使用
# --------------------------
if __name__ == "__main__":
# 假设场景:1辆车,3个边缘节点
vehicle_pos = [0.0, 0.0]
vehicle_speed = [1.0, 0.0] # 沿x轴移动
nodes_info = [
[10.0, 0.0, 5.0, 0.8], # 节点1:正前方,高信任
[0.0, 5.0, 3.0, 0.3], # 节点2:左侧,低信任
[-5.0, 0.0, 4.0, 0.6] # 节点3:后方,中等信任
]

# 初始化ST-SAC
state_dim = len(encode_state(vehicle_pos, vehicle_speed, nodes_info).squeeze(0))
action_dim = len(nodes_info) # 每个动作维度对应一个节点的卸载选择
agent = ST_SAC(state_dim, action_dim)

# 示例训练步骤
state = encode_state(vehicle_pos, vehicle_speed, nodes_info)
for _ in range(100):
# 用到了actor
action = agent.select_action(state, nodes_info)
next_vehicle_pos = [vehicle_pos[0] + 0.1, vehicle_pos[1]] # 模拟移动
next_state = encode_state(next_vehicle_pos, vehicle_speed, nodes_info)
reward = -np.abs(action).mean() # 示例奖励:鼓励集中卸载
done = False

agent.update(state, torch.FloatTensor(action).unsqueeze(0),
torch.FloatTensor([reward]), next_state, done, vehicle_speed)
state = next_state

和PPO算法进行对比

PPO只有一个actor和一个critic,所以他不是靠在线网络和目标网络之间取最小值来约束变化范围的,而是使用近端约束裁切来显式约束步长,从而提高稳定性。PPO可以没有熵,通常依赖基本的随机性或额外的探索策略