DEV Community

Cover image for **How to Run Machine Learning Models in Java Applications with ONNX Runtime**
Nithin Bharadwaj
Nithin Bharadwaj

Posted on

**How to Run Machine Learning Models in Java Applications with ONNX Runtime**

As a best-selling author, I invite you to explore my books on Amazon. Don't forget to follow me on Medium and show your support. Thank you! Your support means the world!

Machine learning models often feel like they live in a separate world, usually one built with Python. For a long time, if you had a Java application that needed to make a prediction, you faced a complex choice: build a separate service, deal with cumbersome language bridges, or miss out entirely. This changed for me when I started using the ONNX Runtime. It’s a high-performance engine that lets me run models from frameworks like PyTorch or TensorFlow directly inside my Java applications. I don’t need to call out to another service or rewrite everything. It feels like bringing the model’s intelligence directly into the heart of my system.

Let’s start at the beginning: loading a model and setting it up to run. Think of an ONNX model file as a blueprint. The ONNX Runtime is the construction crew that can read that blueprint and execute it. In Java, your first job is to create a session, which is your working instance of the model.

When I create this session, I can tell the engine how to work. Do I want it to go as fast as possible? Should it use a specific number of threads? Maybe I have a GPU available and want to use that. This setup happens once, when the application starts, and this session becomes the workhorse for all predictions.

Here’s how I typically do it. I create a service class that handles the lifecycle of the model session.

import ai.onnxruntime.*;

public class ModelLoader {
    private OrtEnvironment env;
    private OrtSession session;

    public ModelLoader(String pathToModel) throws OrtException {
        // Get the shared environment
        env = OrtEnvironment.getEnvironment();

        // Create options to configure the session
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();

        // Apply all available optimizations to the model graph
        sessionOptions.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);

        // Control threading: 2 threads for parallel operations, 4 for calculations within an operation
        sessionOptions.setInterOpNumThreads(2);
        sessionOptions.setIntraOpNumThreads(4);

        // If I had an NVIDIA GPU, I could enable it like this:
        // sessionOptions.addCUDA(0);

        // Finally, load the model file and create the session
        session = env.createSession(pathToModel, sessionOptions);
    }

    public OrtSession getSession() {
        return session;
    }
}
Enter fullscreen mode Exit fullscreen mode

This ModelLoader is a foundation. I initialize it when my application starts, and it gives me a ready-to-use OrtSession. The options are important. Setting the optimization level tells the runtime to rearrange the model's internal operations to run faster on my specific hardware. The thread settings help me balance speed against resource usage in a server environment.

Now, a model doesn't understand Java objects like String or BufferedImage. It understands tensors: multi-dimensional arrays of numbers. So, my second technique is all about translation. I need to build a pipeline that takes my real-world data—a user's comment, a product image, a sensor reading—and converts it into the exact format of numbers the model expects.

Similarly, the model's output is just a tensor. I need to translate that back into something meaningful: a sentiment label, a list of detected objects, or a forecast number. I think of this as building an adapter between my business logic and the mathematical world of the model.

Let's take a common example: processing an image for a classification model. The model might expect a 224x224 pixel image, with three color channels (Red, Green, Blue), each pixel value normalized between 0 and 1, and arranged in a specific order.

import java.awt.image.BufferedImage;
import java.awt.Graphics2D;
import java.nio.FloatBuffer;

public class ImageTransformer {

