DEV Community

Cover image for Minimal Neural Network in Java
Longtime Coder
Longtime Coder

Posted on

Minimal Neural Network in Java

Neural Network in Java

Here is a minimal implementation of a neural network in Java. This is meant to be copy-pasted into your project, should you ever have a need to use a neural network. The code is also meant to be easy to follow and have reasonably good performance. Check it out if you're interested. Any helpful comments humbly accepted. I am not well versed in Java and would gladly like to know how to make the code more idiomatic Java.

https://github.com/dlidstrom/NeuralNetworkInAllLangs/tree/main/Java

Example usage

var trainingData = Arrays.asList( 
     new DataItem(new double[]{0, 0}, new double[]{Logical.xor(0, 0), Logical.xnor(0, 0), Logical.or(0, 0), Logical.and(0, 0), Logical.nor(0, 0), Logical.nand(0, 0)}), 
     new DataItem(new double[]{0, 1}, new double[]{Logical.xor(0, 1), Logical.xnor(0, 1), Logical.or(0, 1), Logical.and(0, 1), Logical.nor(0, 1), Logical.nand(0, 1)}), 
     new DataItem(new double[]{1, 0}, new double[]{Logical.xor(1, 0), Logical.xnor(1, 0), Logical.or(1, 0), Logical.and(1, 0), Logical.nor(1, 0), Logical.nand(1, 0)}), 
     new DataItem(new double[]{1, 1}, new double[]{Logical.xor(1, 1), Logical.xnor(1, 1), Logical.or(1, 1), Logical.and(1, 1), Logical.nor(1, 1), Logical.nand(1, 1)}) 
 ).toArray(new DataItem[0]); 

 Trainer trainer = Trainer.create(2, 2, 6, rand); 
 double lr = 1.0; 
 int ITERS = 4000; 
 for (int e = 0; e < ITERS; e++) { 
     var sample = trainingData[e % trainingData.length]; 
     trainer.train(sample.input(), sample.output(), lr); 
 } 

 Network network = trainer.network(); 
 System.out.println("Result after " + ITERS + " iterations"); 
 System.out.println("        XOR   XNOR    OR   AND   NOR   NAND"); 
 for (var sample : trainingData) { 
     double[] pred = network.predict(sample.input()); 
     System.out.printf( 
         Locale.ROOT, 
         "%d,%d = %.3f  %.3f %.3f %.3f %.3f  %.3f%n", 
         (int) sample.input()[0], (int) sample.input()[1], 
         pred[0], pred[1], pred[2], pred[3], pred[4], pred[5]); 
 }
Enter fullscreen mode Exit fullscreen mode

This example shows how to implement a neural network that can be used to predict 6 logical functions: xor, xnor, or, and, nor, nand. It uses two input neurons, two hidden neurons, and 6 output neurons. Such a network contains 24 weights which are trained to correctly predict all 6 functions.

You can use this implementation for handwriting recognition, game playing, predictions, and more.

Here's the output of the above sample:

Result after 4000 iterations
        XOR   XNOR    OR   AND   NOR   NAND
0,0 = 0.038  0.962 0.038 0.001 0.963  0.999
0,1 = 0.961  0.039 0.970 0.026 0.029  0.974
1,0 = 0.961  0.039 0.970 0.026 0.030  0.974
1,1 = 0.049  0.952 0.994 0.956 0.006  0.044
Enter fullscreen mode Exit fullscreen mode

You'll need to round the values to get exact answers, but this is what a neural network will output, in general.

Top comments (0)