|
@@ -5,14 +5,14 @@ import itertools
|
|
import gym
|
|
import gym
|
|
import numpy
|
|
import numpy
|
|
|
|
|
|
-_OBSERVATION_STEPS_NUM = 64
|
|
|
|
|
|
+_OBSERVATION_STEPS_NUM = 24
|
|
_LEARNING_RATE = 0.1
|
|
_LEARNING_RATE = 0.1
|
|
-_DISCOUNT = 0.9
|
|
|
|
|
|
+_DISCOUNT = 0.95
|
|
|
|
|
|
|
|
|
|
def _main():
|
|
def _main():
|
|
env = gym.make("CartPole-v0")
|
|
env = gym.make("CartPole-v0")
|
|
- observation_min = numpy.array([-4.8, -4.0, -0.8, -3.6])
|
|
|
|
|
|
+ observation_min = numpy.array([-4.8, -4.8, -0.9, -4.0])
|
|
assert (
|
|
assert (
|
|
env.observation_space.shape == observation_min.shape
|
|
env.observation_space.shape == observation_min.shape
|
|
), env.observation_space.shape
|
|
), env.observation_space.shape
|
|
@@ -20,8 +20,8 @@ def _main():
|
|
observation_step = (observation_max - observation_min) / _OBSERVATION_STEPS_NUM
|
|
observation_step = (observation_max - observation_min) / _OBSERVATION_STEPS_NUM
|
|
assert len(env.observation_space.shape) == 1
|
|
assert len(env.observation_space.shape) == 1
|
|
q_table = numpy.random.uniform(
|
|
q_table = numpy.random.uniform(
|
|
- low=0,
|
|
|
|
- high=4,
|
|
|
|
|
|
+ low=-4,
|
|
|
|
+ high=0,
|
|
size=(
|
|
size=(
|
|
[_OBSERVATION_STEPS_NUM] * (env.observation_space.shape[0])
|
|
[_OBSERVATION_STEPS_NUM] * (env.observation_space.shape[0])
|
|
+ [env.action_space.n]
|
|
+ [env.action_space.n]
|
|
@@ -30,7 +30,7 @@ def _main():
|
|
print("q_table:", q_table.itemsize * q_table.size / 2 ** 20, "MiB")
|
|
print("q_table:", q_table.itemsize * q_table.size / 2 ** 20, "MiB")
|
|
for episode_index in itertools.count():
|
|
for episode_index in itertools.count():
|
|
observation = env.reset()
|
|
observation = env.reset()
|
|
- render = (episode_index % 500) == 0
|
|
|
|
|
|
+ render = (episode_index % 400) == 0
|
|
if render:
|
|
if render:
|
|
env.render()
|
|
env.render()
|
|
observation_index = ((observation - observation_min) / observation_step).astype(
|
|
observation_index = ((observation - observation_min) / observation_step).astype(
|
|
@@ -48,7 +48,9 @@ def _main():
|
|
(next_observation - observation_min) / observation_step
|
|
(next_observation - observation_min) / observation_step
|
|
).astype(int)
|
|
).astype(int)
|
|
if done:
|
|
if done:
|
|
- q_table[tuple(observation_index)][action] = reward
|
|
|
|
|
|
+ q_table[tuple(observation_index)][action] = (
|
|
|
|
+ -300 if step_index < 190 else reward
|
|
|
|
+ )
|
|
print(step_index + 1, "steps")
|
|
print(step_index + 1, "steps")
|
|
break
|
|
break
|
|
assert not info, info
|
|
assert not info, info
|