    public static OnnxTensor prepareImageTensor(BufferedImage originalImage) throws OrtException {
        // 1. Resize the image to the required dimensions
        int targetSize = 224;
        BufferedImage resizedImage = new BufferedImage(targetSize, targetSize, BufferedImage.TYPE_INT_RGB);
        Graphics2D g = resizedImage.createGraphics();
        g.drawImage(originalImage, 0, 0, targetSize, targetSize, null);
        g.dispose();

        // 2. Extract pixel data and normalize it
        float[] pixelData = new float[3 * targetSize * targetSize]; // 3 channels * width * height
        int index = 0;

        // ONNX often expects "CHW" format: Channel, Height, Width.
        // First, all Red values, then all Green, then all Blue.
        for (int c = 0; c < 3; c++) { // Loop for each channel: R, G, B
            for (int y = 0; y < targetSize; y++) {
                for (int x = 0; x < targetSize; x++) {
                    int rgb = resizedImage.getRGB(x, y);
                    float value = 0;
                    // Extract the correct channel
                    if (c == 0) { // Red
                        value = ((rgb >> 16) & 0xFF) / 255.0f;
                    } else if (c == 1) { // Green
                        value = ((rgb >> 8) & 0xFF) / 255.0f;
                    } else { // Blue
                        value = (rgb & 0xFF) / 255.0f;
                    }
                    pixelData[index++] = value;
                }
            }
        }

        // 3. Create the tensor. Shape is [Batch=1, Channels=3, Height=224, Width=224]
        long[] shape = {1, 3, targetSize, targetSize};
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        return OnnxTensor.createTensor(env, FloatBuffer.wrap(pixelData), shape);
    }
}
Enter fullscreen mode Exit fullscreen mode

The output is just an array of numbers called logits. To get a useful result, I need to process them. For a classification model, that usually means converting them into probabilities.

import java.util.*;

public class PredictionInterpreter {

    public static Map<String, Float> interpretClassification(float[] logits, List<String> classLabels) {
        // First, apply the softmax function to get probabilities
        float[] probabilities = softmax(logits);

        Map<String, Float> results = new HashMap<>();
        for (int i = 0; i < probabilities.length; i++) {
            results.put(classLabels.get(i), probabilities[i]);
        }

        // Sort by probability, highest first, and return the top 5
        return results.entrySet().stream()
            .sorted(Map.Entry.<String, Float>comparingByValue().reversed())
            .limit(5)
            .collect(Collectors.toMap(
                Map.Entry::getKey,
                Map.Entry::getValue,
                (e1, e2) -> e1,
                LinkedHashMap::new // Preserves order
            ));
    }

    private static float[] softmax(float[] z) {
        float max = Arrays.stream(z).max().orElse(0.0f);
        float sum = 0.0f;
        float[] expZ = new float[z.length];

        // For numerical stability, subtract the max before exponentiating
        for (int i = 0; i < z.length; i++) {
            expZ[i] = (float) Math.exp(z[i] - max);
            sum += expZ[i];
        }

        // Normalize to get probabilities
        for (int i = 0; i < expZ.length; i++) {
            expZ[i] = expZ[i] / sum;
        }
        return expZ;
    }
}
Enter fullscreen mode Exit fullscreen mode

With these translators in place, my main application code stays clean. It just deals with images and maps of labels, not tensors and logits.

When I first deployed a model, I handled predictions one at a time. A request comes in, I run the model, I send the answer back. This works, but it's slow if many requests arrive at once. Each prediction has a small overhead from communicating with the ONNX Runtime. My third technique solves this: dynamic batching.

The idea is simple. Instead of serving each request immediately, I wait a very short time—maybe a few milliseconds—to see if other requests are coming. Then, I group several inputs together into a single batch and run them through the model all at once. The model is often much faster at processing a batch of 10 items than it is at processing 10 items individually. This can massively increase the number of predictions I can handle per second.

Here’s a simplified version of a batching predictor.

import java.util.*;

public class BatchInferenceHandler {
    private final OrtSession session;
    private final int maxBatchSize;

    public BatchInferenceHandler(OrtSession session, int maxBatchSize) {
        this.session = session;
        this.maxBatchSize = maxBatchSize; // e.g., 16 or 32
    }

    public List<float[]> runBatchPrediction(List<float[]> individualInputs) throws OrtException {
        List<List<float[]>> batchedLists = createBatches(individualInputs, maxBatchSize);
        List<float[]> allResults = new ArrayList<>();

        for (List<float[]> singleBatch : batchedLists) {
            allResults.addAll(runSingleBatch(singleBatch));
        }
        return allResults;
    }

