{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install dm-env \"dm-acme[jax,tf,envs]\""
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import dm_env\n",
"from dm_env import specs\n",
"import numpy as np\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"\n",
"_ACTIONS = (-1, 0, 1) # Left, no-op, right.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"class Catch(dm_env.Environment):\n",
" \"\"\"A Catch environment built on the `dm_env.Environment` class.\n",
" The agent must move a paddle to intercept falling balls. Falling balls only\n",
" move downwards on the column they are in.\n",
" The observation is an array shape (rows, columns), with binary values:\n",
" zero if a space is empty; 1 if it contains the paddle or a ball.\n",
" The actions are discrete, and by default there are three available:\n",
" stay, move left, and move right.\n",
" The episode terminates when the ball reaches the bottom of the screen.\n",
" \"\"\"\n",
"\n",
" def __init__(self, rows: int = 10, columns: int = 5, seed: int = 1):\n",
" \"\"\"Initializes a new Catch environment.\n",
" Args:\n",
" rows: number of rows.\n",
" columns: number of columns.\n",
" seed: random seed for the RNG.\n",
" \"\"\"\n",
" self._rows = rows\n",
" self._columns = columns\n",
" self._rng = np.random.RandomState(seed)\n",
" self._board = np.zeros((rows, columns), dtype=np.float32)\n",
" self._ball_x = None\n",
" self._ball_y = None\n",
" self._paddle_x = None\n",
" self._paddle_y = self._rows - 1\n",
" self._reset_next_step = True\n",
"\n",
" def reset(self) -> dm_env.TimeStep:\n",
" \"\"\"Returns the first `TimeStep` of a new episode.\"\"\"\n",
" self._reset_next_step = False\n",
" self._ball_x = self._rng.randint(self._columns)\n",
" self._ball_y = 0\n",
" self._paddle_x = self._columns // 2\n",
" return dm_env.restart(self._observation())\n",
"\n",
" def step(self, action: int) -> dm_env.TimeStep:\n",
" \"\"\"Updates the environment according to the action.\"\"\"\n",
" if self._reset_next_step:\n",
" return self.reset()\n",
"\n",
" # Move the paddle.\n",
" dx = _ACTIONS[action]\n",
" self._paddle_x = np.clip(self._paddle_x + dx, 0, self._columns - 1)\n",
"\n",
" # Drop the ball.\n",
" self._ball_y += 1\n",
"\n",
" # Check for termination.\n",
" if self._ball_y == self._paddle_y:\n",
" reward = 1. if self._paddle_x == self._ball_x else -1.\n",
" self._reset_next_step = True\n",
" return dm_env.termination(reward=reward, observation=self._observation())\n",
" else:\n",
" return dm_env.transition(reward=0., observation=self._observation())\n",
"\n",
" def observation_spec(self) -> specs.BoundedArray:\n",
" \"\"\"Returns the observation spec.\"\"\"\n",
" return specs.BoundedArray(\n",
" shape=self._board.shape,\n",
" dtype=self._board.dtype,\n",
" name=\"board\",\n",
" minimum=0,\n",
" maximum=1,\n",
" )\n",
"\n",
" def action_spec(self) -> specs.DiscreteArray:\n",
" \"\"\"Returns the action spec.\"\"\"\n",
" return specs.DiscreteArray(\n",
" dtype=int, num_values=len(_ACTIONS), name=\"action\")\n",
"\n",
" def _observation(self) -> np.ndarray:\n",
" self._board.fill(0.)\n",
" self._board[self._ball_y, self._ball_x] = 1.\n",
" self._board[self._paddle_y, self._paddle_x] = 1.\n",
" return self._board.copy()"
]
}
],
"metadata": {
"interpreter": {
"hash": "bbf697b8262efa0945ca8c26e05a7230a0c66880a0fc993cfa3876d84998e14e"
},
"kernelspec": {
"display_name": "Python 3.10.4 ('dm_env')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}