BLOG
인공지능

강화학습: 파이썬으로 구현하는 SARSA, 살사 알고리즘


May 4, 2022, 11:35 p.m.



1. SARSA의 정의


살사는 시간차 제어를 사용하는 알고리즘입니다. 가치함수를 시간차 예측으로 업데이트 하면서 정책은 따로 존재하지 않고 현재 상태에서 가장 큰 가치를 가지는 행동을 하는 탐욕적 정책을 사용합니다.

정책 를 수식으로 나타내 볼까요?

즉 상태 s에서 가장 큰 큐함수를 가지는 행동 a로 행동하는 정책이며 argmax를 사용해서 같은 큐함수 값이 있을 경우 그 행동들을 모두 반환합니다. 그 중에서 확률적으로 행동하게 되겠죠?

그렇다면 우리는 각 상태에 대한 큐함수 테이블을 가지고 있어야 합니다. 그리고 그 큐함수를 업데이트 해나가야 하죠. 이전에 시간차 제어에서 알아본 시간차 제어 업데이트 공식을 큐함수를 통해서 표현해보면 아래와 같습니다.

식을 보면 알 수 있듯이, 현재 행동에 대한 큐함수를 업데이트 하려면 현재 큐함수 , 행동을 통해 얻는 보상 , 그리고 다음상태 큐함수 를 알아야합니다. 큐함수는 상태 s와 a를 알아야 알 수 있으므로 큐함수당 s, a 1개씩 총 5가지 정보를 알아야 하겠네요!

순서대로 읽어보면 S A R S A, 살사네요! 그래서 이 알고리즘의 이름이 살사입니다.

즉 이 5가지의 정보가 한 개의 샘플이 되어 큐함수를 업데이트 합니다. 에이전트는 상태 일때 앞서 말한 탐욕적 정책으로 행동 를 하고, 그에 따른 결과로 보상 를 얻습니다. 그리고 다음상태 로 이동하죠. 여기서 탐욕적 정책으로 다시 행동 를 합니다. 5가지 정보 샘플이 만들어 졌네요. 이 시점에서 큐함수 를 업데이트 합니다.

자, 여기서 중요한 점은 에이전트가 직접 행동하고 경험해본 큐함수만 업데이트 할 수 있다는 것입니다. 이대로 두면 잘못된 길로 빠져 돌아올 수 없게 됩니다. 이를 위해 에이전트는 '탐험'이라는 것을 해야합니다.

처음에는 탐욕적 정책보다는 무작위로 행동하는 탐험을 실행해서 다양한 상태를 경험해 보아야 합니다. 그래야 그 중에서 좋은 행동을 찾아낼 수 있는 것입니다. 의 확률 만큼 탐험을 한다고 해서 - 탐욕 정책이라고 부릅니다.

처음에는 값을 크게 해서 다양한 경험을 하게 합니다. 그러나 나중에 큐함수가 최적화가 되어도 계속 탐험을 하면 안되기 때문에 값을 줄여나가 나중에는 정책대로 행동하게 합니다.

정리해보자면 에이전트는 - 탐욕 정책으로 다양한 SARSA 샘플을 얻어냅니다. 그리고 이 샘플을 얻어낼때마다 큐함수를 업데이트 합니다.

그럼 이제 살사를 파이썬으로 구현해 볼까요?

2. 그리드 월드


파이썬으로 구현하기에 앞서 살사를 적용할 간단한 환경이 필요합니다. 바로 그리드월드 입니다. 많은 강화학습 자료들에서 보면 그리드 월드에 적용해보는 예제들이 있습니다. 이 포스트에서는 제가 직접 파이썬으로 작성한 그리드월드를 사용해 보도록 하겠습니다.

그리드월드 환경 코드는 강화학습: 파이썬으로 구현한 프롬프트 그리드월드 포스트를 참고해주세요.

그리드월드는 격자로 구성된 세상입니다. 에이전트는 격자 위를 움직이면서 목표를 향해 가야 합니다. 장애물도 있어서 피해 가야하죠. 장애물도 움직일 수 있습니다.

첫번째로 아주 간단한 환경을 구성해보겠습니다. 5X5 그리드 월드 환경에 (1, 1) , (0, 3), (3, 0), (4, 2), (2, 4)에 장애물이 정지해 있는 환경입니다. 도달해야하는 목표는 (4, 4)에 위치해 있습니다. 제가 만든 그리드월드를 사용해서 구성하면 아래와 같습니다. goal_included_in_state=False와 include_state=False를 통해 상태 정보에는 오로지 에이전트의 위치만 담기도록 합니다.

또한 현재 보상은 장애물에 닿았을 때 -1, 벽으로 가는 행동을 했을 때 -1, 목표에 도달했을 때 1, 그리고 매 스텝마다 -0.1입니다.

env = GridWorld(5, 5, state_method="absolute", goal_included_in_state=False)
env.add_obstacles(1, 1, 0, include_state=False)
env.add_obstacles(0, 3, 0, include_state=False)
env.add_obstacles(3, 0, 0, include_state=False)
env.add_obstacles(4, 2, 0, include_state=False)
env.add_obstacles(2, 4, 0, include_state=False)