    private List<float[]> runSingleBatch(List<float[]> batch) throws OrtException {
        int batchCount = batch.size();
        int featuresPerInput = batch.get(0).length;

        // Combine all inputs into one big flat array
        float[] combinedInput = new float[batchCount * featuresPerInput];
        int copyPosition = 0;
        for (float[] singleInput : batch) {
            System.arraycopy(singleInput, 0, combinedInput, copyPosition, featuresPerInput);
            copyPosition += featuresPerInput;
        }

        // Shape is now [batchCount, featuresPerInput]
        long[] shape = {batchCount, featuresPerInput};
        OnnxTensor batchTensor = OnnxTensor.createTensor(
            session.getEnvironment(),
            FloatBuffer.wrap(combinedInput),
            shape
        );

        Map<String, OnnxTensor> inputs = new HashMap<>();
        // Assume the model has one input; get its name.
        String inputName = session.getInputNames().iterator().next();
        inputs.put(inputName, batchTensor);

        try (OrtSession.Result outputs = session.run(inputs)) {
            OnnxTensor resultTensor = (OnnxTensor) outputs.get(0);
            float[] batchOutput = (float[]) resultTensor.getValue();

            // The output is also a batch. Split it back into individual results.
            return splitBatchOutput(batchOutput, batchCount);
        }
    }

    private List<float[]> splitBatchOutput(float[] flatOutput, int batchCount) {
        List<float[]> separatedResults = new ArrayList<>();
        int singleResultSize = flatOutput.length / batchCount; // e.g., number of classes

        for (int i = 0; i < batchCount; i++) {
            float[] singleResult = new float[singleResultSize];
            int start = i * singleResultSize;
            System.arraycopy(flatOutput, start, singleResult, 0, singleResultSize);
            separatedResults.add(singleResult);
        }
        return separatedResults;
    }

    private List<List<float[]>> createBatches(List<float[]> list, int chunkSize) {
        List<List<float[]>> batches = new ArrayList<>();
        for (int i = 0; i < list.size(); i += chunkSize) {
            batches.add(list.subList(i, Math.min(list.size(), i + chunkSize)));
        }
        return batches;
    }
}
Enter fullscreen mode Exit fullscreen mode

In a real service, I would wrap this with a queue and a separate thread that processes the batch on a timer, sending results back to the waiting requests. This pattern is a game-changer for high-throughput applications.

Once my model is deployed and handling traffic, I can't just assume it's working correctly. I need to watch it. My fourth technique is about building observability into the inference process. How long does each prediction take? Is the latency consistent? Are the input values drifting over time, which might mean the model's performance is decaying?

I use metrics libraries like Micrometer to instrument my prediction code. This lets me track everything and see it on a dashboard.

import io.micrometer.core.instrument.*;

public class ModelMetrics {
    private final Timer inferenceTimer;
    private final Counter successCounter;
    private final Counter failureCounter;
    private final DistributionSummary inputSummary;

    public ModelMetrics(MeterRegistry registry, String modelName) {
        // Timer for latency
        inferenceTimer = Timer.builder("model.inference.duration")
                .tags("model", modelName)
                .publishPercentiles(0.5, 0.9, 0.99) // Median, 90th, 99th percentile
                .register(registry);

        // Counters for success/failure
        successCounter = Counter.builder("model.inference")
                .tags("model", modelName, "status", "success")
                .register(registry);
        failureCounter = Counter.builder("model.inference")
                .tags("model", modelName, "status", "failure")
                .register(registry);

        // Track distribution of a key input feature to detect drift
        inputSummary = DistributionSummary.builder("model.input.feature1")
                .tags("model", modelName)
                .register(registry);
    }

