DEV Community

loading...
Cover image for Neural Nets in C# vs F#

Neural Nets in C# vs F#

Matt Eland
Matt is committed to helping people achieve greater things. After over three decades of coding, Matt put away his mechanical keyboard and made teaching his primary job as he looks to help others grow.
・10 min read

This is a comparison of C# and F# implementations of programming a simple neural network library I wrote for use in a side project.

What is a Neural Net?

A neural net is essentially a calculator that takes one or more numerical inputs and computes one or more numerical outputs.

Neural Net

Neural networks can fulfill a wide range of functions, but where they excel is finding an optimal output given various inputs.

The inputs are held in an input layer which connects to another layer - either the output layer or a hidden layer. Each layer, including the input layer consists of one or more neurons. Each neuron is connected to every neuron in the next layer and given a positive or negative weight indicating how important that connection is.

Individual neurons compute their values by calculating performing a summarizing function on all inputs from the prior layer with each input multiplied by the weight of the neuron connection. This then feeds in as an input into the next layer until they arrive in the output layer.

The output layer is what the caller of the neural net will use to evaluate the result of the network. This could represent anything from whether or not to buy a piece of stock, to the attractiveness of a move in a game, to what the dominant color in an image is or how happy a face appears.

Neural nets achieve these calculations via their inter-connected nature allowing
flexibility to represent innovative solutions to problems, but neural nets are hard to interpret just by reading them.

Neural Networks typically have either a back propogation mechanism for training or are trained by some other factor such as a Genetic Algorithm, but both are beyond the scope of this article.

.NET Implementation

Neuron

A Neuron summarizes and stores a value from other inputs.

Neuron

C# Neuron

In the C# Implementation, there's a lot of boiler-plate code for maintaining fields and properties as well as connecting to other nodes and layers. The core evaluation logic occurs in the Evaluate method and is fairly minimal, but supported by the connections established in the supporting methods.

    /// <summary>
    /// Represents a Neuron in a layer of a Neural Network.
    /// </summary>
    public class Neuron
    {
        /// <summary>
        /// Gets or sets the value of the Neuron.
        /// </summary>
        public decimal Value { get; set; }

        /// <summary>
        /// Creates a new instance of a <see cref="Neuron"/>.
        /// </summary>
        public Neuron()
        {
            OutgoingConnections = new List<NeuronConnection>();
        }

        /// <summary>
        /// Connects this Neuron to the <paramref name="nextNeuron" />
        /// </summary>
        /// <param name="nextNeuron">The Neuron to connect to.</param>
        internal void ConnectTo([NotNull] Neuron nextNeuron)
        {
            if (nextNeuron == null) throw new ArgumentNullException(nameof(nextNeuron));

            OutgoingConnections.Add(new NeuronConnection(nextNeuron));
        }

        private decimal _sum;
        private int _numInputs;

        /// <summary>
        /// Evaluates the values from the incoming connections, averages them by the count of connections,
        /// and calculates the Neuron's Value, which is then passed on to any outgoing connections.
        /// </summary>
        internal void Evaluate()
        {
            if (_numInputs > 0)
            {
                Value = _sum / _numInputs;
                _sum = 0;
            }

            OutgoingConnections.Each(c => c.Fire(Value));
        }

        /// <summary>
        /// The list of outgoing Neuron connections
        /// </summary>
        [NotNull, ItemNotNull]
        public IList<NeuronConnection> OutgoingConnections { get; }

        /// <summary>
        /// Receives a value from a connection.
        /// </summary>
        /// <param name="value">The value to receive</param>
        internal void Receive(decimal value) => _sum += value;

        /// <summary>
        /// Registers an incoming connection from another neuron.
        /// </summary>
        /// <param name="neuronConnection">The connection</param>
        internal void RegisterIncomingConnection([NotNull] NeuronConnection neuronConnection)
        {
            if (neuronConnection == null) throw new ArgumentNullException(nameof(neuronConnection));

            _numInputs++;
        }
    }
Enter fullscreen mode Exit fullscreen mode