env.show([0, 0], 0, 0)
state:[0, 0], reward:0, action:O
------------
|A # # O # |
|# O # # # |
|# # # # O |
|O # # # # |
|# # O # G |
------------

그리드 월드가 잘 생성되었네요.

3. SARSA 에이전트 만들기


그럼 이제 그리드월드를 누빌 SARSA 에이전트를 만들어 보겠습니다.

전체 코드는 아래와 같습니다.

import numpy as np
import random
from collections import defaultdict

class SARSA:
  def __init__(self, actions):
    self.ACTIONS = np.array(actions)
    self.step_size = 0.01
    self.discount_factor = 0.9
    self.epsilon = 0.1
    self.q_table = defaultdict(lambda:[0.0 for _ in range(len(actions))])

  def get_action(self, state):
    if np.random.rand() < self.epsilon:
        action = np.random.choice(self.ACTIONS)
    else:
        q_list = self.q_table[str(state)]
        action = self.ACTIONS[np.random.choice(np.argwhere(q_list == np.amax(q_list)).flatten().tolist())]
    return action

  def learn(self, state, action, reward, next_state, next_action):
    state, next_state = str(state), str(next_state)
    current_q = self.q_table[state][np.argwhere(self.ACTIONS==action)[0][0]]

    next_state_q = self.q_table[next_state][np.argwhere(self.ACTIONS==next_action)[0][0]]

    td = reward + self.discount_factor * next_state_q - current_q

    new_q = current_q + self.step_size * td

    self.q_table[state][np.argwhere(self.ACTIONS==action)[0][0]] = new_q

크게 3가지 부분으로 나눌 수 있습니다.

1) 첫번째로 초기화 부분입니다.

class SARSA:
  def __init__(self, actions):
    self.ACTIONS = np.array(actions)
    self.step_size = 0.01
    self.discount_factor = 0.9
    self.epsilon = 0.1
    self.q_table = defaultdict(lambda:[0.0 for _ in range(len(actions))])

클래스 선언시 매개변수로 클래스가 행동할 행동 리스트를 받아옵니다. 또한 스텝사이즈와 할인율, 탐험율을 정해줍니다. 마지막으로 큐함수를 담을 테이블을 만들어 주었는데요, defaultdict를 가져와 딕셔너리의 기본값을 정해줍니다. 기본적으로 0.0입니다. 즉 현재 모든 상태에서의 모든 행동의 큐함수는 0.0입니다.

2) 두번째로 행동을 가져오는 get_action(state)입니다.

  def get_action(self, state):
    if np.random.rand() < self.epsilon:
        action = np.random.choice(self.ACTIONS)
    else:
        q_list = self.q_table[str(state)]
        action = self.ACTIONS[np.random.choice(np.argwhere(q_list == np.amax(q_list)).flatten().tolist())]
    return action

- 탐욕 정책에 따라 0에서 1사이의 실수 난수를 뽑아서 그 값이 보다 작으면 행동중에서 무작위로 하나를 뽑아 행동하고 아닐 경우 큐함수 테이블에서 해당 상태의 큐함수 리스트 중에서 가장 큰 값에 따라 행동합니다. 값이 같은 큐함수가 있을 경우 그들 중에 무작위로 행동합니다. 여기서 딕셔너리의 키값으로 state를 문자열로 변환하여 넣었습니다.

3) 세번째로 학습을 하는 learn(state, action, reward, next_state, next_action) 입니다.

 def learn(self, state, action, reward, next_state, next_action):
    state, next_state = str(state), str(next_state)
    current_q = self.q_table[state][np.argwhere(self.ACTIONS==action)[0][0]]

    next_state_q = self.q_table[next_state][np.argwhere(self.ACTIONS==next_action)[0][0]]

    td = reward + self.discount_factor * next_state_q - current_q

    new_q = current_q + self.step_size * td

    self.q_table[state][np.argwhere(self.ACTIONS==action)[0][0]] = new_q

S A R S' A' 샘플을 얻어서 큐함수를 업데이트하는 과정입니다.

앞서 알아본 업데이트 공식에 맞추어서 현재의 행동 큐함수와 다음 행동 큐함수를 구하고 이를 바탕으로 현재의 행동 큐함수의 업데이트를 진행합니다.

간단하게 구현이 완료되었습니다.

4. SARSA 적용해보기


그럼 이제 만들어진 에이전트로 그리드 월드 환경에서 학습을 진행해 보도록 하겠습니다. 학습 코드 전체입니다.

agent = SARSA([1, 2, 3, 4])

episodes = []
scores = []

for E in range(1000):
  state = env.reset()
  done = False

  score = 0

  while not done:
    action = agent.get_action(state)
    next_state, reward, done = env.step(action, show=False)
    next_action = agent.get_action(next_state)

    agent.learn(state, action, reward, next_state, next_action)

    score += reward

    state = next_state

  print(f"episode {E} - score {score}")

  scores.append(score)
  episodes.append(E)

