DEV Community

Cover image for Optimizing and Running Neural Networks on React Native: A Grass Case Study
Wlad Radchenko
Wlad Radchenko

Posted on

Optimizing and Running Neural Networks on React Native: A Grass Case Study

During the long New Year holidays, I decided to start a new pet project. I wanted to experiment with something interesting involving neural networks, and the idea of a mobile app that could recognize plants — identifying species, age, and even diseases — directly on a mobile device without a server, while also providing recommendations on care and treatment, seemed exciting and new to me.

As I dug in, I realized that there are already quite a few similar apps. However, that didn’t diminish the value of the experiment: I wanted to understand how neural networks can be optimized for mobile devices and run directly on my phone. In this article, I’ll share what I managed to achieve — from selecting models to optimizing them and integrating them into React Native — as well as my experience with various methods for plant classification, disease detection, and estimating age and leaf count.

This is GitHub code and Hugging Face models.

SCOLD Model for Cross-Modal Disease Search

The first tool I decided to experiment with was the SCOLD model from Hugging Face. It was designed for cross-modal search, meaning it can match an image with a textual description. For example, you can feed in a photo of a corn leaf and a text description of a disease, and the model will estimate how well the text matches the image. SCOLD was trained on the LeafNet dataset.

Despite its interesting concept, I wasn’t entirely sure how to apply it in practice. On top of that, the model itself was about 1GB in size, which immediately makes running it on a mobile device challenging. I decided to modify the model slightly so that it could identify the plant disease directly.

Since I was only interested in working with images, I removed the Roberta text encoder and kept only the image processing module based on the Swin Transformer [source code]. As a result, I ended up with a much smaller model that outputs a 512-dimensional feature vector for each image.

Originally, this vector was meant to be compared with Roberta’s output for textual descriptions, but for mobile deployment, I decided on a different approach: instead of working with text, I compare the image vectors with vectors precomputed for the entire dataset.

This setup also introduced two additional files: a 248MB binary file with the embeddings and an 18MB JSON file containing disease labels. This approach preserves the original task of disease identification without needing to process text, while significantly reducing the model size for mobile use.

from timm import create_model
import torch
import torch.nn as nn
import numpy as np
from transformers import RobertaModel

EMBEDDING_DIM = 512


class ImageEncoder(nn.Module):
    def __init__(self):
        super(ImageEncoder, self).__init__()
        # Load the Swin Transformer with features_only=True
        self.swin = create_model("swin_base_patch4_window7_224.ms_in22k", pretrained=True, features_only=True)
        for param in self.swin.parameters():
            param.requires_grad = True

        # Get the feature size of the final stage
        self.swin_output_dim = self.swin.feature_info.channels()[-1]  # Last stage: 1024 channels

        # Define FC layer
        self.fc1 = nn.Linear(self.swin_output_dim * 7 * 7, EMBEDDING_DIM)  # Flattened input size
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        for param in self.fc1.parameters():
            param.requires_grad = True

    def forward(self, x):
        # Extract features from Swin
        swin_features = self.swin(x)[-1]  # Use the last stage feature map (e.g., [B, 1024, 7, 7])

        # Flatten feature map
        swin_features = swin_features.view(swin_features.size(0), -1)  # Shape: (B, 1024*7*7)

        # Pass through FC layer
        output = self.fc1(swin_features)  # Shape: (B, embedding_dim)
        return output


class LVL(nn.Module):
    def __init__(self):
        super(LVL, self).__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = nn.Identity()
        self.t_prime = nn.Parameter(torch.ones([]) * np.log(0.07))
        self.b = nn.Parameter(torch.ones([]) * 0)

    def get_images_features(self, images):
        image_embeddings = self.image_encoder(images)  # (batch_size, EMBEDDING_DIM)
        image_embeddings = nn.functional.normalize(image_embeddings, p=2, dim=-1)
        return image_embeddings

    def get_texts_feature(self, input_ids=None, attention_mask=None):
        """
        Plug
        :param input_ids: Tensor of shape (batch_size, seq_length)
        :param attention_mask: Tensor of shape (batch_size, seq_length)
        :return:
        """
        return None

    def forward(self, images, input_ids=None, attention_mask=None):
        """
        Args:
            images: Tensor of shape (batch_size, 3, 224, 224)
            input_ids: Tensor of shape (batch_size, seq_length)
            attention_mask: Tensor of shape (batch_size, seq_length)

        Returns:
            Image and text embeddings normalized for similarity calculation
        """

        image_embeddings = self.get_images_features(images)
        return image_embeddings
Enter fullscreen mode Exit fullscreen mode

The next step was converting the model to the ONNX format, which is well-supported on mobile and allows for integrating neural networks directly into React Native. During this process, I had to account for the specifics of each model: setting up the correct dummy input to verify the input shape, specifying input and output names, defining dynamic_axes to support batch sizes, and sometimes adjusting the opset version to ensure compatibility with ONNX Runtime.

import torch
# pip install onnxscript onnxruntime onnxruntime-tools
from model import LVL