F# Neuron

By contrast, the F# implementation is minimal and offers some brief property storage and some simple Connect and Evaluate methods.

/// Represents a node in a Neural Network
and Neuron ([<Optional>] ?initialValue: decimal) =
  let mutable value = defaultArg initialValue 0M;
  let mutable inputs: NeuronConnection seq = Seq.empty;

  /// Exposes the current calculated amount of the Neuron
  member this.Value
    with get () = value
    and set (newValue) = value <- newValue

  /// Incoming connections from other Neurons (if any)
  member this.Inputs: NeuronConnection seq = inputs;

  /// Adds an incoming connection from another Neuron
  member this.AddIncomingConnection c = inputs <- Seq.append this.Inputs [c];

  /// Adds all connections together, stores the result in Value, and returns the value
  member this.Evaluate(): decimal =
    if not (Seq.isEmpty this.Inputs) then do
      let numInputs = Seq.length this.Inputs |> decimal
      value <- Seq.sumBy (fun (c:NeuronConnection) -> c.Evaluate()) this.Inputs / numInputs;
    value;

  /// Connects this neuron to another and returns the connection
  member this.Connect(target: Neuron) =
    let connection = new NeuronConnection(this);
    target.AddIncomingConnection(connection);
    connection;
Enter fullscreen mode Exit fullscreen mode

One of the things I like about this is that there isn't a lot of meaningless syntax, spacing, or irrelevant logic. The downside of this is that the functional syntax can be harder to read while scanning code.

Layers

Layers are just collections of neurons in the same tier. The layer code is used for managing inter-connections between nodes in different layers.

C# Layer

The NeuralNetLayer is honestly fairly boring. It acts as glue between the different nodes, but the implementation takes 100 lines of code to do that.

    /// <summary>
    /// Represents a layer in a neural network. This could be an input, output, or hidden layer.
    /// </summary>
    public class NeuralNetLayer : IEnumerable<Neuron>
    {
        private readonly IList<Neuron> _neurons;

        [CanBeNull]
        private NeuralNetLayer _nextLayer;

        /// <summary>
        /// Creates a new neural network layer with the given count of neurons.
        /// </summary>
        /// <param name="numNeurons">The number of neurons in the layer</param>
        /// <exception cref="ArgumentOutOfRangeException">
        /// Thrown if <paramref name="numNeurons" /> was less than 1
        /// </exception>
        public NeuralNetLayer(int numNeurons)
        {
            if (numNeurons <= 0)
            {
                throw new ArgumentOutOfRangeException(nameof(numNeurons), "Each layer must have at least one Neuron");
            }

            _neurons = new List<Neuron>(numNeurons);

            numNeurons.Each(n => _neurons.Add(new Neuron()));
        }

        /// <summary>
        /// Gets the Neurons belonging to this layer.
        /// </summary>
        public IEnumerable<Neuron> Neurons => _neurons;

        /// <summary>
        /// Sets the values of the layer to the given values set. One value will be used for each neuron in the layer.
        /// </summary>
        /// <param name="values">The values to use.</param>
        /// <exception cref="ArgumentException">Thrown if <paramref name="values"/> did not have an expected values count.</exception>
        internal void SetValues([NotNull] IEnumerable<decimal> values)
        {
            if (values == null) throw new ArgumentNullException(nameof(values));
            if (values.Count() != _neurons.Count) throw new ArgumentException("The number of inputs must match the number of neurons in a layer", nameof(values));

            int i = 0;
            values.Each(v => _neurons[i++].Value = v);
        }

        /// <inheritdoc />
        public IEnumerator<Neuron> GetEnumerator() => _neurons.GetEnumerator();

        /// <inheritdoc />
        IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

        /// <summary>
        /// Evaluates each node in the layer, as well as the next layer if one is present.
        /// </summary>
        /// <returns>The outputs from the Output layer</returns>
        internal IEnumerable<decimal> Evaluate()
        {
            // Calculate all neurons.
            _neurons.Each(n => n.Evaluate());

            // If this is the last layer, return its values, otherwise delegate to the next layer and return its results
            return _nextLayer == null 
                ? _neurons.Select(n => n.Value) 
                : _nextLayer.Evaluate();
        }

        /// <summary>
        /// Connects this layer to the <paramref name="nextLayer"/>, forming connections between each node in this
        /// layer and each node in the next layer.
        /// </summary>
        /// <param name="nextLayer">The layer to connect to</param>
        internal void ConnectTo([NotNull] NeuralNetLayer nextLayer)
        {
            _nextLayer = nextLayer ?? throw new ArgumentNullException(nameof(nextLayer));

            _neurons.Each(source => nextLayer.Each(source.ConnectTo));
        }

        /// <summary>
        /// Sets the weights in the layer to the values provided
        /// </summary>
        /// <param name="weights">The weights to use to set in the connections</param>
        [UsedImplicitly]
        public void SetWeights(IList<decimal> weights)
        {
            int weightIndex = 0;
            _neurons.Each(neuron => neuron.OutgoingConnections.Each(c => c.Weight = weights[weightIndex++]));
        }
    }
