DEV Community

Cover image for A Pure Implementation of the ST Monad
Mike Solomon
Mike Solomon

Posted on

2

A Pure Implementation of the ST Monad

I've been part of a couple discussions recently about the nature of ST and, more specifically, its relationship to purity and certain monadic laws. Folks including myself have been asking:

  • Is it possible to implement ST in a "pure" way? By "pure" here, I mean a hand-wavy notion of referentially transparent, no dependencies on external libs or low-level hacks, etc.
  • Does ST act like a transformer? Can it have an instance of MonadTrans?

The answer is yes and yes! In this article, I'll share a no-frills pure implementation of ST in one file with no external dependencies. I'll also provide a strategy to give it a valid instance of MonadTrans.

ST - the basics

ST is a monad that allows you to write code that feels like JavaScript circa 1995. Everything is a variable, all variables can be modified, and the assignment of a variable to another variable captures the original variable's value at the point of assignment.

var foo = 1; // assignment 
var bar = foo; // read + assignment
foo = 3; // write
console.log(bar); // 1
bar = 5; // write
console.log(bar) // 5
Enter fullscreen mode Exit fullscreen mode

In Haskell and PureScript, the corresponding code would be:

foo <- new 1 -- assignment 
bar <- read foo >>= new -- read + assignment
write 3 foo -- write
log bar -- 1
write 5 bar -- write
log bar -- 5
Enter fullscreen mode Exit fullscreen mode

ST is nice for several reasons:

  • for those who have never programmed in an ML-family language like Haskell or PureScript, ST is a great way to translate imperative code into functional code with minimal modifications.
  • many algorithms are faster and easier to express using ST.
  • some types, like those representing event streams, need to create ad hoc buffers of values for situations like debouncing and underflow, and ST is a useful way to implement these features (this is how, for example, purescript-hyrule works).

ST is almost always implemented by writing its logic in a non-functional language that more closely resembles imperative code. For example, in PureScript, almost all of ST is implemented in JavaScript, and the compiler even has special rules to rewrite do blocs of ST code as imperative code. This leads to substantial performance gains when working with ST.

For understandable reasons, there's a common misconception that ST is a sort of imperative "back door" to functional languages. Even though ST often has an imperative implementation, it is possible to conceive of it in an entirely "pure" way. That's what this article is about!

How is this done?

The first thing to do is define a ST type.

newtype ST r i o a = ST ((Nat /\ i) -> (Nat /\ o /\ a))
Enter fullscreen mode Exit fullscreen mode

Let's unpack what the type is saying:

  • r is a phantom type that will "lock" ST computations into a context.
  • i represents all of the types for which we have made ad hoc values before this computation. For example, in the JS and Haskell/PS examples above, we are using integers, so this would contain Int.
  • o represents the types that are used after this computation. For example, if the computation creates a new reference to a Boolean and we have no other references to Booleans in our computation, it will add Boolean to the cache of types.
  • a is the actual value in the context, like in any garden-variety monad.

The constructor itself looks a lot like the State monad from Haskell and PureScript's MTL libraries. Here's a simplified definition of State:

type State = s -> Tuple a s
Enter fullscreen mode Exit fullscreen mode

The difference between State and ST is that, instead of an open s type, we restrict the variable in ST to a cache that is managed automatically by type classes. The ST cache will contain all of the references with their current value.

