瀏覽代碼

first (failing) q-learning attempt

Fabian Peter Hammerle 2 年之前
父節點
當前提交
41f9203d7c
共有 1 個文件被更改,包括 52 次插入12 次删除
  1. 52 12
      test.py

+ 52 - 12
test.py

@@ -1,23 +1,63 @@
 #!/usr/bin/env python3
 
-import time
+import itertools
 
 import gym
+import numpy
+
+_OBSERVATION_STEPS_NUM = 64
+_LEARNING_RATE = 0.1
+_DISCOUNT = 0.9
 
 
 def _main():
     env = gym.make("CartPole-v0")
-    env.reset()
-    env.render()
-    while True:
-        action = env.action_space.sample()
-        observation, reward, done, info = env.step(action)
-        env.render()
-        assert not info
-        print(action, observation, reward)
-        if done:
-            break
-    time.sleep(21)
+    observation_min = numpy.array([-4.8, -4.0, -0.8, -3.6])
+    assert (
+        env.observation_space.shape == observation_min.shape
+    ), env.observation_space.shape
+    observation_max = observation_min * -1
+    observation_step = (observation_max - observation_min) / _OBSERVATION_STEPS_NUM
+    assert len(env.observation_space.shape) == 1
+    q_table = numpy.random.uniform(
+        low=0,
+        high=4,
+        size=(
+            [_OBSERVATION_STEPS_NUM] * (env.observation_space.shape[0])
+            + [env.action_space.n]
+        ),
+    )
+    print("q_table:", q_table.itemsize * q_table.size / 2 ** 20, "MiB")
+    for episode_index in itertools.count():
+        observation = env.reset()
+        render = (episode_index % 500) == 0
+        if render:
+            env.render()
+        observation_index = ((observation - observation_min) / observation_step).astype(
+            int
+        )
+        for step_index in itertools.count():
+            # action = env.action_space.sample()
+            action = q_table[tuple(observation_index)].argmax()
+            next_observation, reward, done, info = env.step(action)
+            if render:
+                env.render()
+            assert (next_observation >= observation_min).all(), next_observation
+            assert (next_observation <= observation_max).all(), next_observation
+            next_observation_index = (
+                (next_observation - observation_min) / observation_step
+            ).astype(int)
+            if done:
+                q_table[tuple(observation_index)][action] = reward
+                print(step_index + 1, "steps")
+                break
+            assert not info, info
+            q_table[tuple(observation_index)][action] += _LEARNING_RATE * (
+                reward
+                + _DISCOUNT * q_table[tuple(next_observation_index)].max()
+                - q_table[tuple(observation_index)][action]
+            )
+            observation_index = next_observation_index
 
 
 if __name__ == "__main__":