Enter fullscreen mode Exit fullscreen mode

F# Layer

The F# version is shorter which uses Seq (sequence) methods to delegate responsibilities to individual Neurons.

/// A layer is just a series of Neurons in parallel that will link to every Neuron in the next layer (if any is present)
and NeuralNetLayer(numNeurons: int) =
  do if numNeurons <= 0 then invalidArg "numNeurons" "There must be at least one neuron in each layer";

  let neurons: Neuron seq = seq [ for i in 1 .. numNeurons -> new Neuron 0M]
  /// Layers should start with an empty collection of neurons
  member this.Neurons: Neuron seq = neurons;

  /// Sets the value of every neuron in the sequence to the corresponding ordered value provided
  member this.SetValues (values: decimal seq) = 
    let assignValue (n:Neuron) (v:decimal) = n.Value <- v;
    Seq.iter2 assignValue this.Neurons values

  /// Evaluates the layer and returns the value of each node
  member this.Evaluate(): decimal seq =
    for n in this.Neurons do n.Evaluate() |> ignore;
    Seq.map (fun (n:Neuron) -> n.Value) this.Neurons;

  /// Connects every node in this layer to the target layer
  member this.Connect(layer: NeuralNetLayer): unit = 
    for nSource in neurons do
      for nTarget in layer.Neurons do
        nSource.Connect(nTarget) |> ignore;
Enter fullscreen mode Exit fullscreen mode

Neural Net

The Neural Net ties everything together into one wrapper. It arranges layers, exposes the inputs and outputs, and offers a way for callers to configure the network into a pre-determined arrangement.

C# Neural Net