def export(model_path: str, output_name: str):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Create a dummy input with the same dimensions as the real data.
    dummy_input = torch.randn(1, 3, 512, 512, device=device)

    # Load model
    model = LVL()
    model.to(device)
    model.eval()

    #  git clone https://huggingface.co/enalis/scold
    state_dict = torch.load(model_path, map_location=device)

    # Leave only the keys that are in the new model (only image_encoder)
    model_dict = model.state_dict()
    filtered_dict = {k: v for k, v in state_dict.items() if k in model_dict}

    # Load only image_encoder weights
    model.load_state_dict(filtered_dict, strict=False)  # <- strict=False important!

    # Export
    torch.onnx.export(
        model,                      # Model
        dummy_input,                # Input
        output_name,           # Name
        export_params=True,         # Save weights
        opset_version=18,           # Version
        do_constant_folding=True,   # Optim
        input_names=['images'],     # Name of inputs
        output_names=['image_embeddings'],  # Name of outputs
        dynamic_axes={'images': {0: 'batch_size'},  # Dynamic support batch size
                      'image_embeddings': {0: 'batch_size'}}
    )
    print("Finish!")


if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True, help="Path to model pth")
    parser.add_argument("--output_name", type=str, default="disease_detector.onnx", help="Path to save onnx model with name model.onnx")
    args = parser.parse_args()

    export(model_path=args.model_path, output_name=args.output_name)

Enter fullscreen mode Exit fullscreen mode

After that, I further reduced the model size by converting it from float32 to float16, while keeping the input format as float32. As a result, the model size dropped from 1 GB to 231 MB, making it much more suitable for a mobile application.

import onnx
from onnxconverter_common import float16


def export(model_path: str, output_name: str):
    # Load model
    model_fp32 = onnx.load(model_path)
    # Export to float16
    model_fp16 = float16.convert_float_to_float16(model_fp32, keep_io_types=True)
    # Save
    onnx.save(model_fp16, output_name)


if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True, help="Path model onnx")
    parser.add_argument("--output_name", type=str, default="fp16.onnx", help="Path to save fp16 mode with name model_fp16.onnx")
    args = parser.parse_args()

    export(model_path=args.model_path, output_name=args.output_name)

Enter fullscreen mode Exit fullscreen mode

The disease_detection.onnx model itself and the code for running it can be found on GitHub.

Plant Classification from Images

It would be hard to imagine this project without a neural network for recognizing plants from photos, so the next step was to tackle exactly that. While looking for suitable data and models, I came across the PlantCLEF 2024 competition, which led me to the Pl@ntNet dataset. The dataset itself is huge, and training a model on the full dataset would require a lot of time and resources. Fortunately, instead of overloading my laptop, I found that a pre-trained model along with the code to run it had already been released on Zenodo.

Image plant

The model takes a plant image as input and outputs a set of probabilities corresponding to specific plant species. Essentially, this is a standard image classification task without any extra complications. For my purposes, this was more than enough, so I followed the familiar path: exporting the model to ONNX and then converting it to float16 to reduce its size and make it more suitable for mobile deployment. The ready-to-use model and the code for running it are available in the project repository.

import torch
import timm
# pip install onnxscript onnxruntime onnxruntime-tools
import argparse


def load_class_mapping(class_list_file):
    with open(class_list_file) as f:
        class_index_to_class_name = {i: line.strip() for i, line in enumerate(f)}
    return class_index_to_class_name


def export(model_path: str, output_name: str):
    torch.serialization.add_safe_globals([argparse.Namespace])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Class and model after train.py
    class_mapping = load_class_mapping(args.class_mapping)

    # Load model
    model = timm.create_model('vit_base_patch14_reg4_dinov2.lvd142m', pretrained=False, num_classes=len(class_mapping))
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()

    # Dummy input (ViT Base 224x224 RGB)
    dummy_input = torch.randn(1, 3, 518, 518)

    # Export
    torch.onnx.export(
        model,                      # Model
        dummy_input,                # Input
        output_name,           # Name
        export_params=True,         # Save weights
        opset_version=18,           # Version
        input_names=['input'],      # Name of inputs
        output_names=['output'],    # Name of outputs
        dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    )
    print("Finished!")


if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument("--class_mapping", type=str, required=True, help="Path to class mapping")
    parser.add_argument("--model_path", type=str, required=True, help="Path species mapping")
    parser.add_argument("--output_name", type=str, default="plant_classificator.onnx", help="Path to save onnx model with name model.onnx")
    args = parser.parse_args()

    export(model_path=args.model_path, output_name=args.output_name)
Enter fullscreen mode Exit fullscreen mode

Later, after studying the dataset more closely, I realized that it mostly contains rare and exotic plants. It barely includes common ornamental flowers, vegetables, or fruits that one might expect in such an app—or at least typical bouquet flowers.

Because of this, I decided to go further and implement code for fine-tuning the model or training it from scratch. In practice, training from scratch turned out to be the preferable option, as it allows creating a smaller model, which is critical for mobile devices. Additionally, this approach gives users the flexibility to choose which model to load depending on their specific needs and the available resources on their device.

I moved the training process into a separate module and documented the dataset structure and the steps for running training in the project documentation. Training can be run with a batch size of 8 on a GPU with 8 GB of VRAM, or the batch size can be reduced if resources are limited. The dataset must be assembled manually by combining data from Kaggle and other open sources. This approach proved to be more flexible and aligns well with the concept of an experimental mobile ML project.

Regression Model for Estimating Plant Age and Leaf Count

The task of estimating a plant’s age and the number of its leaves may seem trivial at first, but in practice, I couldn’t find any ready-to-use models. There is a scientific work, GroMo: Plant Growth Modeling with Multiview Images, along with its corresponding repository, but the code there is somewhat messy and not easily scalable for practical use. Therefore, I decided not to adapt it, but to rewrite the solution from scratch, tailored to my own constraints.

