diff --git a/acme_env.ipynb b/acme_env.ipynb new file mode 100644 index 0000000..af08874 --- /dev/null +++ b/acme_env.ipynb @@ -0,0 +1,145 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install dm-env \"dm-acme[jax,tf,envs]\"" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import dm_env\n", + "from dm_env import specs\n", + "import numpy as np\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "_ACTIONS = (-1, 0, 1) # Left, no-op, right.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "class Catch(dm_env.Environment):\n", + " \"\"\"A Catch environment built on the `dm_env.Environment` class.\n", + " The agent must move a paddle to intercept falling balls. Falling balls only\n", + " move downwards on the column they are in.\n", + " The observation is an array shape (rows, columns), with binary values:\n", + " zero if a space is empty; 1 if it contains the paddle or a ball.\n", + " The actions are discrete, and by default there are three available:\n", + " stay, move left, and move right.\n", + " The episode terminates when the ball reaches the bottom of the screen.\n", + " \"\"\"\n", + "\n", + " def __init__(self, rows: int = 10, columns: int = 5, seed: int = 1):\n", + " \"\"\"Initializes a new Catch environment.\n", + " Args:\n", + " rows: number of rows.\n", + " columns: number of columns.\n", + " seed: random seed for the RNG.\n", + " \"\"\"\n", + " self._rows = rows\n", + " self._columns = columns\n", + " self._rng = np.random.RandomState(seed)\n", + " self._board = np.zeros((rows, columns), dtype=np.float32)\n", + " self._ball_x = None\n", + " self._ball_y = None\n", + " self._paddle_x = None\n", + " self._paddle_y = self._rows - 1\n", + " self._reset_next_step = True\n", + "\n", + " def reset(self) -> dm_env.TimeStep:\n", + " \"\"\"Returns the first `TimeStep` of a new episode.\"\"\"\n", + " self._reset_next_step = False\n", + " self._ball_x = self._rng.randint(self._columns)\n", + " self._ball_y = 0\n", + " self._paddle_x = self._columns // 2\n", + " return dm_env.restart(self._observation())\n", + "\n", + " def step(self, action: int) -> dm_env.TimeStep:\n", + " \"\"\"Updates the environment according to the action.\"\"\"\n", + " if self._reset_next_step:\n", + " return self.reset()\n", + "\n", + " # Move the paddle.\n", + " dx = _ACTIONS[action]\n", + " self._paddle_x = np.clip(self._paddle_x + dx, 0, self._columns - 1)\n", + "\n", + " # Drop the ball.\n", + " self._ball_y += 1\n", + "\n", + " # Check for termination.\n", + " if self._ball_y == self._paddle_y:\n", + " reward = 1. if self._paddle_x == self._ball_x else -1.\n", + " self._reset_next_step = True\n", + " return dm_env.termination(reward=reward, observation=self._observation())\n", + " else:\n", + " return dm_env.transition(reward=0., observation=self._observation())\n", + "\n", + " def observation_spec(self) -> specs.BoundedArray:\n", + " \"\"\"Returns the observation spec.\"\"\"\n", + " return specs.BoundedArray(\n", + " shape=self._board.shape,\n", + " dtype=self._board.dtype,\n", + " name=\"board\",\n", + " minimum=0,\n", + " maximum=1,\n", + " )\n", + "\n", + " def action_spec(self) -> specs.DiscreteArray:\n", + " \"\"\"Returns the action spec.\"\"\"\n", + " return specs.DiscreteArray(\n", + " dtype=int, num_values=len(_ACTIONS), name=\"action\")\n", + "\n", + " def _observation(self) -> np.ndarray:\n", + " self._board.fill(0.)\n", + " self._board[self._ball_y, self._ball_x] = 1.\n", + " self._board[self._paddle_y, self._paddle_x] = 1.\n", + " return self._board.copy()" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "bbf697b8262efa0945ca8c26e05a7230a0c66880a0fc993cfa3876d84998e14e" + }, + "kernelspec": { + "display_name": "Python 3.10.4 ('dm_env')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/acme_env.ipynb b/acme_env.ipynb new file mode 100644 index 0000000..af08874 --- /dev/null +++ b/acme_env.ipynb @@ -0,0 +1,145 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install dm-env \"dm-acme[jax,tf,envs]\"" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import dm_env\n", + "from dm_env import specs\n", + "import numpy as np\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "_ACTIONS = (-1, 0, 1) # Left, no-op, right.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "class Catch(dm_env.Environment):\n", + " \"\"\"A Catch environment built on the `dm_env.Environment` class.\n", + " The agent must move a paddle to intercept falling balls. Falling balls only\n", + " move downwards on the column they are in.\n", + " The observation is an array shape (rows, columns), with binary values:\n", + " zero if a space is empty; 1 if it contains the paddle or a ball.\n", + " The actions are discrete, and by default there are three available:\n", + " stay, move left, and move right.\n", + " The episode terminates when the ball reaches the bottom of the screen.\n", + " \"\"\"\n", + "\n", + " def __init__(self, rows: int = 10, columns: int = 5, seed: int = 1):\n", + " \"\"\"Initializes a new Catch environment.\n", + " Args:\n", + " rows: number of rows.\n", + " columns: number of columns.\n", + " seed: random seed for the RNG.\n", + " \"\"\"\n", + " self._rows = rows\n", + " self._columns = columns\n", + " self._rng = np.random.RandomState(seed)\n", + " self._board = np.zeros((rows, columns), dtype=np.float32)\n", + " self._ball_x = None\n", + " self._ball_y = None\n", + " self._paddle_x = None\n", + " self._paddle_y = self._rows - 1\n", + " self._reset_next_step = True\n", + "\n", + " def reset(self) -> dm_env.TimeStep:\n", + " \"\"\"Returns the first `TimeStep` of a new episode.\"\"\"\n", + " self._reset_next_step = False\n", + " self._ball_x = self._rng.randint(self._columns)\n", + " self._ball_y = 0\n", + " self._paddle_x = self._columns // 2\n", + " return dm_env.restart(self._observation())\n", + "\n", + " def step(self, action: int) -> dm_env.TimeStep:\n", + " \"\"\"Updates the environment according to the action.\"\"\"\n", + " if self._reset_next_step:\n", + " return self.reset()\n", + "\n", + " # Move the paddle.\n", + " dx = _ACTIONS[action]\n", + " self._paddle_x = np.clip(self._paddle_x + dx, 0, self._columns - 1)\n", + "\n", + " # Drop the ball.\n", + " self._ball_y += 1\n", + "\n", + " # Check for termination.\n", + " if self._ball_y == self._paddle_y:\n", + " reward = 1. if self._paddle_x == self._ball_x else -1.\n", + " self._reset_next_step = True\n", + " return dm_env.termination(reward=reward, observation=self._observation())\n", + " else:\n", + " return dm_env.transition(reward=0., observation=self._observation())\n", + "\n", + " def observation_spec(self) -> specs.BoundedArray:\n", + " \"\"\"Returns the observation spec.\"\"\"\n", + " return specs.BoundedArray(\n", + " shape=self._board.shape,\n", + " dtype=self._board.dtype,\n", + " name=\"board\",\n", + " minimum=0,\n", + " maximum=1,\n", + " )\n", + "\n", + " def action_spec(self) -> specs.DiscreteArray:\n", + " \"\"\"Returns the action spec.\"\"\"\n", + " return specs.DiscreteArray(\n", + " dtype=int, num_values=len(_ACTIONS), name=\"action\")\n", + "\n", + " def _observation(self) -> np.ndarray:\n", + " self._board.fill(0.)\n", + " self._board[self._ball_y, self._ball_x] = 1.\n", + " self._board[self._paddle_y, self._paddle_x] = 1.\n", + " return self._board.copy()" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "bbf697b8262efa0945ca8c26e05a7230a0c66880a0fc993cfa3876d84998e14e" + }, + "kernelspec": { + "display_name": "Python 3.10.4 ('dm_env')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyscript.html b/pyscript.html new file mode 100644 index 0000000..161b151 --- /dev/null +++ b/pyscript.html @@ -0,0 +1,22 @@ + +
+ + +