Newer
Older
notebooks / acme_env.ipynb
Morteza Ansarinia on 17 Jun 2022 4 KB pyscript and Acme Env example
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install dm-env \"dm-acme[jax,tf,envs]\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dm_env\n",
    "from dm_env import specs\n",
    "import numpy as np\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "_ACTIONS = (-1, 0, 1)  # Left, no-op, right.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "class Catch(dm_env.Environment):\n",
    "  \"\"\"A Catch environment built on the `dm_env.Environment` class.\n",
    "  The agent must move a paddle to intercept falling balls. Falling balls only\n",
    "  move downwards on the column they are in.\n",
    "  The observation is an array shape (rows, columns), with binary values:\n",
    "  zero if a space is empty; 1 if it contains the paddle or a ball.\n",
    "  The actions are discrete, and by default there are three available:\n",
    "  stay, move left, and move right.\n",
    "  The episode terminates when the ball reaches the bottom of the screen.\n",
    "  \"\"\"\n",
    "\n",
    "  def __init__(self, rows: int = 10, columns: int = 5, seed: int = 1):\n",
    "    \"\"\"Initializes a new Catch environment.\n",
    "    Args:\n",
    "      rows: number of rows.\n",
    "      columns: number of columns.\n",
    "      seed: random seed for the RNG.\n",
    "    \"\"\"\n",
    "    self._rows = rows\n",
    "    self._columns = columns\n",
    "    self._rng = np.random.RandomState(seed)\n",
    "    self._board = np.zeros((rows, columns), dtype=np.float32)\n",
    "    self._ball_x = None\n",
    "    self._ball_y = None\n",
    "    self._paddle_x = None\n",
    "    self._paddle_y = self._rows - 1\n",
    "    self._reset_next_step = True\n",
    "\n",
    "  def reset(self) -> dm_env.TimeStep:\n",
    "    \"\"\"Returns the first `TimeStep` of a new episode.\"\"\"\n",
    "    self._reset_next_step = False\n",
    "    self._ball_x = self._rng.randint(self._columns)\n",
    "    self._ball_y = 0\n",
    "    self._paddle_x = self._columns // 2\n",
    "    return dm_env.restart(self._observation())\n",
    "\n",
    "  def step(self, action: int) -> dm_env.TimeStep:\n",
    "    \"\"\"Updates the environment according to the action.\"\"\"\n",
    "    if self._reset_next_step:\n",
    "      return self.reset()\n",
    "\n",
    "    # Move the paddle.\n",
    "    dx = _ACTIONS[action]\n",
    "    self._paddle_x = np.clip(self._paddle_x + dx, 0, self._columns - 1)\n",
    "\n",
    "    # Drop the ball.\n",
    "    self._ball_y += 1\n",
    "\n",
    "    # Check for termination.\n",
    "    if self._ball_y == self._paddle_y:\n",
    "      reward = 1. if self._paddle_x == self._ball_x else -1.\n",
    "      self._reset_next_step = True\n",
    "      return dm_env.termination(reward=reward, observation=self._observation())\n",
    "    else:\n",
    "      return dm_env.transition(reward=0., observation=self._observation())\n",
    "\n",
    "  def observation_spec(self) -> specs.BoundedArray:\n",
    "    \"\"\"Returns the observation spec.\"\"\"\n",
    "    return specs.BoundedArray(\n",
    "        shape=self._board.shape,\n",
    "        dtype=self._board.dtype,\n",
    "        name=\"board\",\n",
    "        minimum=0,\n",
    "        maximum=1,\n",
    "    )\n",
    "\n",
    "  def action_spec(self) -> specs.DiscreteArray:\n",
    "    \"\"\"Returns the action spec.\"\"\"\n",
    "    return specs.DiscreteArray(\n",
    "        dtype=int, num_values=len(_ACTIONS), name=\"action\")\n",
    "\n",
    "  def _observation(self) -> np.ndarray:\n",
    "    self._board.fill(0.)\n",
    "    self._board[self._ball_y, self._ball_x] = 1.\n",
    "    self._board[self._paddle_y, self._paddle_x] = 1.\n",
    "    return self._board.copy()"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "bbf697b8262efa0945ca8c26e05a7230a0c66880a0fc993cfa3876d84998e14e"
  },
  "kernelspec": {
   "display_name": "Python 3.10.4 ('dm_env')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}