{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install dm-env \"dm-acme[jax,tf,envs]\"" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import dm_env\n", "from dm_env import specs\n", "import numpy as np\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "\n", "_ACTIONS = (-1, 0, 1) # Left, no-op, right.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "class Catch(dm_env.Environment):\n", " \"\"\"A Catch environment built on the `dm_env.Environment` class.\n", " The agent must move a paddle to intercept falling balls. Falling balls only\n", " move downwards on the column they are in.\n", " The observation is an array shape (rows, columns), with binary values:\n", " zero if a space is empty; 1 if it contains the paddle or a ball.\n", " The actions are discrete, and by default there are three available:\n", " stay, move left, and move right.\n", " The episode terminates when the ball reaches the bottom of the screen.\n", " \"\"\"\n", "\n", " def __init__(self, rows: int = 10, columns: int = 5, seed: int = 1):\n", " \"\"\"Initializes a new Catch environment.\n", " Args:\n", " rows: number of rows.\n", " columns: number of columns.\n", " seed: random seed for the RNG.\n", " \"\"\"\n", " self._rows = rows\n", " self._columns = columns\n", " self._rng = np.random.RandomState(seed)\n", " self._board = np.zeros((rows, columns), dtype=np.float32)\n", " self._ball_x = None\n", " self._ball_y = None\n", " self._paddle_x = None\n", " self._paddle_y = self._rows - 1\n", " self._reset_next_step = True\n", "\n", " def reset(self) -> dm_env.TimeStep:\n", " \"\"\"Returns the first `TimeStep` of a new episode.\"\"\"\n", " self._reset_next_step = False\n", " self._ball_x = self._rng.randint(self._columns)\n", " self._ball_y = 0\n", " self._paddle_x = self._columns // 2\n", " return dm_env.restart(self._observation())\n", "\n", " def step(self, action: int) -> dm_env.TimeStep:\n", " \"\"\"Updates the environment according to the action.\"\"\"\n", " if self._reset_next_step:\n", " return self.reset()\n", "\n", " # Move the paddle.\n", " dx = _ACTIONS[action]\n", " self._paddle_x = np.clip(self._paddle_x + dx, 0, self._columns - 1)\n", "\n", " # Drop the ball.\n", " self._ball_y += 1\n", "\n", " # Check for termination.\n", " if self._ball_y == self._paddle_y:\n", " reward = 1. if self._paddle_x == self._ball_x else -1.\n", " self._reset_next_step = True\n", " return dm_env.termination(reward=reward, observation=self._observation())\n", " else:\n", " return dm_env.transition(reward=0., observation=self._observation())\n", "\n", " def observation_spec(self) -> specs.BoundedArray:\n", " \"\"\"Returns the observation spec.\"\"\"\n", " return specs.BoundedArray(\n", " shape=self._board.shape,\n", " dtype=self._board.dtype,\n", " name=\"board\",\n", " minimum=0,\n", " maximum=1,\n", " )\n", "\n", " def action_spec(self) -> specs.DiscreteArray:\n", " \"\"\"Returns the action spec.\"\"\"\n", " return specs.DiscreteArray(\n", " dtype=int, num_values=len(_ACTIONS), name=\"action\")\n", "\n", " def _observation(self) -> np.ndarray:\n", " self._board.fill(0.)\n", " self._board[self._ball_y, self._ball_x] = 1.\n", " self._board[self._paddle_y, self._paddle_x] = 1.\n", " return self._board.copy()" ] } ], "metadata": { "interpreter": { "hash": "bbf697b8262efa0945ca8c26e05a7230a0c66880a0fc993cfa3876d84998e14e" }, "kernelspec": { "display_name": "Python 3.10.4 ('dm_env')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }