test.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. #!/usr/bin/env python3
  2. import collections
  3. import itertools
  4. import gym
  5. import numpy
  6. _OBSERVATION_STEPS_NUM = 24
  7. _LEARNING_RATE = 0.1
  8. _DISCOUNT = 0.95
  9. # > CartPole-v0 defines "solving" as getting average reward of 195.0 over 100 consecutive trials.
  10. # https://gym.openai.com/envs/CartPole-v0/
  11. _SUCCESS_AVERAGE_REWARD = 195.0
  12. _SUCCESS_AVERAGE_WINDOW_SIZE = 100
  13. def _main():
  14. env = gym.make("CartPole-v0")
  15. observation_min = numpy.array([-4.8, -4.8, -0.9, -4.0])
  16. assert (
  17. env.observation_space.shape == observation_min.shape
  18. ), env.observation_space.shape
  19. observation_max = observation_min * -1
  20. observation_step = (observation_max - observation_min) / _OBSERVATION_STEPS_NUM
  21. assert len(env.observation_space.shape) == 1
  22. q_table = numpy.random.uniform(
  23. low=-4,
  24. high=0,
  25. size=(
  26. [_OBSERVATION_STEPS_NUM] * (env.observation_space.shape[0])
  27. + [env.action_space.n]
  28. ),
  29. )
  30. print("q_table:", q_table.itemsize * q_table.size / 2 ** 20, "MiB")
  31. last_step_counts = collections.deque()
  32. last_step_counts_sum = 0
  33. for episode_index in itertools.count():
  34. observation = env.reset()
  35. render = (episode_index % 1000) == 0
  36. if render:
  37. env.render()
  38. observation_index = ((observation - observation_min) / observation_step).astype(
  39. int
  40. )
  41. for step_index in itertools.count():
  42. action = q_table[tuple(observation_index)].argmax()
  43. next_observation, reward, done, info = env.step(action)
  44. if render:
  45. env.render()
  46. assert (next_observation >= observation_min).all(), next_observation
  47. assert (next_observation <= observation_max).all(), next_observation
  48. next_observation_index = (
  49. (next_observation - observation_min) / observation_step
  50. ).astype(int)
  51. if done:
  52. q_table[tuple(observation_index)][action] = (
  53. -300 if step_index < 190 else reward
  54. )
  55. last_step_counts.append(step_index + 1)
  56. last_step_counts_sum += step_index + 1
  57. if len(last_step_counts) > _SUCCESS_AVERAGE_WINDOW_SIZE:
  58. last_step_counts_sum -= last_step_counts.popleft()
  59. average_reward = last_step_counts_sum / _SUCCESS_AVERAGE_WINDOW_SIZE
  60. print(
  61. f"episode #{episode_index}"
  62. f"\t{step_index+1} steps"
  63. f"\taverage of {average_reward:.1f} steps"
  64. f" over last {_SUCCESS_AVERAGE_WINDOW_SIZE} episodes"
  65. )
  66. if average_reward > _SUCCESS_AVERAGE_REWARD:
  67. return
  68. break
  69. assert not info, info
  70. q_table[tuple(observation_index)][action] += _LEARNING_RATE * (
  71. reward
  72. + _DISCOUNT * q_table[tuple(next_observation_index)].max()
  73. - q_table[tuple(observation_index)][action]
  74. )
  75. observation_index = next_observation_index
  76. if __name__ == "__main__":
  77. _main()