Keeping to form, the C# implementation does some basic iteration and enumeration, but has a pronounced amount of extra space devoted only to syntax.

    /// <summary>
    /// Represent a neural network consisting of an input layer, an output layer, and 0 to many hidden layers.
    /// Neural networks can compute values and return a set of output values, allowing for computation to occur
    /// between layers.
    /// </summary>
    public class NeuralNet
    {
        private readonly IList<NeuralNetLayer> _hiddenLayers = new List<NeuralNetLayer>();

        /// <summary>
        /// Creates a new instance of a <see cref="NeuralNet"/>
        /// </summary>
        /// <param name="numInputs">The number of nodes in the input layer</param>
        /// <param name="numOutputs">The number of nodes in the output layer</param>
        public NeuralNet(int numInputs, int numOutputs)
        {
            if (numInputs <= 0) throw new ArgumentOutOfRangeException(nameof(numInputs), "You must have at least one input node");
            if (numOutputs <= 0) throw new ArgumentOutOfRangeException(nameof(numOutputs), "You must have at least one output node");

            Inputs = new NeuralNetLayer(numInputs);
            Outputs = new NeuralNetLayer(numOutputs);
        }

        /// <summary>
        /// Adds a hidden layer to the neural net and returns the new layer.
        /// </summary>
        /// <param name="numNeurons">The number of neurons in the layer</param>
        public void AddHiddenLayer(int numNeurons)
        {
            if (numNeurons <= 0) throw new ArgumentOutOfRangeException(nameof(numNeurons), "You cannot add a hidden layer without any nodes");
            if (IsConnected) throw new InvalidOperationException("Cannot add a new layer after the network has been evaluated.");

            var layer = new NeuralNetLayer(numNeurons);

            _hiddenLayers.Add(layer);
        }

        /// <summary>
        /// Evaluates the result of the neural network given the specified set of <paramref name="inputs"/>.
        /// </summary>
        /// <param name="inputs">The inputs to evaluate.</param>
        /// <returns>The values outputted from the output layer</returns>
        public IEnumerable<decimal> Evaluate(IEnumerable<decimal> inputs)
        {
            // Don't force people to explicitly connect
            EnsureConnected();

            // Pipe the inputs into the network and evaluate the results
            Inputs.SetValues(inputs);

            return Inputs.Evaluate();
        }

        /// <summary>
        /// Declares that the network is now complete and that connections should be created.
        /// </summary>
        public void Connect()
        {
            if (IsConnected) throw new InvalidOperationException("The Network has already been connected");

            if (_hiddenLayers.Any())
            {
                // Connect input to the first hidden layer
                Inputs.ConnectTo(_hiddenLayers.First());

                // Connect hidden layers to each other
                if (_hiddenLayers.Count > 1)
                {
                    for (int i = 0; i < _hiddenLayers.Count - 1; i++)
                    {
                        _hiddenLayers[i].ConnectTo(_hiddenLayers[i + 1]);
                    }
                }

                // Connect the last hidden layer to the output layer
                _hiddenLayers.Last().ConnectTo(Outputs);
            }
            else
            {
                // No hidden layers, connect the input layer to the output layer
                Inputs.ConnectTo(Outputs);
            }

            IsConnected = true;
        }

        /// <summary>
        /// Determines whether or not the nodes in the network have been connected.
        /// </summary>
        public bool IsConnected { get; private set; }

        /// <summary>
        /// The input layer
        /// </summary>
        public NeuralNetLayer Inputs { get; }

        /// <summary>
        /// The output layer
        /// </summary>
        public NeuralNetLayer Outputs { get; }

        /// <summary>
        /// Gets all layers in the network, in order from first to last, including the Input layer,
        /// output layer, and any hidden layers.
        /// </summary>
        public IEnumerable<NeuralNetLayer> Layers
        {
            get
            {
                yield return Inputs;

                foreach (var layer in _hiddenLayers)
                {
                    yield return layer;
                }

                yield return Outputs;
            }
        }

        /// <summary>
        /// Sets the weights of all connections in the network. This is a convenience method for loading
        /// weight values from JSON and restoring them into the network.
        /// This will connect the network if it is not currently connected.
        /// </summary>
        /// <param name="weights">The weight values from -1 to 1 for every connector in the network.</param>
        [UsedImplicitly]
        public void SetWeights(IList<decimal> weights)
        {
            // Setting weights makes no sense unless the network is connected, so ensure we're connected
            EnsureConnected();

            ConnectorCount = 0;

            int weightIndex = 0;

            foreach (var layer in Layers)
            {
                foreach (var neuron in layer.Neurons)
                {
                    foreach (var connection in neuron.OutgoingConnections)
                    {
                        // Early exit if we've run out of weights to go around
                        if (weightIndex >= weights.Count)
                        {
                            break;
                        }

                        connection.Weight = weights[weightIndex++];
                        ConnectorCount++;
                    }
                }
            }
        }

        /// <summary>
        /// Connects the neural net if it has not yet been connected
        /// </summary>
        private void EnsureConnected()
        {
            if (IsConnected) return;

            Connect();
        }

        /// <summary>
        /// Gets the total connector count in the neural net.
        /// </summary>
        public int ConnectorCount { get; private set; }
    }
Enter fullscreen mode Exit fullscreen mode

