結合強化學習DQN與遊戲開發:使用Python實現貪食蛇並訓練AI學習遊戲策略

Jiiiiing
Aiii
Published in
17 min readJun 15, 2023

嘶嘶嘶嘶嘶snake~𓆙𓆙𓆙

貪食蛇這款考驗反應力的經典街機遊戲想必大家都不陌生,如今它不只是個遊戲,更是一個應用強化學習技術的大平台!本文將使用Python設計和實現貪食蛇遊戲開始,再介紹Deep Q-Learning(DQN)算法的基本原理和實現過程,讓 AI 學習貪食蛇的遊戲規則並獲取高分!準備好開始設計和訓練你自己的AI貪食蛇了嗎?讓我們一起展開這個有趣的旅程吧!

𓆙 建立貪食蛇遊戲

𓆗 遊戲規則

先說明遊戲規則,等等依序介紹程式碼實現過程!

  1. 新遊戲開始,蛇頭每次從遊戲畫布中心出現。
  2. 蘋果隨機放置並確保不放在蛇身體上。
  3. 每次移動 1 格,只能轉「-90、0°、90° 」,也就是「左轉、直走和右轉」3種方向。
  4. 吃到蘋果得一分 (^o^)/🍎
  5. 當蛇碰到自己的身體或牆壁遊戲結束 (。•́︿•̀。)

確認好遊戲規定後,就可以一一實現了!我們使用 pygame 去建立遊戲視窗、遊戲畫布的呈現,以及調整遊戲畫面更新速率(讓蛇移動更快的感覺!),下方程式碼著重在遊戲規則!所以對於 pygame 的設置就不會多做介紹~

𓆗 程式碼 — 蛇頭和蛇身

  1. 使用 namedtuple 定義蛇頭的座標變數(x, y),並設定遊戲一開始蛇頭會出現在畫布正中心!
  2. 同時預設蛇的身體長度為 2(加頭長度為 3),會往畫布右邊前進,因此在蛇頭的左邊兩格都屬於蛇的身體。下方為將蛇頭和蛇身的位置儲存在 snake 變數裡的程式碼:
from collections import namedtuple
Point = namedtuple('Point', 'x, y') #創建點的座標

# 蛇頭初始位置
self.head = Point(self.w/2, self.h/2)
self.snake = [self.head,
Point(self.head.x-BLOCK_SIZE, self.head.y),
Point(self.head.x-(2*BLOCK_SIZE), self.head.y)]

𓆗 程式碼 — 隨機放置蘋果🍎

  1. 透過前一步驟定義好蛇的身體後,就可以決定蘋果🍎的位置!
  2. 根據畫布大小決定蘋果隨機生成的位置,若蘋果位置出現在蛇的身體上就會再生一次位置。
