#!/usr/bin/env python3 import itertools import gym import numpy _OBSERVATION_STEPS_NUM = 24 _LEARNING_RATE = 0.1 _DISCOUNT = 0.95 def _main(): env = gym.make("CartPole-v0") observation_min = numpy.array([-4.8, -4.8, -0.9, -4.0]) assert ( env.observation_space.shape == observation_min.shape ), env.observation_space.shape observation_max = observation_min * -1 observation_step = (observation_max - observation_min) / _OBSERVATION_STEPS_NUM assert len(env.observation_space.shape) == 1 q_table = numpy.random.uniform( low=-4, high=0, size=( [_OBSERVATION_STEPS_NUM] * (env.observation_space.shape[0]) + [env.action_space.n] ), ) print("q_table:", q_table.itemsize * q_table.size / 2 ** 20, "MiB") for episode_index in itertools.count(): observation = env.reset() render = (episode_index % 400) == 0 if render: env.render() observation_index = ((observation - observation_min) / observation_step).astype( int ) for step_index in itertools.count(): # action = env.action_space.sample() action = q_table[tuple(observation_index)].argmax() next_observation, reward, done, info = env.step(action) if render: env.render() assert (next_observation >= observation_min).all(), next_observation assert (next_observation <= observation_max).all(), next_observation next_observation_index = ( (next_observation - observation_min) / observation_step ).astype(int) if done: q_table[tuple(observation_index)][action] = ( -300 if step_index < 190 else reward ) print(step_index + 1, "steps") break assert not info, info q_table[tuple(observation_index)][action] += _LEARNING_RATE * ( reward + _DISCOUNT * q_table[tuple(next_observation_index)].max() - q_table[tuple(observation_index)][action] ) observation_index = next_observation_index if __name__ == "__main__": _main()