    public float[] predictWithMetrics(float[] input, Supplier<float[]> predictionLogic) {
        Timer.Sample sample = Timer.start(); // Start timing

        try {
            float[] result = predictionLogic.get(); // This calls the actual model
            sample.stop(inferenceTimer); // Stop timer and record
            successCounter.increment();

            // Record the first feature value to monitor its distribution
            if (input.length > 0) {
                inputSummary.record(input[0]);
            }

            return result;

        } catch (Exception e) {
            failureCounter.increment();
            sample.stop(Timer.builder("model.inference.duration")
                    .tags("model", "error")
                    .register(Metrics.globalRegistry)); // Record error timing separately
            throw e; // Re-throw the exception
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

In my application, I’d wrap my model call with predictWithMetrics. This gives me concrete data. If the 99th percentile latency jumps from 50ms to 500ms, I know something is wrong. If the average value of an input feature slowly changes, it might be a signal that the real-world data is drifting from the data the model was trained on.

Models aren't static. They get retrained, improved, or sometimes rolled back. My fifth and final technique is about managing this lifecycle. I treat models like any other deployable artifact: they have versions, and I need a system to manage them. This allows for safe updates and even A/B testing, where I can compare a new model against the current one on a small portion of live traffic.

I implement a simple model registry that holds the active version and can load a new one.

import java.nio.file.Path;
import java.util.concurrent.ConcurrentHashMap;

public class ModelRegistry {
    // Maps a model name (e.g., "fraud-detector") to its currently loaded version
    private final Map<String, ModelInstance> registry = new ConcurrentHashMap<>();

    public void loadNewVersion(String modelName, String versionId, Path modelFilePath) throws OrtException {
        ModelInstance newInstance = new ModelInstance(versionId, modelFilePath);
        registry.put(modelName, newInstance);
        System.out.println("Loaded model " + modelName + " version " + versionId);
    }

    public PredictionResult runPrediction(String modelName, float[] input, String userId) throws OrtException {
        ModelInstance instance = registry.get(modelName);
        if (instance == null) {
            throw new IllegalArgumentException("Model not found: " + modelName);
        }

        // Simple A/B test routing: use user ID hash to decide
        boolean useExperimental = Math.abs(userId.hashCode() % 100) < 10; // 10% of traffic

        if (useExperimental && instance.hasExperimentalVersion()) {
            return instance.runExperimental(input);
        } else {
            return instance.run(input);
        }
    }

    private static class ModelInstance {
        private final String version;
        private final OrtSession productionSession;
        private OrtSession experimentalSession = null; // Optional new version

        public ModelInstance(String version, Path modelPath) throws OrtException {
            this.version = version;
            OrtEnvironment env = OrtEnvironment.getEnvironment();
            OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
            this.productionSession = env.createSession(modelPath.toString(), opts);
        }

        public void setExperimentalVersion(Path experimentalModelPath) throws OrtException {
            OrtEnvironment env = OrtEnvironment.getEnvironment();
            OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
            this.experimentalSession = env.createSession(experimentalModelPath.toString(), opts);
        }

        public boolean hasExperimentalVersion() {
            return experimentalSession != null;
        }

        public PredictionResult run(float[] input) throws OrtException {
            // ... logic to run inference using productionSession ...
            return new PredictionResult("production", version, result);
        }

        public PredictionResult runExperimental(float[] input) throws OrtException {
            // ... logic to run inference using experimentalSession ...
            return new PredictionResult("experimental", version + "-exp", result);
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

With this registry, I can push a new model version by calling loadNewVersion. I can also set an experimental version for A/B testing. The routing logic in runPrediction is simple—here it’s based on a user ID hash—but it lets me compare performance and accuracy between two models in a controlled way before committing to a full rollout.

These five techniques have allowed me to integrate machine learning into Java applications in a way that is robust, scalable, and maintainable. I start by loading a model efficiently. I build careful translators for data. I batch requests to handle high volume. I monitor everything to ensure health. And finally, I manage models as versioned artifacts that can evolve. The ONNX Runtime provides the core engine, but it’s these patterns that let me build it into a true production system. My Java application can now make intelligent decisions directly, quickly, and reliably, closing the gap between the data science team’s work and the real-world needs of our users.

📘 Checkout my latest ebook for free on my channel!

Be sure to like, share, comment, and subscribe to the channel!


101 Books

101 Books is an AI-driven publishing company co-founded by author Aarav Joshi. By leveraging advanced AI technology, we keep our publishing costs incredibly low—some books are priced as low as $4—making quality knowledge accessible to everyone.

Check out our book Golang Clean Code available on Amazon.

Stay tuned for updates and exciting news. When shopping for books, search for Aarav Joshi to find more of our titles. Use the provided link to enjoy special discounts!

Our Creations

Be sure to check out our creations:

Investor Central | Investor Central Spanish | Investor Central German | Smart Living | Epochs & Echoes | Puzzling Mysteries | Hindutva | Elite Dev | Java Elite Dev | Golang Elite Dev | Python Elite Dev | JS Elite Dev | JS Schools


We are on Medium

Tech Koala Insights | Epochs & Echoes World | Investor Central Medium | Puzzling Mysteries Medium | Science & Epochs Medium | Modern Hindutva

Top comments (0)