diff --git a/src/main/scala/org/xcit/nback/markov/models/MarkovChain.scala b/src/main/scala/org/xcit/nback/markov/models/MarkovChain.scala index 5296968..5694371 100644 --- a/src/main/scala/org/xcit/nback/markov/models/MarkovChain.scala +++ b/src/main/scala/org/xcit/nback/markov/models/MarkovChain.scala @@ -2,23 +2,38 @@ import scala.collection.mutable -object MarkovChain extends MarkovChain() { -} - /** * Represents a simple markov chain - * TODO: must be extended to HMM + * TODO: must be extended to HMM (add emission probability) */ -class MarkovChain() { +class MarkovChain(startState: State) { val states: mutable.MutableList[State] = mutable.MutableList[State]() + var currentState: State = startState + + //TODO + def next(): State = ??? + + //def currentProbability(): Double = currentState.transitions.size.toDouble / totalTransitions().toDouble + def totalTransitions(): Int = { var total = 0 states.foreach(total += _.transitions.size) total } - def addTransition(fromState: State, toState: State): Unit = { - states.find(s => s == fromState).getOrElse(new State(fromState.label)) - } + /** + * Add a transition from "from" state to "to" state with a defined probability. + * @param from origin node + * @param to destination node + * @param probability transition probability + */ + def addTransition(from: State, to: State, probability: Double): Unit = + states + .find(_.label == from.label) + .map(s => + s.transitions += Transition(to, probability) + ) + + } diff --git a/src/main/scala/org/xcit/nback/markov/models/MarkovChain.scala b/src/main/scala/org/xcit/nback/markov/models/MarkovChain.scala index 5296968..5694371 100644 --- a/src/main/scala/org/xcit/nback/markov/models/MarkovChain.scala +++ b/src/main/scala/org/xcit/nback/markov/models/MarkovChain.scala @@ -2,23 +2,38 @@ import scala.collection.mutable -object MarkovChain extends MarkovChain() { -} - /** * Represents a simple markov chain - * TODO: must be extended to HMM + * TODO: must be extended to HMM (add emission probability) */ -class MarkovChain() { +class MarkovChain(startState: State) { val states: mutable.MutableList[State] = mutable.MutableList[State]() + var currentState: State = startState + + //TODO + def next(): State = ??? + + //def currentProbability(): Double = currentState.transitions.size.toDouble / totalTransitions().toDouble + def totalTransitions(): Int = { var total = 0 states.foreach(total += _.transitions.size) total } - def addTransition(fromState: State, toState: State): Unit = { - states.find(s => s == fromState).getOrElse(new State(fromState.label)) - } + /** + * Add a transition from "from" state to "to" state with a defined probability. + * @param from origin node + * @param to destination node + * @param probability transition probability + */ + def addTransition(from: State, to: State, probability: Double): Unit = + states + .find(_.label == from.label) + .map(s => + s.transitions += Transition(to, probability) + ) + + } diff --git a/src/main/scala/org/xcit/nback/markov/models/State.scala b/src/main/scala/org/xcit/nback/markov/models/State.scala index ab94bdb..fbbd844 100644 --- a/src/main/scala/org/xcit/nback/markov/models/State.scala +++ b/src/main/scala/org/xcit/nback/markov/models/State.scala @@ -6,10 +6,10 @@ * A Single state with a label in Markov chain * @param label A simple string label, representing the state and node */ -class State(label: String) { +class State(val label: String) { - val transitions: mutable.HashMap[State, Transition] = mutable.HashMap[State, Transition]() - - //TODO move to chain - def probability(): Double = transitions.size.toDouble / MarkovChain.totalTransitions().toDouble + /** + * Keep track of transitions from this state to other states. + */ + val transitions: mutable.MutableList[Transition] = mutable.MutableList[Transition]() } diff --git a/src/main/scala/org/xcit/nback/markov/models/MarkovChain.scala b/src/main/scala/org/xcit/nback/markov/models/MarkovChain.scala index 5296968..5694371 100644 --- a/src/main/scala/org/xcit/nback/markov/models/MarkovChain.scala +++ b/src/main/scala/org/xcit/nback/markov/models/MarkovChain.scala @@ -2,23 +2,38 @@ import scala.collection.mutable -object MarkovChain extends MarkovChain() { -} - /** * Represents a simple markov chain - * TODO: must be extended to HMM + * TODO: must be extended to HMM (add emission probability) */ -class MarkovChain() { +class MarkovChain(startState: State) { val states: mutable.MutableList[State] = mutable.MutableList[State]() + var currentState: State = startState + + //TODO + def next(): State = ??? + + //def currentProbability(): Double = currentState.transitions.size.toDouble / totalTransitions().toDouble + def totalTransitions(): Int = { var total = 0 states.foreach(total += _.transitions.size) total } - def addTransition(fromState: State, toState: State): Unit = { - states.find(s => s == fromState).getOrElse(new State(fromState.label)) - } + /** + * Add a transition from "from" state to "to" state with a defined probability. + * @param from origin node + * @param to destination node + * @param probability transition probability + */ + def addTransition(from: State, to: State, probability: Double): Unit = + states + .find(_.label == from.label) + .map(s => + s.transitions += Transition(to, probability) + ) + + } diff --git a/src/main/scala/org/xcit/nback/markov/models/State.scala b/src/main/scala/org/xcit/nback/markov/models/State.scala index ab94bdb..fbbd844 100644 --- a/src/main/scala/org/xcit/nback/markov/models/State.scala +++ b/src/main/scala/org/xcit/nback/markov/models/State.scala @@ -6,10 +6,10 @@ * A Single state with a label in Markov chain * @param label A simple string label, representing the state and node */ -class State(label: String) { +class State(val label: String) { - val transitions: mutable.HashMap[State, Transition] = mutable.HashMap[State, Transition]() - - //TODO move to chain - def probability(): Double = transitions.size.toDouble / MarkovChain.totalTransitions().toDouble + /** + * Keep track of transitions from this state to other states. + */ + val transitions: mutable.MutableList[Transition] = mutable.MutableList[Transition]() } diff --git a/src/main/scala/org/xcit/nback/markov/models/Transition.scala b/src/main/scala/org/xcit/nback/markov/models/Transition.scala index 1e4692f..bc20a9e 100644 --- a/src/main/scala/org/xcit/nback/markov/models/Transition.scala +++ b/src/main/scala/org/xcit/nback/markov/models/Transition.scala @@ -1,12 +1,8 @@ package org.xcit.nback.markov.models /** - * A Single transition in Markov chain graph. - * @param from The starting vertex + * A Single transition in Markov chain graph. It is stored in "from" node. * @param to the ending node of the edge * @param probability the probability of this transition from starting node to the ending node (0.0 <= p <= 1.0). */ -class Transition(from: State, to: State, probability: Double = 0.0) { - -} - +case class Transition(to: State, probability: Double = 0.0)