diff --git a/py/gym_dqn.py b/py/gym_dqn.py new file mode 100644 index 0000000..eac09ae --- /dev/null +++ b/py/gym_dqn.py @@ -0,0 +1,41 @@ +# %% Keras-RL for gym cart pole + +import gym +from rl.policy import EpsGreedyQPolicy, LinearAnnealedPolicy +from rl.memory import SequentialMemory +from rl.agents.dqn import DQNAgent + +# 1. model +def build_model(state_size, num_actions): + input = Input(shape=(1,state_size)) + x = Flatten()(input) + x = Dense(16, activation='relu')(x) + x = Dense(16, activation='relu')(x) + x = Dense(16, activation='relu')(x) + output = Dense(num_actions, activation='linear')(x) + model = Model(inputs=input, outputs=output) + print(model.summary()) + return model + +model = build_model() + +# 2. memory +memory = SequentialMemory(limit=50000, window_length=1) + +# 3. policy (decay eps) + +policy = LinearAnnealedPolicy( + EpsGreedyQPolicy(), + attr='eps', + value_max=1., + value_min=.1, + value_test=.05, + nb_steps=10000) + +# 4. agent + +dqn = DQNAgent( + model=model, nb_actions=num_actions, memory=memory, nb_steps_warmup=10, + target_model_update=1e-2, policy=policy) + +dqn.compile(Adam(lr=1e-3), metrics=['mae']) \ No newline at end of file