I am sure many of you have experienced that when you try to do machine learning with Java or Scala, there is no cool graphing tool unlike that Python has Matplotlib. Did you wonder if you could draw Matplotlib chart with Java..?
Matplotlib4j is a library which gives you the power!
- How to use
- Draw Graphs
- Save Image To File
- Switch Python with pyenv, pyenv-virtualenv
From here I will introduce Java examples. Of course, it can also be used from other JVM languages such as Scala and Kotlin. The examples will be described later.
First, add Matplotlib4j to the Java project where you want to use Matplotlib.
For Maven, add the following dependency.
<dependency> <groupId>com.github.sh0nk</groupId> <artifactId>matplotlib4j</artifactId> <version>0.5.0</version> </dependency>
Similarly, for Gradle:
That's all. Let's begin drawing!
The usage is similar to Matplotlib's API, so we can write it intuitively. First, create a
Plot object, call the
pyplot method on it to add an arbitrary graph, and finally call the
show() method; since it is a Builder pattern, we will add options behind it using IDE completion.
As a starting point, let's draw a scatter plot.
List<Double> x = NumpyUtils.linspace(-3, 3, 100); List<Double> y = x.stream().map(xi -> Math.sin(xi) + Math.random()).collect(Collectors.toList()); Plot plt = Plot.create(); plt.plot().add(x, y, "o").label("sin"); plt.legend().loc("upper right"); plt.title("scatter"); plt.show();
With the above Java code, we can draw the following graph.
Some Numpy methods, such as
meshgrid, have been prepared as
NumpyUtils classes to help with graph drawing. The first block generates the x and y data for plotting. Here we give a random value to the sin curve. After that, we create a plot object, add the generated x and y data to the
plot() method, and call
show() at the end to draw the graph.
This is almost equivalent to the Python implementation below (almost, because the data generation part of numpy is strictly different). The method calls are similar, making it easy to use in case you are a Pythonista.
import matplotlib.pyplot as plt import numpy as np x = np.linspace(-3, 3, 100) y = np.sin(x) + np.random.rand(100) plt.plot(x, y, "o", label="sin") plt.legend(loc="upper right") plt.title("scatter") plt.show()
Next, let's draw a contour plot (contour line).
List<Double> x = NumpyUtils.linspace(-1, 1, 100); List<Double> y = NumpyUtils.linspace(-1, 1, 100); NumpyUtils.Grid<Double> grid = NumpyUtils.meshgrid(x, y); List<List<Double>> zCalced = grid.calcZ((xi, yj) -> Math.sqrt(xi * xi + yj * yj)); Plot plt = Plot.create(); ContourBuilder contour = plt.contour().add(x, y, zCalced); plt.clabel(contour) .inline(true) .fontsize(10); plt.title("contour"); plt.show();
Histograms can be drawn in the same way.
Random rand = new Random(); List<Double> x1 = IntStream.range(0, 1000).mapToObj(i -> rand.nextGaussian()) .collect(Collectors.toList()); List<Double> x2 = IntStream.range(0, 1000).mapToObj(i -> 4.0 + rand.nextGaussian()) .collect(Collectors.toList()); Plot plt = Plot.create(); plt.hist() .add(x1).add(x2) .bins(20) .stacked(true) .color("#66DD66", "#6688FF"); plt.xlim(-6, 10); plt.title("histogram"); plt.show();
Matplotlib4j also supports saving to a file. Saving images to a file would be convenient for use cases that do not have a GUI, such as batch processing of machine learning on a server.
Similar to the original Matplotlib, by using the
.savefig() method instead of
.show(), the image is saved to a file without popping up a plot window. The only difference is that
plt.executeSilently() needs to be called after
.savefig(), which is necessary as a termination process since the savefig command can also be a part of a method chain.
Random rand = new Random(); List<Double> x = IntStream.range(0, 1000).mapToObj(i -> rand.nextGaussian()) .collect(Collectors.toList()); Plot plt = Plot.create(); plt.hist().add(x).orientation(HistBuilder.Orientation.horizontal); plt.ylim(-5, 5); plt.title("histogram"); plt.savefig("/tmp/histogram.png").dpi(200); // Necessary to output the file plt.executeSilently();
This will output an image like the one below.
To use Matplotlib4j, you need to install Matplotlib with Python environment; by default, Matplotlib4j will use the Python that is in your environment path, but in many cases you may not have Matplotlib installed in the system default Python.
In that case, you can switch to a Python environment with Matplotlib installed, such as Anaconda, using pyenv or pyenv-virtualenv.
To use Python according to the Pyenv environment, specify
PythonConfig when creating the
Plot object as follows.
Plot plot = Plot.create(PythonConfig.pyenvConfig("Arbitrary pyenv name"));
Similarly, you can specify the environment name of pyenv-virtualenv.
Plot plot = Plot.create(PythonConfig.pyenvVirtualenvConfig("Arbitrary pyenv name", "Arbitrary virtualenv name"));
When used from Scala, the aforementioned scatter plot example can be written as follows, just by paying attention to the difference of the Boxing/Unboxing numbers and List classes.
import scala.collection.JavaConverters._ val x = NumpyUtils.linspace(-3, 3, 100).asScala.toList val y = x.map(xi => Math.sin(xi) + Math.random()).map(Double.box) val plt = Plot.create() plt.plot().add(x.asJava, y.asJava, "o") plt.title("scatter") plt.show()
In the Tutorial page, you can find more cases step by step in Java, Scala and Kotlin.
I recently started reading a book of Deep Learning and decided to try to implement it in Scala which I've often touched lately, since it was not interesting to copy the code on the book in Python as it is. I was happy to be able to write it in a functional way in Scala, but when I got to the backpropagation using the steepest descent method, I encountered a situation where the loss was not dropping at all, and I thought, "What's wrong?"
Of course, the common practice to tackle this is to thicken the tests, but I'd like to see what's going on first quickly by displaying a graph like in the book. But found that there are no good graphing tools in Scala... However, implementing the graphing tool in Scala from scratch is too hard... So I decided to use Matplotlib, which is a familiar Python library, as the reason to create the library.
Matplotlib4j calls Matplotlib in a way that generates Python code without using JNI or Jython. Initially, I wanted to implement it using Jython, but it only supported the Python version up to 2.7, and since numpy wasn't supported, the Matplotlib which depends on it wouldn't work either, so I decided to abandon this path.
There is a library in the world that allow you to use CPython from Java code, and this one was a candidate because we can use both Python3 and numpy. However, we had to install a separate environment-dependent library to use JNI, and we also had to install the library from pip on the Python side, which was too much work for something as simple as drawing graphs. So in the end I have decided to implement it independently of these libraries at all.
Of course, since it is executed via a file, I had to do some tricks in how I pass variables and use return values. Fortunately, since the purpose is only to draw graphs, the basic functions can be satisfied by one-way output to a file, and I think the performance is within the acceptable range with some latency.
Matplotlib for java: A simple graph plot library for java, scala and kotlin with powerful python matplotlib
A simplest interface library to enable your java project to use matplotlib.
Of course it is able to be imported to scala project as below. The API is designed as similar to the original matplotlib's.
How to use
Here is an example. Find more examples on
Plot plt = Plot.create(); plt.plot() .add(Arrays.asList(1.3, 2)) .label("label") .linestyle("--"); plt.xlabel("xlabel"); plt.ylabel("ylabel"); plt.text(0.5, 0.2, "text"); plt.title("Title!"); plt.legend(); plt.show();
Another example to draw Contour.