F# Neural Net

The F# version is the largest F# class, but it's logic is still fairly concise with small, focused methods.

/// A high-level encapsulation of a neural net
and NeuralNet(numInputs: int, numOutputs: int) =
  do 
    if numInputs <= 0 then invalidArg "numInputs" "There must be at least one neuron in the input layer";
    if numOutputs <= 0 then invalidArg "numOutputs" "There must be at least one neuron in the output layer";

  let inputLayer: NeuralNetLayer = new NeuralNetLayer(numInputs);
  let outputLayer: NeuralNetLayer = new NeuralNetLayer(numOutputs);
  let mutable hiddenLayers: NeuralNetLayer seq = Seq.empty;
  let mutable isConnected: bool = false;

  let connectLayers (n1:NeuralNetLayer) (n2:NeuralNetLayer) = n1.Connect(n2);

  let layersMinusInput: NeuralNetLayer seq =
    seq {
      for layer in hiddenLayers do yield layer;
      yield outputLayer;
    }

  let layersMinusOutput: NeuralNetLayer seq =
    seq {
      yield inputLayer;
      for layer in hiddenLayers do yield layer;
    }

  /// Yields all connections to nodes inside of the network
  let connections = Seq.collect (fun (l:NeuralNetLayer) -> l.Neurons) layersMinusInput 
                 |> Seq.collect (fun (n:Neuron) -> n.Inputs); 

  /// Gets the layers of the neural network, in sequential order
  member this.Layers: NeuralNetLayer seq =
    seq {
      yield inputLayer;
      for layer in hiddenLayers do
        yield layer;
      yield outputLayer;
    }

  /// Represents the input layer for the network which take in values from another system
  member this.InputLayer = inputLayer;

  /// Represents the last layer in the network which has the values that will be taken out of the network
  member this.OutputLayer = outputLayer;    

  /// Connects the various layers of the neural network
  member this.Connect() =
    if isConnected then invalidOp "The Neural Network has already been connected";

    Seq.iter2 (fun l lNext -> connectLayers l lNext) layersMinusOutput layersMinusInput 
    isConnected <- true;

  /// Determines whether or not the network has been connected. After the network is connected, it can no longer be added to
  member this.IsConnected = isConnected;

  /// Adds a hidden layer to the middle of the neural net
  member this.AddHiddenLayer(layer: NeuralNetLayer) = 
    if isConnected then invalidOp "Hidden layers cannot be added after the network has been connected.";
    hiddenLayers <- Seq.append hiddenLayers [layer];

  /// Sets the weights on all connections in the neural network
  member this.SetWeights(weights: decimal seq) = 
    if isConnected = false then do this.Connect();
    Seq.iter2 (fun (w:decimal) (c:NeuronConnection) -> c.Weight <- w) weights connections;      

  /// Evaluates the entire neural network and yields the result of the output layer
  member this.Evaluate(): decimal seq = 
    if not isConnected then do this.Connect();

    // Iterate through the layers and run calculations
    let mutable result: decimal seq = Seq.empty;
    for layer in this.Layers do
      result <- layer.Evaluate();
    result;
Enter fullscreen mode Exit fullscreen mode

Conclusion

Closing Thoughts

While the F# syntax is more concise, it should be noted that this is an example that is almost ideal for a functional language. This is a key example of a component that could be used by C# code in other projects.

If you were looking to add F# to a project, I'd recommend starting with a small isolated slice of your application that other areas depend on for calculations or other sorts of transformation logic.

I personally feel that Functional Programming, or at least core concepts from those languages, can benefit software quality significantly, so this is an idea worth exploring.

Where can I find this code?

All code in these examples is hosted on GitHub.

GitHub logo IntegerMan / MattEland.AI

.NET Artificial Intelligence Libraries not tied to a specific project

MattEland.AI

Artificial Intelligence related libraries not tied to a specific project.




If you're curious about MattEland.AI, it is available as a NuGet package at https://www.nuget.org/packages/MattEland.AI.Neural/

Discussion (0)