Then, we need some type classes to manage two operations:

  • creating a new reference (we'll call this STCons)
  • modifying a reference (we'll call this STModify, and we'll use it to define both read and write operations)

STCons

The definition of STCons is as follows:

class STCons (x :: Type) (i :: Type) (o :: Type) | x i -> o where
  stCons :: Nat -> x -> i -> o
Enter fullscreen mode Exit fullscreen mode

It is very similar to the cons one would see for any associative list. There is a Nat for the index, an x type we are storing, the input list i and the output list o. At both the term level (the definition of stCons and the type level (the functional dependency x i -> o), we define a function that inserts x in i to produce o. When we call new, this is how will update our cached values.

STModify

The definition of STModify is as follows:

class STModify (x :: Type) (y :: Type) where
  stModify :: (x -> x) -> Nat -> x -> y -> x /\ y
Enter fullscreen mode Exit fullscreen mode

We can only legally modify a value in the cache if its type has already been added via a call to stCons. So, instead of tracking and i and o type, we have a single type y representing the cache with this value already present. If when solving the class constraint the compiler finds that the type x is not present in y, the program will not compile. This is how the correctness of the cache is enforced at the type level.

stModify can be used to implement both read and write. In the case of read, we use the identity function as our x -> x. In the case of the write function, we stash the modified y in our ST and return Unit.

Show me the code!

Here's the whole thing. Everything before -- test is the implementation, and everything afterwards is a test that throws a lot of stuff at the type (several new-s, read and write in different orders, etc).

module ST where
data Unit = Unit
data Tuple a b = Tuple a b
infixr 6 Tuple as /\
infixr 6 type Tuple as /\
data Nat = Succ Nat | Z
data List a = Cons a (List a) | Nil
newtype ST :: forall k. k -> Type -> Type -> Type -> Type
newtype ST r i o a = ST ((Nat /\ i) -> (Nat /\ o /\ a))
newtype Ref a = Ref (Tuple Nat a)
class Functor f where
map :: forall a b. (a -> b) -> f a -> f b
instance Functor List where
map f (Cons a b) = Cons (f a) (map f b)
map _ Nil = Nil
instance Functor (Tuple a) where
map f (Tuple a b) = Tuple a (f b)
instance Functor (Function a) where
map f a x = f (a x)
instance Functor (ST r i o) where
map f (ST x) = ST ((\y -> map (map (map y))) f x)
class Apply :: forall k. (k -> k -> Type -> Type) -> Constraint
class Apply f where
apply :: forall a b x y z. f x y (a -> b) -> f y z a -> f x z b
instance Apply (ST r) where
apply (ST f0) (ST f1) = ST \nr -> do
let nat0 /\ r0 /\ fab = f0 nr
let nat1 /\ r1 /\ a = f1 (nat0 /\ r0)
nat1 /\ r1 /\ (fab a)
class Applicative :: forall k. (k -> k -> Type -> Type) -> k -> Constraint
class Applicative f i where
pure :: forall a. a -> f i i a
instance Applicative (ST r) i where
pure a = ST \(n /\ r) -> n /\ r /\ a
class Bind :: forall k. (k -> k -> Type -> Type) -> Constraint
class Bind m where
bind :: forall a b x y z. m x y a -> (a -> m y z b) -> m x z b
instance Bind (ST r) where
bind (ST ma) f = ST \nr -> do
let nat0 /\ r0 /\ a = ma nr
let ST f1 = f a
f1 (nat0 /\ r0)
class Monad :: forall k. (k -> k -> Type -> Type) -> Constraint
class Bind m <= Monad m
class Discard a where
discard :: forall f (x :: Type) (y :: Type) (z :: Type) b. Bind f => f x y a -> (a -> f y z b) -> f x z b
instance discardUnit :: Discard Unit where
discard = bind
class STCons (x :: Type) (y :: Type) (z :: Type) | x y -> z where
stCons :: Nat -> x -> y -> z
instance STCons a (List (Nat /\ a) /\ x) (List (Nat /\ a) /\ x) where
stCons n x (l /\ r) = (Cons (n /\ x) l) /\ r
else instance
STCons b x o =>
STCons b (List (Nat /\ a) /\ x) (List (Nat /\ a) /\ o) where
stCons n x (l /\ r) = l /\ stCons n x r
else instance STCons a (List (Nat /\ a)) (List (Nat /\ a)) where
stCons n x l = Cons (n /\ x) l
else instance STCons a (List (Nat /\ b)) (List (Nat /\ a) /\ List (Nat /\ b)) where
stCons n x r = Cons (n /\ x) Nil /\ r
class STModify (x :: Type) (y :: Type) where
stModify :: (x -> x) -> Nat -> x -> y -> x /\ y
instance STModify a (List (Nat /\ a) /\ x) where
stModify f n x (l /\ y) = let a /\ b = stModify f n x l in a /\ (b /\ y)
else instance
STModify b x =>
STModify b (List (Nat /\ a) /\ x) where
stModify f n x (y /\ l) = let a /\ b = stModify f n x l in a /\ (y /\ b)
instance STModify a (List (Nat /\ a)) where
stModify f ix v l' = go l'
where
go (Cons (n /\ x) l) = comp n ix n x l
go Nil = v /\ Nil
comp n Z Z x l = let fx = f x in fx /\ (Cons (n /\ fx) l)
comp n (Succ j) (Succ k) x l = comp n j k x l
comp n (Succ _) Z x l = let a /\ b = go l in a /\ (Cons (n /\ x) b)
comp n Z (Succ _) x l = let a /\ b = go l in a /\ (Cons (n /\ x) b)
new :: forall a r i o. STCons a i o => a -> ST r i o (Ref a)
new a = ST \(n /\ r) -> do
let ix = Succ n
let newR = stCons ix a r
ix /\ newR /\ (Ref (ix /\ a))
read :: forall r a i. STModify a i => Ref a -> ST r i i a
read (Ref (ix /\ v)) = ST \(n /\ r) ->
n /\ r /\ (let a /\ _ = stModify (\z -> z) ix v r in a)
write :: forall r a i. STModify a i => a -> Ref a -> ST r i i Unit
write v (Ref (ix /\ o)) = ST \(n /\ r) ->
n /\ (let _ /\ b = stModify (\_ -> v) ix o r in b) /\ Unit
modify :: forall r a i. STModify a i => (a -> a) -> Ref a -> ST r i i a
modify f (Ref (ix /\ v)) = ST \(n /\ r) -> do
let a /\ b = stModify f ix v r
n /\ b /\ a
modify_ :: forall r a i. STModify a i => (a -> a) -> Ref a -> ST r i i Unit
modify_ f r = do
_ <- modify f r
pure Unit
type STU :: forall k. k -> Type -> Type -> Type
type STU r o a = ST r (List (Nat /\ Unit)) o a
runST :: forall a o. (forall r. STU r o a) -> a
runST (ST f) = let _ /\ _ /\ c = f (Z /\ (Cons (Z /\ Unit) Nil)) in c
-- test
data Vehicle = Boat | Car | Train
data Composer = Bach | Beethoven | Brahms
data Color = Red | Green | Blue
colorVehicle :: forall r. STU r _ (Color /\ Vehicle)
colorVehicle = do
_ <- new Green
myColor <- new Blue
write Red myColor
color2 <- read myColor
myVehicle <- new Car
write Green myColor
_ <- new Green
write color2 myColor
outColor <- read myColor
outVehicle <- read myVehicle
_ <- new Beethoven
write Boat myVehicle
pure (outColor /\ outVehicle)
myColorVehicle :: Color /\ Vehicle
myColorVehicle = runST colorVehicle
view raw ST.purs hosted with ❤ by GitHub

Because the file is self-contained with no side effects, the test can't actually execute. So here's a bit of impure code we can use to test that our test code in the gist actually works:

module Main where

import Prelude

import Effect (Effect)
import Effect.Console (log)
import ST (myColorVehicle)
import Unsafe.Coerce (unsafeCoerce)

main :: Effect Unit
main = do
  log $ unsafeCoerce myColorVehicle
  pure unit
Enter fullscreen mode Exit fullscreen mode

When we run this, we see in the console:

Tuple { value0: Red {}, value1: Car {} }
Enter fullscreen mode Exit fullscreen mode

Sure enough, following the control flow of myColorVehicle, this is exactly what we'd expect!

What about MonadTrans

An implementation of MonadTrans would require a new type STT:

newtype STT r i o m a = ST ((Nat /\ i) -> m (Nat /\ o /\ a))
Enter fullscreen mode Exit fullscreen mode

This type has the same topology as StateT and an instance of MonadTrans can be defined in a similar manner, allowing it to be used as a transformer in an MTL stack.

Image of Datadog

Create and maintain end-to-end frontend tests

Learn best practices on creating frontend tests, testing on-premise apps, integrating tests into your CI/CD pipeline, and using Datadog’s testing tunnel.

Download The Guide

Top comments (0)

Billboard image

Create up to 10 Postgres Databases on Neon's free plan.

If you're starting a new project, Neon has got your databases covered. No credit cards. No trials. No getting in your way.

Try Neon for Free →

👋 Kindness is contagious

Please leave a ❤️ or a friendly comment on this post if you found it helpful!

Okay