Fabian Peter Hammerle 2 anni fa
parent
commit
21c9f0b827
1 ha cambiato i file con 9 aggiunte e 7 eliminazioni
  1. 9 7
      test.py

+ 9 - 7
test.py

@@ -5,14 +5,14 @@ import itertools
 import gym
 import numpy
 
-_OBSERVATION_STEPS_NUM = 64
+_OBSERVATION_STEPS_NUM = 24
 _LEARNING_RATE = 0.1
-_DISCOUNT = 0.9
+_DISCOUNT = 0.95
 
 
 def _main():
     env = gym.make("CartPole-v0")
-    observation_min = numpy.array([-4.8, -4.0, -0.8, -3.6])
+    observation_min = numpy.array([-4.8, -4.8, -0.9, -4.0])
     assert (
         env.observation_space.shape == observation_min.shape
     ), env.observation_space.shape
@@ -20,8 +20,8 @@ def _main():
     observation_step = (observation_max - observation_min) / _OBSERVATION_STEPS_NUM
     assert len(env.observation_space.shape) == 1
     q_table = numpy.random.uniform(
-        low=0,
-        high=4,
+        low=-4,
+        high=0,
         size=(
             [_OBSERVATION_STEPS_NUM] * (env.observation_space.shape[0])
             + [env.action_space.n]
@@ -30,7 +30,7 @@ def _main():
     print("q_table:", q_table.itemsize * q_table.size / 2 ** 20, "MiB")
     for episode_index in itertools.count():
         observation = env.reset()
-        render = (episode_index % 500) == 0
+        render = (episode_index % 400) == 0
         if render:
             env.render()
         observation_index = ((observation - observation_min) / observation_step).astype(
@@ -48,7 +48,9 @@ def _main():
                 (next_observation - observation_min) / observation_step
             ).astype(int)
             if done:
-                q_table[tuple(observation_index)][action] = reward
+                q_table[tuple(observation_index)][action] = (
+                    -300 if step_index < 190 else reward
+                )
                 print(step_index + 1, "steps")
                 break
             assert not info, info