def _place_food(self): # 隨機放置食物
x = random.randint(0, (self.w - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE
y = random.randint(0, (self.h - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE
self.food = Point(x, y)
if self.food in self.snake: #確保生成的食物不在蛇的身體上
self._place_food()

𓆗 程式碼— 蛇頭轉動的方向

  1. 蛇頭轉動的方向:先定義「左轉、直走和右轉」的向量:
    [0, 0, 1] → left turn
    [1, 0, 0] → straight
    [0, 1, 0] → right turn
  2. 先按照「順時針」定義出四個方向:右、下、左、上,設置四個方向的目的為:蛇頭可以根據目前行走方向、需轉動的方向來決定蛇頭需移動的 x 或 y 座標。
  3. 舉例:蛇頭目前往右移動(畫布的右),此時需蛇要往下移動,等於蛇頭是往右轉(順時針);或蛇頭目前朝下移動,此時蛇需要往右移動,等於蛇頭往左轉(逆時針)。

因此定義 idx 擷取現在移動的方向,並根據「直走、左轉、右轉」等決定方向是否不變、逆時針或順時針。而若順時針,則根據原始 clock wise 定義將 idx+1,且使用 4 的餘數是因為能確保讓 idx 在 0~3 之間移動。

clock_wise = [Direction.RIGHT, Direction.DOWN, Direction.LEFT, Direction.UP] # 定義順時針的四個方向序列,對應到遊戲中的四個方向。
idx = clock_wise.index(self.direction) # 現在目前的方向

#決定蛇頭的下一步移動方向,並移動蛇頭到新的位置
if np.array_equal(action, [1, 0, 0]): # straight,等於沒有轉動
new_dir = clock_wise[idx] # 所以 idx 還是原本的方向
elif np.array_equal(action, [0, 1, 0]): # 右轉
next_idx = (idx + 1) % 4 # 確保 idx 會在 0,1,2,3 之間跑動
new_dir = clock_wise[next_idx] # right turn(順時針) r -> d -> l -> u
else: # [0, 0, 1] 是左轉
next_idx = (idx - 1) % 4
new_dir = clock_wise[next_idx] # left turn(逆時針) r -> u -> l -> d

self.direction = new_dir

x = self.head.x
y = self.head.y
if self.direction == Direction.RIGHT:
x += BLOCK_SIZE
elif self.direction == Direction.LEFT:
x -= BLOCK_SIZE
elif self.direction == Direction.DOWN:
y += BLOCK_SIZE
elif self.direction == Direction.UP:
y -= BLOCK_SIZE

self.head = Point(x, y)

𓆗 程式碼 — 遊戲過程

  1. 當蛇開始移動,我們可以透過前述步驟得知目前蛇的位置以及蛇頭要轉的方向,推算出蛇頭應該移動的下一個位置。但是,我們同時須考慮到蛇會不會吃到蘋果,如果吃到蛇身會變長,反之則不變。
  2. 在記錄蛇的位置時,順序是由頭到尾,因此可以透過判定新的蛇頭是否與蘋果位置重疊來決定是否刪除蛇尾的位置,且同時判斷是否需要生成新的蘋果。
self._move(action)
self.snake.insert(0, self.head) #當蛇頭向某個方向移動時,將新的蛇頭位置插入到列表的開頭,表示蛇頭已經移動到新位置。

if self.head == self.food:
self._place_food()
else:
self.snake.pop() #模擬蛇在未吃到食物時的移動,使蛇能在不改變長度的情況下向前移動一格

𓆗 程式碼 — 碰撞!

  1. 貪食蛇遊戲最重要的是避免蛇頭碰撞到畫布邊界和自身的身體。所以我們可以藉由判定蛇頭是否超出邊界、或是否與自己身體位置重疊,寫出判斷蛇是否碰撞的函數。
def is_collision(self, pt = None):
if pt is None:
pt = self.head # 蛇頭
if pt.x > self.w - BLOCK_SIZE or pt.x < 0 or pt.y > self.h - BLOCK_SIZE or pt.y < 0: #確保蛇的整個頭部都沒有超出遊戲邊界
return True
if pt in self.snake[1:]: #蛇頭與身體發生碰撞
return True

return False

綜合以上遊戲規則以及程式碼,我們就可以透過 pygame 完成簡單的貪食蛇介面,接下來我們就要來開始使用 DQN 來讓我們的蛇變得越來越聰明!

𓆙 Q-Learning 概述

Q-Learning 的目標就是通過學習一個 Q-value function,讓 Agent 學習到在不同狀態下選擇最佳動作的策略,從而在強化學習任務中取得良好的性能。

𓆗 Q-value function

Q-value function是屬於一種動作價值函數(State-Action Value Function),用於衡量在給定狀態和動作下的期望累積回報。Q(s, a)的意涵為:計算在 state s 時,選擇 action a 後,再根據 policy 走的話能獲得的 total expected reward。因此透過計算 Q-value function,Agent 可以比較不同動作的 Q 值,選擇具有最大 Q 值的動作作為其策略的一部分。

而我們要如何估計 State-Action Value Function 呢?以下有兩種方法:

  1. Monte-Carlo(MC)based approach:在 Monte Carlo 方法中,Agent 會經歷完一整個回合或任務後,才根據收集到的所有的數據(包括 state、action和 reward)去更新 Q-value。
  2. Temporal-difference(TD)approach:TD 方法不需等到遊戲結束,只需當前的狀態、動作和所獲得的獎勵就可以去更新 Q-value。

下圖為採用 Bellman equation 和 TD 方法更新 Q-value 的過程:

𓆗 Epsilon Greedy

若我們每是只選擇 Q-value 最大的 action 動作這樣並不好,因為一定要進行過某個 state 和 action 才可以進行估計。所以為了讓 Agent 也可以進行其他動作,我們會設定一個介於 0~1 的值稱作 Epsilon(ε)。在訓練過程中,會有 ε 的機率採用隨機動作;(1-ε)的機率會選擇具有最大的 Q-value 動作進行。在實作中,Epsilon 會隨著時間的增長遞減。

𓆗 Q-Learning

傳統的 Q-Learning 是一種基於表格的方法,它建立一個表格(Q table)來儲存狀態-動作對的價值。這個表格根據 Agent 與環境的交互來學習最優策略。然後 Agent 根據一個策略(例如ε-greedy策略)來選擇動作,並在新的狀態中重複該過程。綜合以上 Q-Learning 算法的步驟如下:

  1. 初始化 Q-value function,可以使用隨機值或其他方式進行初始化。
  2. Agent 與環境互動,根據當前的狀態選擇一個動作並收集環境的反饋,包括新的狀態和即時獎勵。
  3. 根據當前的 state s_t、action a_t 和 觀測到的 reward r,以及下一個 state s_t+1,使用 TD 算法更新 Q-value function。
  4. 再使用學習到的 Q-value ,並根據當前狀態選擇最佳的策略。

𓆙 Deep Q-Learning

DQN 是將深度神經網絡應用於 Q-Learning 的延伸,在 DQN 中模型的輸入是狀態(state),而模型的輸出是每個可能動作(action)的 Q 值(Q-value)。

以下是此次建立貪食蛇使用的 DQN 背景介紹,此次省略 Target Network 的機制,Target Network 的目的是在提高 Q-Learning 算法的穩定性和收斂性。此次只透過損失函數和反向傳播去更新神經網路的參數仍可以得到良好的結果,以下是 DQN 和傳統 Q-Learning 的主要差異:

  1. DQN 初始化一個深度神經網絡(Deep Neural Network, DNN)近似 Q 值函數。DQN 的輸入是狀態 state 並輸出每個可能動作 action 的 Q 值。
  2. Replay Buffer 裏紀錄著每一筆資料的狀態state、動作action、和 reward,我們可以設定要保存的筆數,可以是不同的策略(policy)的結果。當Replay Buffer 儲存滿了就會將最舊的資料丟棄。訓練的時候就會從 Buffer 中取出資料來更新 Q-value function,此時就是屬於 off-policy 的做法。
  3. 損失函數和反向傳播:使用目標 Q 值和預測的 Q 值計算均方誤差(Mean Squared Error,MSE)當作損失函數。並通過反向傳播算法來更新神經網絡的參數,以最小化損失函數。
  4. 重複迭代:重複執行步驟2和步驟3,通過不斷從環境中獲得的信息來更新神經網絡的參數。

總結來說,Deep Q-Learning和Q-Learning的主要區別在於近似Q值函數的方式。Q-Learning使用一個表格來存儲Q值,而Deep Q-Learning使用一個深度神經網絡來學習Q值函數的近似。Deep Q-Learning通常適用於狀態空間較大且較複雜的問題,而Q-Learning則適用於狀態空間較小且較簡單的問題。

𓆙 貪食蛇 + Deep Q-Learning!

初步理解 DQN 之後,我們要怎麼應用在 DQN 上呢?所以我們有以下的目標需要一一處理:

  1. 定義輸入的 state、輸出的 action 格式
  2. 定義執行某個 action 所獲得的 reward
  3. 設定 Epsilon Greedy 的權重、決定執行的動作
  4. 設定 Replay Buffer(memory)

𓆗 定義輸入的 state、輸出的 action 格式

  1. 輸入的 states 包含11種狀態,分別為:危險方向(3種)、目前方向(4種)、以及蘋果的方向(4種)。值分別為 0、1的二進制變量。
    - 危險方向【 danger straight, danger right, danger left 】
    - 目前方向【direction left, direction right, direction up, direction down】,每次只有一項會是 1,其餘是 0。
    - 蘋果的方向 【food left, food right, food up, food down】
  2. 輸出的格式包含 3 種 action 的 values,分別為:左轉、直走和右轉。其中三種方向都以長度為 3 的向量呈現。
    - [1, 0, 0] → straight
    - [0, 1, 0] → right turn
    - [0, 0, 1] → left turn
    但神經網路的輸出不一定會是 0 和 1,所以會根據輸出向量裡最大的值為1,其餘皆為 0。例如輸出的向量為[0.4, 3, 6.2],根據最大元素的位置對應為[0, 0, 1],此時就會向左轉。
  3. 程式碼:先定義出蛇頭四個方向的位置(不管目前方向為何,以畫布的上下左右為準),並判斷目前的方向。再使用前述的碰撞函數判斷可能會遇到的碰撞,最後再判斷食物相對於蛇頭的位置並把所有 states 儲存成向量。
def get_state(self, game): 
head = game.snake[0]
#先定義好蛇頭下一步可能的位置(4個方向)
point_l = Point(head.x - 20, head.y)
point_r = Point(head.x + 20, head.y)
point_u = Point(head.x, head.y - 20)
point_d = Point(head.x, head.y + 20)

#判斷目前行進的方向,只會有一個變數的值會是1
dir_l = game.direction == Direction.LEFT
dir_r = game.direction == Direction.RIGHT
dir_u = game.direction == Direction.UP
dir_d = game.direction == Direction.DOWN

#建立狀態向量
state = [
## 直行會遇到的碰撞
(dir_r and game.is_collision(point_r)) or #蛇朝右方向前進,直走時蛇頭會和(畫布)右邊的相鄰位置碰撞
(dir_l and game.is_collision(point_l)) or
(dir_u and game.is_collision(point_u)) or
(dir_d and game.is_collision(point_d)),

## 右轉(順時針)會遇到的碰撞
(dir_u and game.is_collision(point_r)) or #蛇朝上方前進,向右轉時蛇頭會和(畫布)右邊的相鄰位置碰撞
(dir_d and game.is_collision(point_l)) or
(dir_l and game.is_collision(point_u)) or
(dir_r and game.is_collision(point_d)),

## 左轉(逆時針)會遇到的碰撞
(dir_d and game.is_collision(point_r)) or #蛇朝下方前進,向左轉時蛇頭會和(畫布)右邊的相鄰位置碰撞
(dir_u and game.is_collision(point_l)) or
(dir_r and game.is_collision(point_u)) or
(dir_l and game.is_collision(point_d)),

dir_l,
dir_r,
dir_u,
dir_d,

game.food.x < game.head.x, # 食物在左邊
game.food.x > game.head.x,
game.food.y < game.head.y, # 食物在上面
game.food.y > game.head.y
]

return np.array(state, dtype=int)

定義蛇頭可能出現的位置:

輸入的 state 情境範例:

𓆗 定義執行某個 action 所獲得的 reward

  1. 蛇吃到蘋果🍎 reward「加10分」。
  2. 蛇撞到牆壁、蛇身,reward「減10分」。
  3. 蛇行動超過100步但遊戲沒結束(不吃蘋果),reward「減10分」。會加上這個條件是因為我們希望蛇不只單純學習到避開相撞,還能學習到去吃蘋果。
  4. 其他動作 reward 不變。

𓆗 設定 Epsilon Greedy 的權重、決定執行的動作

我們採用 Epsilon Greedy 讓 Agent 隨機選擇執行動作。我們設定的方式是 Epsilon 減去玩遊戲的局數,初始設定為「80 減掉遊玩的局數」,因此玩越多局、epsilon 就越小。再設定一個隨機的值介於 0~200 之間,當此值小於 epsilon 就會採取隨機的動作。雖然 epsilon 越小,執行動作就會越不隨機,但透過設定一個介於 0~200 之間隨機的值仍可以讓 Agent 執行不死板。

def get_action(self, state):
# random moves: tradeoff exploration / exploitation
self.epsilon = 80 - self.n_games # n_games 是遊戲的局數,所以玩越多遊戲,epsilon就會越小
final_move = [0, 0, 0]
if random.randint(0, 200) < self.epsilon: # 如果epsilon越小、範圍越小,move就會越不隨機、死板
move = random.randint(0, 2)
final_move[move] = 1
else:
# 所以當遊戲局數越多,就會減少隨機步驟而依賴我們之前model裡面的權重
state0 = torch.tensor(state, dtype = torch.float)
prediction = self.model(state0)
move = torch.argmax(prediction).item()
final_move[move] = 1

return final_move

𓆗 設定 Replay Buffer(memory)

更新 Q-value function 的方式可以分為 short memory 和 long memory。short memory 儲存每次行動的state、action、reward等表現來更新下一步驟的 Q-value function。long memory 為從 Replay Buffer 裏面隨機抽出 batch size 大小的資料去更新 Q-value function,而 long memory 是當一局遊戲結束才會進行訓練。

𓆙 Training!看 DQN 的成效!

從下方的demo影片可見隨著訓練次數上升,遊戲的平均得分越來越高,代表上述簡略的DQN架構也可以獲得良好的結果!但我們也可以看到訓練越多次、平均得分的上升速度越慢,因此未來想要更進一步提高算法的穩定性和收斂性,可以考慮實現 Target Network 機制並將其應用到程式碼中,或者使用更複雜的神經網路去訓練 AI!

影片有縮時20倍!

--

--