|
@@ -1,23 +1,63 @@
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
-import time
|
|
|
+import itertools
|
|
|
|
|
|
import gym
|
|
|
+import numpy
|
|
|
+
|
|
|
+_OBSERVATION_STEPS_NUM = 64
|
|
|
+_LEARNING_RATE = 0.1
|
|
|
+_DISCOUNT = 0.9
|
|
|
|
|
|
|
|
|
def _main():
|
|
|
env = gym.make("CartPole-v0")
|
|
|
- env.reset()
|
|
|
- env.render()
|
|
|
- while True:
|
|
|
- action = env.action_space.sample()
|
|
|
- observation, reward, done, info = env.step(action)
|
|
|
- env.render()
|
|
|
- assert not info
|
|
|
- print(action, observation, reward)
|
|
|
- if done:
|
|
|
- break
|
|
|
- time.sleep(21)
|
|
|
+ observation_min = numpy.array([-4.8, -4.0, -0.8, -3.6])
|
|
|
+ assert (
|
|
|
+ env.observation_space.shape == observation_min.shape
|
|
|
+ ), env.observation_space.shape
|
|
|
+ observation_max = observation_min * -1
|
|
|
+ observation_step = (observation_max - observation_min) / _OBSERVATION_STEPS_NUM
|
|
|
+ assert len(env.observation_space.shape) == 1
|
|
|
+ q_table = numpy.random.uniform(
|
|
|
+ low=0,
|
|
|
+ high=4,
|
|
|
+ size=(
|
|
|
+ [_OBSERVATION_STEPS_NUM] * (env.observation_space.shape[0])
|
|
|
+ + [env.action_space.n]
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ print("q_table:", q_table.itemsize * q_table.size / 2 ** 20, "MiB")
|
|
|
+ for episode_index in itertools.count():
|
|
|
+ observation = env.reset()
|
|
|
+ render = (episode_index % 500) == 0
|
|
|
+ if render:
|
|
|
+ env.render()
|
|
|
+ observation_index = ((observation - observation_min) / observation_step).astype(
|
|
|
+ int
|
|
|
+ )
|
|
|
+ for step_index in itertools.count():
|
|
|
+ # action = env.action_space.sample()
|
|
|
+ action = q_table[tuple(observation_index)].argmax()
|
|
|
+ next_observation, reward, done, info = env.step(action)
|
|
|
+ if render:
|
|
|
+ env.render()
|
|
|
+ assert (next_observation >= observation_min).all(), next_observation
|
|
|
+ assert (next_observation <= observation_max).all(), next_observation
|
|
|
+ next_observation_index = (
|
|
|
+ (next_observation - observation_min) / observation_step
|
|
|
+ ).astype(int)
|
|
|
+ if done:
|
|
|
+ q_table[tuple(observation_index)][action] = reward
|
|
|
+ print(step_index + 1, "steps")
|
|
|
+ break
|
|
|
+ assert not info, info
|
|
|
+ q_table[tuple(observation_index)][action] += _LEARNING_RATE * (
|
|
|
+ reward
|
|
|
+ + _DISCOUNT * q_table[tuple(next_observation_index)].max()
|
|
|
+ - q_table[tuple(observation_index)][action]
|
|
|
+ )
|
|
|
+ observation_index = next_observation_index
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|