123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- #!/usr/bin/env python3
- import itertools
- import gym
- import numpy
- _OBSERVATION_STEPS_NUM = 24
- _LEARNING_RATE = 0.1
- _DISCOUNT = 0.95
- def _main():
- env = gym.make("CartPole-v0")
- observation_min = numpy.array([-4.8, -4.8, -0.9, -4.0])
- assert (
- env.observation_space.shape == observation_min.shape
- ), env.observation_space.shape
- observation_max = observation_min * -1
- observation_step = (observation_max - observation_min) / _OBSERVATION_STEPS_NUM
- assert len(env.observation_space.shape) == 1
- q_table = numpy.random.uniform(
- low=-4,
- high=0,
- size=(
- [_OBSERVATION_STEPS_NUM] * (env.observation_space.shape[0])
- + [env.action_space.n]
- ),
- )
- print("q_table:", q_table.itemsize * q_table.size / 2 ** 20, "MiB")
- for episode_index in itertools.count():
- observation = env.reset()
- render = (episode_index % 400) == 0
- if render:
- env.render()
- observation_index = ((observation - observation_min) / observation_step).astype(
- int
- )
- for step_index in itertools.count():
- # action = env.action_space.sample()
- action = q_table[tuple(observation_index)].argmax()
- next_observation, reward, done, info = env.step(action)
- if render:
- env.render()
- assert (next_observation >= observation_min).all(), next_observation
- assert (next_observation <= observation_max).all(), next_observation
- next_observation_index = (
- (next_observation - observation_min) / observation_step
- ).astype(int)
- if done:
- q_table[tuple(observation_index)][action] = (
- -300 if step_index < 190 else reward
- )
- print(step_index + 1, "steps")
- break
- assert not info, info
- q_table[tuple(observation_index)][action] += _LEARNING_RATE * (
- reward
- + _DISCOUNT * q_table[tuple(next_observation_index)].max()
- - q_table[tuple(observation_index)][action]
- )
- observation_index = next_observation_index
- if __name__ == "__main__":
- _main()
|