In the original work, plants of different species, such as mustard, radish, wheat, and okra, were photographed from multiple angles and at various growth stages, with annotations indicating the plant’s age and the number of leaves. An important point is that visual features strongly depend on the species: leaf shape, size, and even the overall plant structure can vary significantly, and sometimes leaves are visually almost indistinguishable from stems.

In the original model, four images from different angles were fed into the network at once, whereas in my case, the input was simplified to a single photo, which is more aligned with a real mobile user scenario. There was also an idea to check whether the current number of leaves or stems is appropriate for the plant’s age—or if it’s below average—but I decided to leave that for future work.

At the core of this model is MobileNetV3 Large, used as a universal visual encoder. From the pretrained network, only the convolutional part is taken, which extracts compact and informative features from a leaf image. These features are then aggregated using Adaptive Average Pooling into a fixed-size vector of 960 dimensions, making the model independent of the input image size.

Image cat

This shared feature vector is then used for two regression tasks simultaneously: estimating the number of leaves and the plant’s age. To accomplish this, two separate “heads” are added, each implemented as a small fully connected block. I named this model LeafNet.

import torch.nn as nn
from torchvision.models import mobilenet_v3_large

class LeafNet(nn.Module):
    def __init__(self):
        super().__init__()
        backbone = mobilenet_v3_large(weights="DEFAULT")
        self.encoder = backbone.features

        self.pool = nn.AdaptiveAvgPool2d(1)

        self.count_head = nn.Sequential(
            nn.Linear(960, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1)
        )

        self.age_head = nn.Sequential(
            nn.Linear(960, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        feat = self.encoder(x)
        pooled = self.pool(feat).flatten(1)
        count = self.count_head(pooled)
        age = self.age_head(pooled)
        return feat, count, age
Enter fullscreen mode Exit fullscreen mode

Training was performed on a combined dataset without separating by plant species, allowing the model to learn from a wider variety of visual patterns. The total dataset size was around 380 GB. Full training took approximately three days, and during training, the model is automatically exported to ONNX format for use on mobile devices. The training documentation, along with all the code for running experiments, is available in the project repository.

Integrating and Running Neural Networks in React Native

In this part of the article, I’ll discuss running neural networks directly on a mobile device and the code required to make it work. I intentionally skip basic setup steps like installing Node.js, running an Android emulator, or setting up the development environment. Those are well-documented elsewhere and not directly relevant to the task at hand. Here, we’ll focus on the code and the limitations encountered when working with neural networks on mobile platforms.

For convenience, the project repository includes a ready-to-use APK for those who don’t want to build the app themselves (it’s large because it includes the models). However, for development and experimentation, I highly recommend building the project manually. Testing was performed on a Pixel 8 Pro emulator, so behavior on other devices may vary. Detailed build and run instructions are available in the repository under the mobile folder.

The first major problem, which cost me about a day of development, was choosing the right mobile stack. I initially started with Expo since it allows quickly spinning up a prototype. However, I discovered that running ONNX models in this environment is either unstable or completely impossible. Even building a custom Expo client didn’t help. Eventually, I abandoned Expo and switched to pure React Native, which proved to be the correct choice for further native integration and neural network usage.

Next, I ran into a fundamental limitation of mobile platforms: the JavaScript engine on Android and iOS is not designed to handle large binary files. Files like captions.json and especially embeddings.bin (248 MB) cannot simply be loaded into memory from JavaScript without risking OutOfMemory errors or extremely slow performance due to constant data copying. Yet, these files are essential for the disease detection model, where searches are performed over precomputed embedding vectors.

To solve this, all work with large data and vector search is moved to the native layer on Android by implementing a custom Kotlin native module, called FaissSearchModule. Its main purpose is to load the embedding file, store the data outside the JavaScript heap, and perform vector search entirely in native memory. Instead of loading embeddings.bin into RAM, a memory-mapped file is used via MappedByteBuffer. This allows the operating system to manage memory and load data on demand, without holding the entire file in the app’s memory. Essentially, the embedding file becomes part of the process’s virtual memory, which is crucial for handling large datasets on mobile devices.

The nearest-vector search is also performed entirely in native code. Only the query embedding is sent from JavaScript, and the results returned are compact: indexes, distances, and text labels for diseases. This minimizes the load on the JavaScript-to-native bridge and avoids transferring large arrays back and forth. A similar approach is used for captions.json: instead of standard JSON parsing, the file is read streamingly, preventing the need to load the entire file into memory. These files are located in the Android project directory for proper access.

package com.berkano

import com.facebook.react.bridge.*
import org.json.JSONArray
import java.io.*
import java.nio.ByteBuffer
import java.nio.ByteOrder
import kotlin.math.sqrt

/**
 * Native module for vector search using Faiss
 * Data is stored in native memory and does not enter the JavaScript heap
 */
class FaissSearchModule(
    private val reactContext: ReactApplicationContext
) : ReactContextBaseJavaModule(reactContext) {

    // Store data in native memory
    // Use MappedByteBuffer instead of FloatArray for large files
    // This allows avoiding loading all data into RAM at once
    private var embeddingsBuffer: java.nio.MappedByteBuffer? = null
    private var captions: List<String>? = null
    private var numVectors: Int = 0
    private var embeddingSize: Int = 512
    private var embeddingsFile: File? = null

    override fun getName(): String {
        return "FaissSearch"
    }

    /**
     * Loads embeddings from a binary file using a memory-mapped file
     * Does NOT load all data into RAM, reads on demand
     */
    @ReactMethod
    fun loadEmbeddings(embeddingsPath: String, promise: Promise) {
        try {
            val file = File(embeddingsPath)
            if (!file.exists()) {
                promise.reject("FILE_NOT_FOUND", "Embeddings file not found: $embeddingsPath")
                return
            }

            val fileSize = file.length()
            numVectors = (fileSize / 4 / embeddingSize).toInt() // 4 bytes per float

            // Use MappedByteBuffer — this does NOT fully load the file into RAM
            // The operating system manages memory and loads data on demand
            FileInputStream(file).use { fis ->
                val channel = fis.channel
                val byteBuffer = channel.map(
                    java.nio.channels.FileChannel.MapMode.READ_ONLY,
                    0,
                    fileSize
                )
                byteBuffer.order(ByteOrder.LITTLE_ENDIAN)
                embeddingsBuffer = byteBuffer
            }

            embeddingsFile = file

            val result = Arguments.createMap()
            result.putInt("numVectors", numVectors)
            result.putInt("embeddingSize", embeddingSize)
            result.putLong("fileSize", fileSize)

            promise.resolve(result)
            println("Mapped $numVectors vectors of size $embeddingSize from file (${fileSize / 1024 / 1024}MB)")
        } catch (e: OutOfMemoryError) {
            promise.reject(
                "OUT_OF_MEMORY",
                "Not enough memory to map embeddings file. File too large: ${e.message}",
                e
            )
        } catch (e: Exception) {
            promise.reject("LOAD_ERROR", "Failed to load embeddings: ${e.message}", e)
        }
    }

    /**
     * Reads a vector by index from the memory-mapped file
     * Does not load all data into memory
     */
    private fun getVector(index: Int): FloatArray {
        if (embeddingsBuffer == null) {
            throw IllegalStateException("Embeddings not loaded")
        }

        val startByte = index * embeddingSize * 4 // 4 bytes per float
        val vector = FloatArray(embeddingSize)

        embeddingsBuffer!!.position(startByte)
        embeddingsBuffer!!.asFloatBuffer().get(vector)

        return vector
    }

    /**
     * Loads captions from a JSON file
     * Uses streaming parsing for large files
     */
    @ReactMethod
    fun loadCaptions(captionsPath: String, promise: Promise) {
        try {
            val file = File(captionsPath)
            if (!file.exists()) {
                promise.reject("FILE_NOT_FOUND", "Captions file not found: $captionsPath")
                return
            }

            // Use BufferedReader for efficient reading of large files
            val captionsList = mutableListOf<String>()
            var buffer = StringBuilder()
            var inString = false
            var escapeNext = false

            BufferedReader(FileReader(file), 8192).use { reader ->
                var char: Int
                while (reader.read().also { char = it } != -1) {
                    when {
                        escapeNext -> {
                            buffer.append(char.toChar())
                            escapeNext = false
                        }
                        char == '\\'.code -> {
                            escapeNext = true
                        }
                        char == '"'.code -> {
                            if (inString) {
                                // End of string
                                captionsList.add(buffer.toString())
                                buffer = StringBuilder()
                                inString = false
                            } else {
                                // Start of string
                                inString = true
                            }
                        }
                        inString -> {
                            buffer.append(char.toChar())
                        }
                        char == ']'.code -> {
                            // End of array
                            break
                        }
                    }
                }
            }

            captions = captionsList
            val result = Arguments.createMap()
            result.putInt("count", captionsList.size)

            promise.resolve(result)
            println("Loaded ${captionsList.size} captions")
        } catch (e: OutOfMemoryError) {
            promise.reject(
                "OUT_OF_MEMORY",
                "Not enough memory to load captions. File too large: ${e.message}",
                e
            )
        } catch (e: Exception) {
            promise.reject("LOAD_ERROR", "Failed to load captions: ${e.message}", e)
        }
    }

    /**
     * Computes cosine distance between two vectors
     */
    private fun cosineDistance(vecA: FloatArray, vecB: FloatArray): Float {
        var dotProduct = 0f
        var normA = 0f
        var normB = 0f

        for (i in vecA.indices) {
            dotProduct += vecA[i] * vecB[i]
            normA += vecA[i] * vecA[i]
            normB += vecB[i] * vecB[i]
        }

        normA = sqrt(normA)
        normB = sqrt(normB)

        if (normA == 0f || normB == 0f) {
            return 1f
        }

        val similarity = dotProduct / (normA * normB)
        return 1f - similarity
    }

    /**
     * Finds top-K nearest vectors to the query
     * Search is performed in native memory, only results are sent to JS
     */
    @ReactMethod
    fun search(
        queryEmbedding: ReadableArray,
        topK: Int,
        promise: Promise
    ) {
        try {
            if (embeddingsBuffer == null) {
                promise.reject("NOT_LOADED", "Embeddings not loaded. Call loadEmbeddings first.")
                return
            }

            if (captions == null) {
                promise.reject("NOT_LOADED", "Captions not loaded. Call loadCaptions first.")
                return
            }

            // Convert query from ReadableArray to FloatArray
            val query = FloatArray(queryEmbedding.size())
            for (i in 0 until queryEmbedding.size()) {
                query[i] = queryEmbedding.getDouble(i).toFloat()
            }

            if (query.size != embeddingSize) {
                promise.reject(
                    "INVALID_SIZE",
                    "Query embedding size ${query.size} != $embeddingSize"
                )
                return
            }

            // Compute distances for all vectors
            // Vectors are read on demand from the memory-mapped file
            val distances = mutableListOf<Pair<Int, Float>>()

            for (i in 0 until numVectors) {
                // Read vector from file on demand
                val vector = getVector(i)
                val distance = cosineDistance(query, vector)
                distances.add(Pair(i, distance))
            }

            // Sort and take top-K
            distances.sortBy { it.second }
            val topResults = distances.take(topK)

            // Prepare results to send to JS
            val results = Arguments.createArray()
            for ((index, distance) in topResults) {
                val resultItem = Arguments.createMap()
                resultItem.putInt("index", index)
                resultItem.putDouble("distance", distance.toDouble())
                resultItem.putString("caption", captions!![index])
                results.pushMap(resultItem)
            }

            promise.resolve(results)
        } catch (e: Exception) {
            promise.reject("SEARCH_ERROR", "Search failed: ${e.message}", e)
        }
    }

    /**
     * Clears loaded data from memory
     */
    @ReactMethod
    fun clearCache(promise: Promise) {
        embeddingsBuffer = null
        embeddingsFile = null
        captions = null
        numVectors = 0
        promise.resolve(null)
    }

    /**
     * Checks whether data is loaded
     */
    @ReactMethod
    fun isLoaded(promise: Promise) {
        val result = Arguments.createMap()
        result.putBoolean("embeddingsLoaded", embeddingsBuffer != null)
        result.putBoolean("captionsLoaded", captions != null)
        result.putInt("numVectors", numVectors)
        promise.resolve(result)
    }
}

Enter fullscreen mode Exit fullscreen mode

The FaissSearchPackage plays a supporting role in this setup. Its purpose is to register the native module in React Native and inform the JavaScript part of the app that FaissSearchModule is available. All the actual logic and computation happens inside the module itself; the package serves purely as an infrastructure layer to connect it to the React Native framework.

package com.berkano

import com.facebook.react.ReactPackage
import com.facebook.react.bridge.NativeModule
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.uimanager.ViewManager

class FaissSearchPackage : ReactPackage {
    override fun createNativeModules(
        reactContext: ReactApplicationContext
    ): List<NativeModule> {
        return listOf(FaissSearchModule(reactContext))
    }

    override fun createViewManagers(
        reactContext: ReactApplicationContext
    ): List<ViewManager<*, *>> {
        return emptyList()
    }
}
Enter fullscreen mode Exit fullscreen mode

Besides working with the embedding files, there’s another problem that can’t be solved purely in JavaScript. All ONNX models in the project expect a tensor of fixed shape and format as input, whereas in React Native an image usually comes as a URI or a base64 string. Converting an image to a numeric tensor in JavaScript is possible, but in practice it either runs slowly or consumes a lot of memory, especially for camera images. For this reason, input preprocessing is also moved to the native layer.

For this, a separate native module, ImageDecoderModule, was implemented. Its role is to take an image URI, load it via ContentResolver, decode it into a Bitmap, resize it to the required dimensions, and convert it into a CHW-format tensor expected by the PyTorch-trained models exported to ONNX. At this stage, pixel values are normalized to the [0,1] range, and RGB channels are explicitly separated so that the resulting array exactly matches the model input.

Importantly, all of this logic runs entirely on the Android side before passing the data to JavaScript. React Native cannot directly handle FloatArray, so the tensor is converted into a WritableArray, which is then safely passed across the bridge. This approach ensures the array is created once in native code without intermediate copies or unnecessary allocations in the JS engine, significantly reducing memory usage and speeding up data preparation for inference.

package com.berkano;

import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.net.Uri
import com.facebook.react.bridge.*
import java.io.InputStream

class ImageDecoderModule(
  private val reactContext: ReactApplicationContext
) : ReactContextBaseJavaModule(reactContext) {

  override fun getName(): String {
    return "ImageDecoder"
  }

  @ReactMethod
  fun decodeToTensor(
    uriString: String,
    targetWidth: Int,
    targetHeight: Int,
    promise: Promise
  ) {
    try {
      val uri = Uri.parse(uriString)
      val inputStream: InputStream? =
        reactContext.contentResolver.openInputStream(uri)

      if (inputStream == null) {
        promise.reject("ERROR", "Cannot open image stream")
        return
      }

      var bitmap: Bitmap? = null
      try {
        bitmap = BitmapFactory.decodeStream(inputStream)
        if (bitmap == null) {
          promise.reject("ERROR", "Cannot decode image")
          return
        }
        bitmap = Bitmap.createScaledBitmap(bitmap, targetWidth, targetHeight, true)
      } finally {
        inputStream.close()
      }

      val width = bitmap.width
      val height = bitmap.height

      val pixels = IntArray(width * height)
      bitmap.getPixels(pixels, 0, width, 0, 0, width, height)

      // CHW: [3, H, W] — format used for passing data to JavaScript
      val tensor = FloatArray(3 * width * height)

      for (y in 0 until height) {
        for (x in 0 until width) {
          val color = pixels[y * width + x]

          val r = (color shr 16 and 0xFF) / 255f
          val g = (color shr 8 and 0xFF) / 255f
          val b = (color and 0xFF) / 255f

          val idx = y * width + x
          tensor[idx] = r
          tensor[width * height + idx] = g
          tensor[2 * width * height + idx] = b
        }
      }

      // Convert FloatArray to WritableArray for proper transfer to JavaScript
      // React Native cannot directly convert FloatArray, so WritableArray is used
      val writableArray = Arguments.createArray()
      for (value in tensor) {
        // pushDouble accepts double, so convert float to double
        writableArray.pushDouble(value.toDouble())
      }

      promise.resolve(writableArray)
    } catch (e: Exception) {
      promise.reject("ERROR_DECODING_IMAGE", e)
    }
  }
}

Enter fullscreen mode Exit fullscreen mode

Just like with the vector search module, ImageDecoderModule is registered via a separate ImageDecoderPackage, which is then included in MainApplication.kt. Without this, React Native won’t recognize the module. Additionally, the required permissions must be explicitly declared in AndroidManifest.xml; otherwise, the app won’t be able to access images. All these details are available in the project repository, in the Android part of the mobile application.

package com.berkano;

import com.facebook.react.ReactPackage
import com.facebook.react.bridge.NativeModule
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.uimanager.ViewManager

class ImageDecoderPackage : ReactPackage {

  override fun createNativeModules(
    reactContext: ReactApplicationContext
  ): List<NativeModule> {
    return listOf(ImageDecoderModule(reactContext))
  }

  override fun createViewManagers(
    reactContext: ReactApplicationContext
  ): List<ViewManager<*, *>> {
    return emptyList()
  }
}
Enter fullscreen mode Exit fullscreen mode

In the end, all heavy and memory-sensitive operations remain on the native side, while JavaScript only works with preprocessed data and inference results.

Now, for running the models on a mobile device, I created a separate file that handles preparing, copying, initializing models, and running inference. On the first run, the code checks whether the model files and associated assets exist in the local filesystem. If not, they are copied from the assets folder (on Android). The logic is implemented in the copyAssetIfNeeded function. On Android, it’s recommended to copy large files using copyFileAssets.

The prepareModelAssets function gathers all necessary paths to models, embeddings, and auxiliary files, returning an object with these paths for later use. ONNX models are initialized via initializeModel, which creates an ONNX session with NNAPI priority for GPU/NPU, falling back to CPU if needed.

Inference is performed in runInference, where a Float32Array containing the image tensor (prepared by the native modules, as described in the ImageDecoderModule) is passed in. The function constructs a tensor in the correct format for the model, e.g., [1, 3, 224, 224] for disease detection, runs the model, and returns a normalized embedding.

import * as ort from 'onnxruntime-react-native';
import RNFS from 'react-native-fs';
import { Platform } from 'react-native';

async function copyAssetIfNeeded(assetName: string, folder: string): Promise<string> {
  const localPath = `${RNFS.DocumentDirectoryPath}/${assetName}`;

  const exists = await RNFS.exists(localPath);
  if (!exists) {
    try {
      if (Platform.OS === 'android') {
        // Android: read from assets
        // Path must be relative to the assets folder: folder/assetName
        const assetPath = `${folder}/${assetName}`;
        console.log(`Attempting to copy from assets: ${assetPath} to ${localPath}`);
        await RNFS.copyFileAssets(assetPath, localPath);
        console.log(`Successfully copied ${assetName} to local FS: ${localPath}`);
      } else {
        // iOS: read from the main bundle
        const source = `${RNFS.MainBundlePath}/${folder}/${assetName}`;
        console.log(`Attempting to copy from bundle: ${source} to ${localPath}`);
        await RNFS.copyFile(source, localPath);
        console.log(`Successfully copied ${assetName} to local FS: ${localPath}`);
      }
    } catch (err: any) {
      console.error(`Failed to copy ${assetName} from ${folder}/${assetName}:`, err);
      console.error(`Error details:`, {
        message: err?.message,
        code: err?.code,
        platform: Platform.OS,
        localPath,
        assetPath: `${folder}/${assetName}`,
      });
      // Try an alternative path without the folder (in case the file is in the root of assets)
      if (Platform.OS === 'android') {
        try {
          console.log(`Trying alternative path: ${assetName}`);
          await RNFS.copyFileAssets(assetName, localPath);
          console.log(`Successfully copied using alternative path`);
        } catch (altErr) {
          console.error(`Alternative path also failed:`, altErr);
          throw new Error(
            `Failed to copy ${assetName}. Please verify that the file is located in assets/${folder}/`
          );
        }
      } else {
        throw err;
      }
    }
  } else {
    console.log(`${assetName} already exists at ${localPath}`);
  }

  return localPath;
}

export async function prepareModelAssets() {
  const modelPath = await copyAssetIfNeeded('disease_detection.onnx', 'models');
  const embeddingsPath = await copyAssetIfNeeded('embeddings.bin', 'files');
  const captionsPath = await copyAssetIfNeeded('captions.json', 'files');
  const classMappingPath = await copyAssetIfNeeded('class_mapping.txt', 'files');
  const speciesMappingPath = await copyAssetIfNeeded('species_id_to_name.txt', 'files');

  return { 
    modelPath, 
    embeddingsPath, 
    captionsPath,
    classMappingPath,
    speciesMappingPath,
  };
}

export interface ModelSession {
  session: ort.InferenceSession;
  inputName: string;
  outputName: string;
}

let modelSession: ModelSession | null = null;

/**
 * Initializes the ONNX model with GPU/NPU priority via NNAPI
 */
export async function initializeModel(): Promise<ModelSession> {
  if (modelSession) {
    return modelSession;
  }

  try {
    // Use the path obtained from prepareModelAssets
    const { modelPath } = await prepareModelAssets();

    // Use the local model path
    const modelUri = modelPath;

    // Session options with NNAPI priority (GPU/NPU)
    const sessionOptions: ort.InferenceSession.SessionOptions = {
      executionProviders: ['nnapi', 'cpu'], // NNAPI for GPU/NPU, CPU as fallback
      graphOptimizationLevel: 'all',
    };

    // Create the inference session
    // onnxruntime-react-native can work directly with file paths
    const session = await ort.InferenceSession.create(modelUri, sessionOptions);

    // Retrieve input and output names
    const inputNames = session.inputNames;
    const outputNames = session.outputNames;

    if (inputNames.length === 0 || outputNames.length === 0) {
      throw new Error('The model has no inputs or outputs');
    }

    modelSession = {
      session,
      inputName: inputNames[0],  // Assume the first input
      outputName: outputNames[0], // Assume the first output
    };

    console.log('Model successfully loaded');
    console.log('Input name:', modelSession.inputName);
    console.log('Output name:', modelSession.outputName);

    return modelSession;
  } catch (error) {
    console.error('Error while loading the model:', error);
    throw error;
  }
}

/**
 * Runs model inference on an image
 * @param imageTensor Image tensor in [1, 3, 224, 224] format
 * @returns Image embedding
 */
export async function runInference(
  imageTensor: Float32Array,
): Promise<Float32Array> {
  if (!modelSession) {
    throw new Error('Model is not initialized. Call initializeModel() first.');
  }

  try {
    // Create input tensor
    // Format: [batch, channels, height, width] = [1, 3, 224, 224]
    const inputTensor = new ort.Tensor('float32', imageTensor, [1, 3, 224, 224]);

    // Run inference
    const feeds = { [modelSession.inputName]: inputTensor };
    const results = await modelSession.session.run(feeds);

    // Retrieve output tensor
    const outputTensor = results[modelSession.outputName];
    const embedding = outputTensor.data as Float32Array;

    // Normalize the embedding
    const norm = Math.sqrt(
      Array.from(embedding).reduce((sum, val) => sum + val * val, 0),
    );
    const normalizedEmbedding = new Float32Array(embedding.length);
    for (let i = 0; i < embedding.length; i++) {
      normalizedEmbedding[i] = embedding[i] / norm;
    }

    return normalizedEmbedding;
  } catch (error) {
    console.error('Error during inference:', error);
    throw error;
  }
}

/**
 * Releases model resources
 */
export function disposeModel(): void {
  if (modelSession) {
    modelSession.session.release();
    modelSession = null;
  }
}

Enter fullscreen mode Exit fullscreen mode

To manage multiple models, I created modelManager.ts. It allows registering different models (disease, plant, age) and managing their loading into memory. Only two models are kept in memory at the same time to avoid overloading the device. If the limit is reached, the least recently used model is unloaded.

import * as ort from 'onnxruntime-react-native';
import RNFS from 'react-native-fs';
import { Platform } from 'react-native';

export type ModelType = 'disease' | 'plant' | 'age';

export interface ModelInfo {
  type: ModelType;
  path: string;
  session: ort.InferenceSession | null;
  inputName: string;
  outputName: string;
  isLoaded: boolean;
  lastUsed: number;
}

// Maximum number of models allowed in memory at the same time
const MAX_MODELS_IN_MEMORY = 2;

// Model cache
const modelCache: Map<ModelType, ModelInfo> = new Map();

/**
 * Copies a file from assets to the local file system
 */
async function copyAssetIfNeeded(assetName: string, folder: string): Promise<string> {
  const localPath = `${RNFS.DocumentDirectoryPath}/${assetName}`;

  const exists = await RNFS.exists(localPath);
  if (!exists) {
    try {
      if (Platform.OS === 'android') {
        const assetPath = `${folder}/${assetName}`;
        console.log(`Copying ${assetName} from assets...`);
        await RNFS.copyFileAssets(assetPath, localPath);
        console.log(`Successfully copied ${assetName}`);
      } else {
        const source = `${RNFS.MainBundlePath}/${folder}/${assetName}`;
        await RNFS.copyFile(source, localPath);
      }
    } catch (err: any) {
      console.error(`Failed to copy ${assetName}:`, err);
      if (Platform.OS === 'android') {
        try {
          await RNFS.copyFileAssets(assetName, localPath);
        } catch (altErr) {
          throw new Error(`Failed to copy ${assetName}`);
        }
      } else {
        throw err;
      }
    }
  }

  return localPath;
}

/**
 * Unloads the least recently used model from memory
 */
function unloadLeastUsedModel(): void {
  let leastUsed: ModelInfo | null = null;
  let leastUsedTime = Date.now();

  for (const model of modelCache.values()) {
    if (model.isLoaded && model.lastUsed < leastUsedTime) {
      leastUsed = model;
      leastUsedTime = model.lastUsed;
    }
  }

  if (leastUsed) {
    console.log(`Unloading model: ${leastUsed.type}`);
    if (leastUsed.session) {
      leastUsed.session.release();
      leastUsed.session = null;
      leastUsed.isLoaded = false;
    }
  }
}

/**
 * Loads a model into memory
 */
async function loadModel(modelInfo: ModelInfo): Promise<void> {
  if (modelInfo.isLoaded && modelInfo.session) {
    modelInfo.lastUsed = Date.now();
    return;
  }

  // Check if the memory limit for loaded models has been reached
  const loadedModels = Array.from(modelCache.values()).filter(m => m.isLoaded);
  if (loadedModels.length >= MAX_MODELS_IN_MEMORY) {
    console.log('Memory limit reached, unloading least recently used model...');
    unloadLeastUsedModel();
  }

  try {
    console.log(`Loading model: ${modelInfo.type}`);

    const sessionOptions: ort.InferenceSession.SessionOptions = {
      executionProviders: ['nnapi', 'cpu'],
      graphOptimizationLevel: 'all',
    };

    const session = await ort.InferenceSession.create(modelInfo.path, sessionOptions);

    const inputNames = session.inputNames;
    const outputNames = session.outputNames;

    if (inputNames.length === 0 || outputNames.length === 0) {
      throw new Error('The model has no inputs or outputs');
    }

    modelInfo.session = session;
    modelInfo.inputName = inputNames[0];
    modelInfo.outputName = outputNames[0];
    modelInfo.isLoaded = true;
    modelInfo.lastUsed = Date.now();

    console.log(`Model ${modelInfo.type} loaded successfully`);
  } catch (error) {
    console.error(`Error loading model ${modelInfo.type}:`, error);
    throw error;
  }
}

/**
 * Initializes the model management system
 */
export async function initializeModelManager(): Promise<void> {
  // Prepare model paths
  const diseaseModelPath = await copyAssetIfNeeded('disease_detection.onnx', 'models');
  const plantModelPath = await copyAssetIfNeeded('plant_classification.onnx', 'models');
  const ageModelPath = await copyAssetIfNeeded('plant_analysis.onnx', 'models');

  // Register models
  modelCache.set('disease', {
    type: 'disease',
    path: diseaseModelPath,
    session: null,
    inputName: '',
    outputName: '',
    isLoaded: false,
    lastUsed: 0,
  });

  modelCache.set('plant', {
    type: 'plant',
    path: plantModelPath,
    session: null,
    inputName: '',
    outputName: '',
    isLoaded: false,
    lastUsed: 0,
  });

  modelCache.set('age', {
    type: 'age',
    path: ageModelPath,
    session: null,
    inputName: '',
    outputName: '',
    isLoaded: false,
    lastUsed: 0,
  });

  console.log('Model manager initialized');
}

/**
 * Retrieves a model (loads it if necessary)
 */
export async function getModel(modelType: ModelType): Promise<ModelInfo> {
  const modelInfo = modelCache.get(modelType);
  if (!modelInfo) {
    throw new Error(`Model ${modelType} not found`);
  }

  await loadModel(modelInfo);
  return modelInfo;
}

/**
 * Unloads a model from memory
 */
export function unloadModel(modelType: ModelType): void {
  const modelInfo = modelCache.get(modelType);
  if (modelInfo && modelInfo.session) {
    console.log(`Unloading model: ${modelType}`);
    modelInfo.session.release();
    modelInfo.session = null;
    modelInfo.isLoaded = false;
  }
}

/**
 * Unloads all models from memory
 */
export function unloadAllModels(): void {
  for (const modelType of modelCache.keys()) {
    unloadModel(modelType);
  }
}

/**
 * Checks whether a model is loaded
 */
export function isModelLoaded(modelType: ModelType): boolean {
  const modelInfo = modelCache.get(modelType);
  return modelInfo?.isLoaded ?? false;
}

Enter fullscreen mode Exit fullscreen mode

For loading and running the models, I use the same approach with onnxruntime-react-native. After obtaining the modelInfo object via getModel:

const modelInfo = await getModel('age');
Enter fullscreen mode Exit fullscreen mode

an image tensor is created:

const inputTensor = new ort.Tensor('float32', imageTensor, [1, 3, n, n]);
Enter fullscreen mode Exit fullscreen mode

This tensor is fed into the model, and the output tensors are returned immediately.

const feeds = { [modelInfo.inputName]: inputTensor };
const results = await modelInfo.session.run(feeds);
Enter fullscreen mode Exit fullscreen mode

The examples available in plantClassificationService.ts and diseaseSearchService.ts demonstrate how to feed embeddings, normalize them, and perform vector search using Faiss. This approach is useful not only for this app but also for any mobile project that deals with large models and vector search.

A few words about the project itself and its purpose. At first glance, it might seem like just a collection of neural networks for plants, running on mobile devices with some optimizations. In reality, it’s an experiment in transferring complex models to resource-constrained devices, managing memory efficiently, and building unified code for different types of tasks. I wrote it not only for readers but also as notes for myself.

If you’ve made it to the end of this article, it means you enjoy diving into the details and aren’t afraid of challenges. The project is open-source under the MIT license, and all code is available on GitHub, so anyone can try it out, fine-tune models, optimize performance, or suggest new features. Finally, I invite everyone interested to join as contributors: every idea, optimization, or small feature makes the project better and more interesting. Sometimes a small experiment leads to unexpected discoveries, and this project is all about experimenting and finding those discoveries.

Top comments (0)