基于强化学习的贪吃蛇游戏(五)——智能体结构优化

在上一篇文章中,我们使用DQN算法实现了一个基于深度神经网络的强化学习智能体。虽然通过神经网络的函数近似突破了Q-learning算法的局限性,但当前的实现仍然存在一些改进空间。本文将从状态建模和网络结构两个角度对智能体进行优化,以进一步提升智能体的学习效果。

log😄😅=💧\log_{😄}😅=💧

混合状态建模

在之前的实现中,我们仅采用了一个3×12×123 \times 12 \times 12的特征矩阵来表示游戏状态。虽然这种方式保留了完整的空间信息,但对于智能体来说可能并不是最理想的状态表示方式。

在本次优化中,我们将游戏状态建模为特征矩阵和人工特征的混合表示。(地图大小更改为12×2412 \times 24):

  1. 特征矩阵表示:使用一个12×2412 \times 24的矩阵对游戏画面进行编码
  • 食物位置对应的元素设为1
  • 蛇头位置对应的元素设为0.5
  • 蛇身位置对应的元素设为-0.5
  • 其他位置元素设为0
  1. 人工特征表示:总结游戏的关键信息为抽象特征向量
  • 食物相对于蛇头的方位特征:使用(-1,0,1)表示(上/中/下,左/中/右)
  • 蛇头四个方向的障碍物特征:使用(0/1)表示(无/有)障碍物
优化后的状态表示

图1 优化后的状态表示

这种混合状态表示的优势在于:

  • 特征矩阵保留了完整的游戏空间信息
  • 人工特征提供了更直接的决策依据
  • 不同形式的特征互为补充,提供了更丰富的状态表达

网络结构优化

为了更好地处理混合状态输入,我们设计了一个双输入的神经网络结构,如图1所示:

优化后的网络结构

图2 优化后的网络结构

  1. 特征矩阵处理分支:
  • 输入层:1×12×241 \times 12 \times 24的特征矩阵
  • 卷积层:使用16个3×33 \times 3的卷积核提取空间特征
  • 全连接层:将卷积特征压缩至10维向量
  1. 人工特征处理分支:
  • 输入层:6维的人工特征向量
  • 全连接层:直接与卷积特征分支的输出拼接
  1. 合并层:
  • 将两个分支的特征向量拼接为16维向量
  • 通过全连接层映射到4维输出,对应四个动作的Q值

网络结构的具体实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Sequential(
nn.ConstantPad2d(1, -1),
nn.Conv2d(1, 16, kernel_size=3), # 16x12x24
nn.ReLU(),
)
self.conv_fc = nn.Sequential(
nn.Flatten(),
nn.Linear(16*12*24, 10),
nn.ReLU(),
)
self.fc = nn.Sequential(
nn.Linear(16, 4),
)

def forward(self, x):
input_tensor, feature_matrix = x
input_tensor = torch.tensor(input_tensor)
feature_matrix = torch.tensor(feature_matrix)

feat = self.conv(feature_matrix)
feat = self.conv_fc(feat)
x = torch.cat((feat, input_tensor), dim=1)
return self.fc(x)

状态获取优化

对应新的状态表示方式,我们需要修改状态获取函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def get_state(self, snake, food, done=False):
if done:
feature_matrix = np.zeros((1, 1, 12, 24), dtype=np.float32)
input_tensor = np.zeros((1, 6), dtype=np.float32)
return input_tensor, feature_matrix

# 构建特征矩阵
snake_head = snake.body[0]
snake_body = snake.body[1:]
feature_matrix = np.zeros((12, 24), dtype=np.float32)

# 设置食物、蛇头和蛇身的位置
feature_matrix[food.rect.top // UNIT, food.rect.left // UNIT] = 1
feature_matrix[snake_head.top // UNIT, snake_head.left // UNIT] = 0.5
for body in snake_body:
feature_matrix[body.top // UNIT, body.left // UNIT] = -0.5

feature_matrix = feature_matrix.reshape(1, 1, 12, 24)

# 构建人工特征向量
# 食物相对位置
vertical_position = np.sign(food.rect.top - snake.body[0].top) + 1
horizontal_position = np.sign(food.rect.left - snake.body[0].left) + 1

# 障碍物检测
state_surround = [0] * 4
for i, direction in enumerate(DIRECTIONS):
left, top = snake.body[0].topleft + np.array(DIRECTIONS[direction]) * UNIT
# 检测边界
if left < 0 or left > SCREEN_X or top < 0 or top > SCREEN_Y:
exist = True
# 检测身体
elif (left, top) in [body.topleft for body in snake.body[1:]]:
exist = True
else:
exist = False
state_surround[i] = exist

input_tensor = np.array([vertical_position, horizontal_position, *state_surround])
input_tensor = input_tensor.reshape(1, 6)

return input_tensor, feature_matrix

训练效果

优化后的模型训练效果如图2所示:

优化后的训练曲线

图3 优化后的训练曲线

从训练曲线可以看出:

  1. 学习速度显著提升:在前500回合内模型就表现出明显的学习效果
  2. 最终性能提高:平均得分从原来的2-8分提升到15-17分左右

这种性能的提升主要得益于:

  • 混合状态表示提供了更丰富的环境信息
  • 双输入网络结构能更好地处理不同形式的特征
  • 人工特征为智能体提供了更直接的决策依据

小结

本文通过优化状态表示和网络结构,显著提升了DQN智能体在贪吃蛇游戏中的表现。混合状态建模和双输入网络的设计思路也为类似的强化学习任务提供了有益的参考。这些优化不仅提高了智能体的学习效率和最终性能,也增强了模型的稳定性,展现了深度强化学习在复杂环境中的应用潜力。


基于强化学习的贪吃蛇游戏(五)——智能体结构优化
http://dufolk.github.io/2024/12/22/snake-4/
作者
Dufolk
发布于
2024年12月22日
许可协议