DEV Community

Byron Salty
Byron Salty

Posted on

Updating state through multiple epochs in Elixir

Elixir's immutability makes some things more difficult, though it's a trade-off I'm willing to make so that values are not changing in unpredictable ways.

But what if you want to mutate some value over and over again, as is the case when you are training an ML model over a number of epochs?

After each epoch we want to update our model and run a validation to calculate our loss.

At first, I only knew how to accomplish this using Enum.reduce and setting the model as the accumulated value.

But I believe most would find a variation of a for loop more readable. Luckily there is a reduce option for Elixir comprehensions (aka for loop) which allows us to again treat the accumulated value as the state.

For completeness, why not see what the code would look like with recursion as well? I actually think this might be the most readable form, if I wasn't using anonymous functions and therefore needed a bit extra to make this work. (Thanks to Lucas Perez for his article on how to make recursive anonymous functions)

If I was putting this into modules I'd definitely take a look at the recursive version first, and then go with the for version second.

# Obviously the train and validate will be much more complex
#  these are placeholders so that the code will execute
train_epoch = fn model -> model + 1 end
validate_epoch = fn model -> IO.inspect(model) end


train_model_for = fn model, epochs ->
  for _ <- 1..epochs, reduce: model do
    model ->
      newModel = train_epoch.(model)
      validate_epoch.(newModel)
      newModel
  end
end

train_model_enum = fn model, epochs ->
  Enum.reduce(1..epochs, model, fn _, model ->
    newModel = train_epoch.(model)
    validate_epoch.(newModel)
    newModel
  end)
end

train_model_recur = fn 
  model, 0, _ -> model
  model, epochs, func ->
    newModel = train_epoch.(model)
    validate_epoch.(newModel)
    func.(newModel, epochs - 1, func)
end


train_model_for.(0, 10)

train_model_enum.(0, 10)

train_model_recur.(0, 10, train_model_recur)
Enter fullscreen mode Exit fullscreen mode

Top comments (0)