먼저 에이전트를 생성합니다. 상, 하, 좌, 우 행동만 할 것이기 때문에 [1, 2, 3, 4] 리스트를 넘겨 주었습니다. 그다음 에피소드당 점수를 저장해서 나중에 그래프로 볼 수 있게 빈 리스트를 만들었습니다.

총 1000번의 에피소드를 진행하도록 for 문을 구성하였습니다. for 문이 시작되면 환경을 리셋하고 상태를 받아오며 완료 여부를 저장할 변수 done과 점수를 저장할 변수 score를 선언합니다.

done이 True가 될 때 까지, 즉 에피소드가 종료될때 까지 while문을 반복합니다.

while 문 안에서 에이전트는 현재 상태를 기반으로 행동을 가져오고 이를 토대로 .step()을 통해 한 스텝 진행합니다. 이때 학습 동안에는 화면을 출력하지 않기 위해 show=False로 지정합니다. 리턴값으로 다음상태, 보상, 완료 여부를 받습니다.

에이전트는 다음 상태 정보로 다음 행동을 가져오고 S A R S' A' 샘플이 모두 모였으므로 .learn()을 통해 학습을 진행합니다. 그리고 score에 받은 보상을 누적합니다.

마지막으로 다음 상태를 현재상태로 저장합니다.

에피소드가 끝날때 까지 while문을 반복하다가 끝나면 이전에 생성했던 scores리스트에 점수를 저장합니다.

학습을 진행하면 아래와 같이 출력될 것입니다.

episode 0 - score -30.800000000000047
episode 1 - score -15.299999999999988
episode 2 - score -11.399999999999984
episode 3 - score -8.499999999999993
episode 4 - score -3.9999999999999987
episode 5 - score -5.899999999999995
episode 6 - score -13.099999999999982
episode 7 - score -18.500000000000032
episode 8 - score -6.499999999999995
episode 9 - score 0.20000000000000007
episode 10 - score -15.999999999999995
episode 11 - score -4.299999999999999
episode 12 - score -9.299999999999988
episode 13 - score -1.8000000000000012
episode 14 - score -7.999999999999995
episode 15 - score -17.099999999999994
episode 16 - score -8.599999999999989
episode 17 - score 0.20000000000000007
episode 18 - score -13.799999999999985
episode 19 - score -2.000000000000001
episode 20 - score -4.099999999999998
episode 21 - score -5.099999999999994
episode 22 - score -9.699999999999992

에피소드가 진행될 수록 보상이 조금씩 줄어드는게 보이시나요? 그래프를 그려 보겠습니다.

from matplotlib import pyplot as plt

plt.plot(episodes, scores)

에피소드가 진행될 수록 점수가 높아지며 수렴하는 모습을 볼 수 있습니다.

마지막으로 1000번의 에피소드동안 학습된 에이전트가 그리드월드를 움직이는 모습을 보겠습니다.

state = env.reset()
done = False

score = 0

while not done:
  action = agent.get_action(state)
  next_state, reward, done = env.step(action, show=True)
  next_action = agent.get_action(next_state)
  score += reward

  state = next_state

실행하면 아래 화면과 같이 에이전트가 움직이는 모습을 볼 수 있습니다. 아마 장애물을 잘 피해서 목표까지 도달할 것입니다.

state:[1, 0], reward:-0.1, action:>
------------
|# A # O # |
|# O # # # |
|# # # # O |
|O # # # # |
|# # O # G |
------------

state:[1, 0], reward:-1.1, action:^
------------
|# A # O # |
|# O # # # |
|# # # # O |
|O # # # # |
|# # O # G |
------------

state:[2, 0], reward:-0.1, action:>
------------
|# # A O # |
|# O # # # |
|# # # # O |
|O # # # # |
|# # O # G |
------------

state:[2, 1], reward:-0.1, action:v
------------
|# # # O # |
|# O A # # |
|# # # # O |
|O # # # # |
|# # O # G |
------------

state:[2, 2], reward:-0.1, action:v
------------
|# # # O # |
|# O # # # |
|# # A # O |
|O # # # # |
|# # O # G |
------------

state:[2, 3], reward:-0.1, action:v
------------
|# # # O # |
|# O # # # |
|# # # # O |
|O # A # # |
|# # O # G |
------------

state:[3, 3], reward:-0.1, action:>
------------
|# # # O # |
|# O # # # |
|# # # # O |
|O # # A # |
|# # O # G |
------------

state:[4, 3], reward:-0.1, action:>
------------
|# # # O # |
|# O # # # |
|# # # # O |
|O # # # A |
|# # O # G |
------------

state:[4, 4], reward:0.9, action:v
------------
|# # # O # |
|# O # # # |
|# # # # O |
|O # # # # |
|# # O # A |
------------

벽으로 가는 행동을 한번 하긴 했지만 잘 작동하네요!

이상으로 살사 알고리즘을 알아보고 파이썬으로 구현해보았습니다.

살사 코드 원본은 저의 깃허브 Python_RL_Agents SARSA.py 에 있습니다.

강화학습 큐함수 시간차예측 살사 SARSA



Search