Browse Source

check if "solved" (took 3302 episodes in one attempt)

Fabian Peter Hammerle 2 years ago
parent
commit
a0eff373f9
1 changed files with 22 additions and 3 deletions
  1. 22 3
      test.py

+ 22 - 3
test.py

@@ -1,5 +1,6 @@
 #!/usr/bin/env python3
 
+import collections
 import itertools
 
 import gym
@@ -9,6 +10,11 @@ _OBSERVATION_STEPS_NUM = 24
 _LEARNING_RATE = 0.1
 _DISCOUNT = 0.95
 
+# > CartPole-v0 defines "solving" as getting average reward of 195.0 over 100 consecutive trials.
+# https://gym.openai.com/envs/CartPole-v0/
+_SUCCESS_AVERAGE_REWARD = 195.0
+_SUCCESS_AVERAGE_WINDOW_SIZE = 100
+
 
 def _main():
     env = gym.make("CartPole-v0")
@@ -28,16 +34,17 @@ def _main():
         ),
     )
     print("q_table:", q_table.itemsize * q_table.size / 2 ** 20, "MiB")
+    last_step_counts = collections.deque()
+    last_step_counts_sum = 0
     for episode_index in itertools.count():
         observation = env.reset()
-        render = (episode_index % 400) == 0
+        render = (episode_index % 1000) == 0
         if render:
             env.render()
         observation_index = ((observation - observation_min) / observation_step).astype(
             int
         )
         for step_index in itertools.count():
-            # action = env.action_space.sample()
             action = q_table[tuple(observation_index)].argmax()
             next_observation, reward, done, info = env.step(action)
             if render:
@@ -51,7 +58,19 @@ def _main():
                 q_table[tuple(observation_index)][action] = (
                     -300 if step_index < 190 else reward
                 )
-                print(step_index + 1, "steps")
+                last_step_counts.append(step_index + 1)
+                last_step_counts_sum += step_index + 1
+                if len(last_step_counts) > _SUCCESS_AVERAGE_WINDOW_SIZE:
+                    last_step_counts_sum -= last_step_counts.popleft()
+                average_reward = last_step_counts_sum / _SUCCESS_AVERAGE_WINDOW_SIZE
+                print(
+                    f"episode #{episode_index}"
+                    f"\t{step_index+1} steps"
+                    f"\taverage of {average_reward:.1f} steps"
+                    f" over last {_SUCCESS_AVERAGE_WINDOW_SIZE} episodes"
+                )
+                if average_reward > _SUCCESS_AVERAGE_REWARD:
+                    return
                 break
             assert not info, info
             q_table[tuple(observation_index)][action] += _LEARNING_RATE * (