test.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. #!/usr/bin/env python3
  2. import itertools
  3. import gym
  4. import numpy
  5. _OBSERVATION_STEPS_NUM = 24
  6. _LEARNING_RATE = 0.1
  7. _DISCOUNT = 0.95
  8. def _main():
  9. env = gym.make("CartPole-v0")
  10. observation_min = numpy.array([-4.8, -4.8, -0.9, -4.0])
  11. assert (
  12. env.observation_space.shape == observation_min.shape
  13. ), env.observation_space.shape
  14. observation_max = observation_min * -1
  15. observation_step = (observation_max - observation_min) / _OBSERVATION_STEPS_NUM
  16. assert len(env.observation_space.shape) == 1
  17. q_table = numpy.random.uniform(
  18. low=-4,
  19. high=0,
  20. size=(
  21. [_OBSERVATION_STEPS_NUM] * (env.observation_space.shape[0])
  22. + [env.action_space.n]
  23. ),
  24. )
  25. print("q_table:", q_table.itemsize * q_table.size / 2 ** 20, "MiB")
  26. for episode_index in itertools.count():
  27. observation = env.reset()
  28. render = (episode_index % 400) == 0
  29. if render:
  30. env.render()
  31. observation_index = ((observation - observation_min) / observation_step).astype(
  32. int
  33. )
  34. for step_index in itertools.count():
  35. # action = env.action_space.sample()
  36. action = q_table[tuple(observation_index)].argmax()
  37. next_observation, reward, done, info = env.step(action)
  38. if render:
  39. env.render()
  40. assert (next_observation >= observation_min).all(), next_observation
  41. assert (next_observation <= observation_max).all(), next_observation
  42. next_observation_index = (
  43. (next_observation - observation_min) / observation_step
  44. ).astype(int)
  45. if done:
  46. q_table[tuple(observation_index)][action] = (
  47. -300 if step_index < 190 else reward
  48. )
  49. print(step_index + 1, "steps")
  50. break
  51. assert not info, info
  52. q_table[tuple(observation_index)][action] += _LEARNING_RATE * (
  53. reward
  54. + _DISCOUNT * q_table[tuple(next_observation_index)].max()
  55. - q_table[tuple(observation_index)][action]
  56. )
  57. observation_index = next_observation_index
  58. if __name__ == "__main__":
  59. _main()