this below article assume some prior knowledge in scala and monads.
The State Monad is a concept from functional programming that allows for stateful computations to be represented in a pure functional way
The State Monad addresses this challenge by allowing functions to carry state along with them.
The key idea is to represent state as a function that takes an initial state and returns a new state along with a result.
Key Concepts of the State Monad
*State Representation: *
The State Monad represents a computation that carries some state. 
It can be thought of as a function that takes an initial state and returns a value along with a new state.
*Type Signature: *
The type signature of a State Monad can be represented as State s,a = s -> (a,s):  Here s  is the type of the state, and a is the type of the result.
Basic Operations:
get: Retrieves the current state.
put: Updates the state to a new value.
modify: Applies a function to the current state to produce a new state.
Composing State Computations: The State Monad allows for the composition of stateful computations using the flatMap (or bind) operation. This enables chaining multiple stateful operations together.
1. State Monad Definition
case class State[S, A](run: S => (S, A)) {
  // Execute the state computation and return the final state
  def exec(s: S): S = run(s)._1
  // Evaluate the state computation and return the result
  def eval(s: S): A = run(s)._2
}
Case Class: State is defined as a case class that takes two type parameters: S (the type of the state) and A (the type of the result).
Run Method: The run method is a function that takes an initial state of type S and returns a tuple containing the new state of type S and a result of type A.
exec Method: This method executes the state computation and returns the final state after running the computation.
eval Method: This method evaluates the state computation and returns the result of type A.
2. Companion Object for State
object State {
  // Create a State that replaces the current state with a new state
  def put[S](ns: S): State[S, Unit] = State(s => (ns, ()))
  // Create a State that retrieves the current state
  def get[S]: State[S, S] = State(s => (s, s))
  // Create a State that modifies the current state using a function
  def modify[S](fx: S => S): State[S, Unit] = State(s => (fx(s), ()))
}
put Method: This method creates a State that replaces the current state with a new state ns. The result type is Unit because it doesn't return any meaningful value.
get Method: This method creates a State that retrieves the current state. It returns the current state as both the new state and the result.
modify Method: This method creates a State that modifies the current state using a provided function fx. It applies fx to the current state and returns the modified state.
3. Functor Type Class
trait Functor[F[_]] {
  def fmap[A, B](fa: F[A])(f: A => B): F[B]
}
Functor Trait: This trait defines a type class for functors, which are types that can be mapped over. It has a method fmap that takes a functor fa containing a value of type A and a function f that transforms A into B. It returns a functor containing a value of type B.
4. Functor Instance for State
object Functor {
  // Functor instance for State using a type lambda
  given [S]: Functor[[A] =>> State[S, A]] with {
    override def fmap[A, B](a: State[S, A])(fx: A => B): State[S, B] = State { s =>
      val (newState, value) = a.run(s)
      (newState, fx(value))
    }
  }
Functor Instance: This provides an instance of the Functor type class for the State monad. The given keyword is used to define an implicit instance.
fmap Implementation: The fmap method is implemented to apply the function fx to the result of the state computation while preserving the state.
5. Monad Type Class
abstract class Monad[M[_]](using functor: Functor[M]) {
  def pure[A](a: A): M[A]
  def flatMap[A, B](a: M[A])(fx: A => M[B]): M[B]
}
Monad Trait: This abstract class defines a type class for monads, which are types that support chaining operations. It requires an implicit Functor instance.
pure Method: This method lifts a value of type A into the monad M.
flatMap Method: This method allows for chaining operations by taking a monadic value a and a function fx that returns a new monadic value.
6. Monad Instance for State
`object StateInstance {
  given [S]: Monad[[A] =>> State[S, A]] with {
    override def pure[A](a: A): State[S, A] = State(s => (s, a))
    override def flatMap[A, B](a: State[S, A])(fx: A => State[S, B]): State[S, B] = State { s =>
      val (newState, value) = a.run(s)
      fx(value).run(newState)
    }
  }
}`
Monad Instance: This provides an instance of the Monad type class for the State monad.
pure Implementation: The pure method creates a State that returns the value a while keeping the state unchanged.
flatMap Implementation: The flatMap method allows chaining stateful computations. It runs the initial state computation, retrieves the new state and value, and then applies the function fx to the value, running the resulting state computation with the new state.
7. Extension Methods for State
extension [S, A](state: State[S, A]) {
  // Map a function over the result of a State
  infix def map[B](f: A => B)(using functor: Functor[[A] =>> State[S, A]]): State[S, B] = 
    functor.fmap(state)(f)
  infix def flatMap[B](f: A => State[S, B])(using monad: Monad[[A] =>> State[S, A]]): State[S, B] = 
    monad.flatMap(state)(f)  
}
Extension Methods: These methods add map and flatMap to the State class, allowing for more convenient usage of the State monad.
map Method: This method allows you to apply a function f to the result of a State computation.
flatMap Method: This method allows you to chain stateful computations using a function that returns a new State.
8. Extension Method to Lift a Value into State
extension A {
  def pureState[S]: State[S, A] = State(s => (s, a))
}
pureState Method: This extension method allows you to lift a value a into a State monad, creating a stateful computation that returns a while keeping the state unchanged.
9. Example Usage
object StateExample {
  import com.example.Functor.given
  // import com.example.StateInstance.given
  type StringState[A] = State[String, A]
  def main(args: Array[String]): Unit = {
    given Monad[StringState] = StateInstance.given_Monad_State[String]
    val s = for {
      x <- 10.pureState[String] 
      os <- State.get[String]
      y <- 20.pureState[String]  
      _ <- State.put(os + " boom")
      z <- 30.pureState[String] 
      _ <- State.get[String].flatMap { ns => State.put(ns + " baam") }
    } yield { x + y + z }
    val (s3, r3) = s.run("hello")
    println(s"Example 3 - State: $s3, Result: $r3")
  }
}
StateExample Breakdown
Here’s the relevant part of the StateExample code again for reference:
Step-by-Step Explanation
Step 1: x <- 10.pureState[String]
This line lifts the value 10 into the State monad. The state remains unchanged at this point.
The result of this operation is that x will hold the value 10.
Step 2: os <- State.get[String]
The get operation retrieves the current state.
At this point, the initial state is "hello" (the argument passed to s.run("hello")).
The result of this operation is that os will hold the value "hello".
Step 3: y <- 20.pureState[String]
Similar to Step 1, this line lifts the value 20 into the State monad.
The state is still unchanged at this point.
The result of this operation is that y will hold the value 20.
Step 4: _ <- State.put(os + " boom")
The put operation replaces the current state with a new state. Here, it takes the value of os (which is "hello") and appends " boom" to it.
The new state becomes "hello boom".
This operation does not return a value (hence the underscore _), but it modifies the state.
Step 5: z <- 30.pureState[String]
This line lifts the value 30 into the State monad.
The state is still "hello boom" at this point.
The result of this operation is that z will hold the value 30.
Step 6: _ <- State.get[String].flatMap { ns => State.put(ns + " baam") }
First, State.get[String] retrieves the current state, which is now "hello boom".
The result of this get operation is that ns will hold the value "hello boom".
The flatMap method is then used to chain another operation. It takes the value of ns and appends " baam" to it, creating a new state of "hello boom baam".
The put operation replaces the current state with this new state.
Again, this operation does not return a value (hence the underscore _).
Final Result Calculation
After all these steps, the final result of the for comprehension is calculated:
Final Result Calculation
After all these steps, the final result of the for comprehension is calculated:
yield { x + y + z }
Here, x is 10, y is 20, and z is 30.
The final result is .
Summary of State Changes
Initial State: "hello"
After Step 2 (get): State remains "hello", os is "hello".
After Step 4 (put): State changes to "hello boom".
After Step 6 (get and put): State changes to "hello boom baam".
Final Output
When the state computation is run with the initial state "hello":
val (s3, r3) = s.run("hello")
s3 will hold the final state, which is "hello boom baam".
r3 will hold the result of the computation, which is 60.
The output will be:
Example 3 - State: hello boom baam, Result: 60
the full gist can be found at 
[https://gist.github.com/depareddy/a0bfb88b9cdfa627395cefcbe0563a5e]
    
Top comments (0)