DEV Community

Cover image for On-the-fly zstandard (de)compression in Elixir
NDREAN
NDREAN

Posted on

On-the-fly zstandard (de)compression in Elixir

Notes on building Elixir bindings for the zstandard C library.

We use Zig/Zigler to wrap the zstandard primitives and integrate into Elixir.

Why Zig?: We have the amazing library Zigler that automatically generates NIFs for us from Zig code, and Zig seamlessly integrates with C libraries.

Why Zstd?: it offers excellent compression ratios and very fast decompression.

What are we doing?: We focus on the streaming implementation. We will:

  • stream files with on-the-fly (de)compression
  • make HTTP streams with on-the-fly (de)compression

Streaming implementation Trade-offs and dirty schedulers

Zstd Compression: computational costs are strategy-dependent, with a simple implementation even with streams.
For example, it takes 6.6ms to compress a 100kB PNG image with compression level 22, and 0.7ms with compression level 9, and 0.07ms for the lowest compression level 1 (with the fastest algorithm). For this reason, a compression NIF needs to be set with dirty_cpu as this setting is compile_time.

Zstd decompression: significantly faster but streaming requires careful implementation. The reason is the misalignment between chunks and frames explained after.

In terms of speed, the decompression process is significantly faster than the compression process. it takes 100ms to decompress the previous compressed PNG image (90kB-> 100kB).

Whilst decompressing chunks of 100kB stays well beyond 1ms - which is our focus here-, a one-shot decompression of a 9MB file can take up to 19ms. It is still reasonably fast, but can't be run safely as a NIF. You may need to set dirty_cpu for one-shot decompression whilst chunks decompression can safely be run synchronously.

The Chunk/Frame misalignment problem

We use zstd's recommended chunk size (~128KB) obtained programmatically.

Chunk boundaries never align with zstd frame boundaries — you might read half a frame or multiple frames in one chunk.

The drawings below explains this.

▶️ ITERATION 1: Read first chunk

Step Description
Read 128 kB from HTTP/file
Input decompress_stream(data = 128 kB)

Result:

Frame 1 Frame 2 Frame 3
(whole): 45 kB (whole): 60 kB (start): 23 kB
  • output: 200KB decompressed
  • consumed: 105KB ((Frame 1 + Frame 2)
  • remaining: 23 kB => [Buffer for next iteration]

▶️ ITERATION 2: concatenate remaining data to the next chunk

Step Description
Read 128 kB from HTTP/file
Input decompress_stream(data = 23 kB + 128 kB)

Result:

Frame 3 Frame 4 Frame 5 Frame 6
23 kB + (rest: 24 kB) (whole): 55 kB (whole): 40 kB (start): 9 kB
  • output: 200 kB
  • consumed: 142KB (Frame 3 + Frame 4 + Frame 5)
  • remaining: 9 kB => [Buffer for next iteration]

▶️ ... continues until EOF ...

Build Elixir bindings using Zig via Zigler.

We assume you have zstandard installed on your machine so it will be declared as a system dependency.

We use the C_ABI interop.

The following Elixir module is where we configure Zigler to process our Zig code. Then Zigler will build NIFs for us.
Naturally, the Zig code is adapted to use Zigler (check it at the end).

We firstly declare:

  • our dependencies (the system zstd library),
  • how to compile the Zig code (the release mode)
  • what are our NIFs (we used the placeeholder ... which means we let Zigler do the work for us).
  • resources if any.

Notice that the compress_stream NIF is declared as dirty_cpu. As explained before, this is because it can potentially take more than 1ms to run. Since the decompress_stream is significantly and consistently faster to process chunks and typically processes chunks of 100 kB in less than 400µs, it can be safely run synchronously by the BEAM.

We also declare resources. These are long living processes that we want to reuse between runs: our (de)compression zstd configuration.

defmodule ExZstandard do
  @check_leak Mix.env() in [:test, :dev]
  @release if Mix.env() == :prod, do: :fast, else: :debug

  use Zig,
    otp_app: :ex_zstandard,
    c: [link_lib: {:system, "zstd"}],
    zig_code_path: "lib.zig",
    release_mode: @release,
    leak_check: @check_leak,
    nifs: [
      ...,
      compress_stream: [:dirty_cpu]
    ],
    resources: [:ZstdCResource, :ZstdDResource]
Enter fullscreen mode Exit fullscreen mode

Stream a file with real-time compression

def compress_file(input_path, output_path, cctx) do
    # use a ZSTD primitive to determine a chunk size
    chunk_size = recommended_c_in_size()

    Stream.resource(
      # start_fun: state= {file_pid, true|false}
      fn ->
        {File.open!(input_path, [:read, :binary]), false}
      end,

      # read_fun: the "element" is emitted and consumed by the next stream (Stream.File here)
      # (state = {pid, false}) -> {[element], state} 
      # (state = {pid, true})   -> {:halt, state}

      fn
        {file_pid, true} ->
          {:halt, file_pid}

        {file_pid, false} ->
          case IO.binread(file_pid, chunk_size) do
            :eof ->
              {:ok, {final, _, _}} = compress_stream(cctx, <<>>, :end_frame)
              {[final], {file_pid, true}}

            {:error, reason} ->
              raise "Failed to read file: #{inspect(reason)}"

            data ->
              {:ok, {compressed, _, _}} = compress_stream(cctx, data, :flush)
              {[compressed], {file_pid, false}}
          end
      end,

      # after_fun
      fn file_pid -> File.close(file_pid) end
    )
    |> Stream.into(File.stream!(output_path, [:append]))
    |> Stream.run()

    :ok
  end
end
Enter fullscreen mode Exit fullscreen mode

Stream HTTP download with real-time compression

We use Req, one of the best Elixir libraries. We send the response body to the field :into where a function - that respects Req's specs - is waiting to process the received streams.

def compress_file(url, output_path, cctx)
  compressed_pid = File.open!(output_path, [:write, :binary])

  Req.get!(url,
   into: fn
     {:data, chunk}, {req, resp} ->  
       {:ok, {compressed, _, _}} = 
          ExZstandard.compress_stream(cctx, chunk, :flush)
       :ok = IO.binwrite(compressed_pid, compressed)
       {:cont, {req, resp}}
    end
 )

 :ok = File.close(compressed_pid)
end
Enter fullscreen mode Exit fullscreen mode

Stream a file with real-time decompression

The loop is:

  • We process data in the function decompress_stream at each "next-step". It will return a decompressed chunk and the unconsumed data size.

  • We extract the leftover with binary_part/3 and return it.

  • At each "next step", we can concatenation of the remaining data and the new chunk and process it.

When it receives the :eof, there's usually unconsumed data (partial frame).

The "drain" helper function will recursively decompress_stream the buffer.
Once the buffer is empty, it will return :done so that the stream can halt properly and proceed to the "after" step.

def decompress_file(input_path, output_path, dctx) do
  Stream.resource(
    # start_fun: () -> state = {pid, unconsumed_buffer = <<>>}
    fn -> {File.open!(input_path, [:read, :binary]), <<>>} end,

    # read_fun : element "decompressed" is emitted to next stream step (Stream.write)
    # state = {pid, unconsumed_buffer} -> {[decompressed], {pid, unconsumed_buffer}}
    # state = {:done, pid}             -> {:halt, state}
    fn
      # Already processed EOF and emitted final data, now halt and pass the pid to the "after_step" to conclude
      {:done, file_pid} ->
        {:halt, file_pid}

      {file_pid, buffer} ->
        case IO.binread(file_pid, chunk_size) do
           :eof when buffer == <<>> ->
              # No more data to process
              {:halt, file_pid}

           :eof ->
                # Process all remaining buffered data
                # One decompress_stream call may not consume everything, so loop until empty
                decompressed_chunks = drain_buffer(dctx, buffer, [])
                # Emit chunks to the stream and mark as done (will halt on next call)
                {Enum.reverse(decompressed_chunks), {:done, file_pid}}

          {:error, reason} ->
                raise "Failed to read file: #{inspect(reason)}"

          chunk -> 
                data = buffer <> chunk # append chunk to buffer
                {:ok, {decompressed, bytes_consumed}} = ExZstandard.decompress_stream(dctx, data)

                # Keep unconsumed bytes for next iteration
                remaining = binary_part(data, bytes_consumed, byte_size(data) - bytes_consumed)

                {[decompressed], {file_pid, remaining}}
          end
    end,

    # after_fun: (acc) -> ()
    fn file_pid -> File.close(file_pid) end
  )
  |> Stream.into(File.stream!(output_path, [:append]))
  |> Stream.run()
end
Enter fullscreen mode Exit fullscreen mode

and we "drain" the residual buffer:

defp drain_buffer(_dctx, <<>>, acc), do: acc

defp drain_buffer(dctx, buffer, acc) do
  {:ok, {decompressed, bytes_consumed}} = decompress_stream(dctx, buffer)

  if bytes_consumed == 0 do
    # Can't make progress - corrupted data
    raise "Decompression stalled with #{byte_size(buffer)} bytes remaining"
  end

  remaining = binary_part(buffer, bytes_consumed, byte_size(buffer) - bytes_consumed)

  drain_buffer(dctx, remaining, [decompressed | acc])
end
Enter fullscreen mode Exit fullscreen mode

Stream HTTP download with real-time decompression

The Req.Response has a builtin :private key that can be used to save the unconsumed data as a temporary state. We can then retrieve it, consume it and update it every time we receive data.

def download_decompress(url, path, dctx) do
    decompressed_pid = File.open!(path, [:write, :binary])

    # Download and decompress chunks as they arrive
    result =
      Req.get!(url,
        into: fn
          {:data, chunk}, {req, resp} ->
            buffer = Req.Response.get_private(resp, :buffer, <<>>). # Get buffer from previous iteration (unconsumed bytes)

            data = buffer <> chunk. # Concatenate with new chunk

            # Decompress what we can
            {:ok, {decompressed, bytes_consumed}} = ExZstandard.decompress_stream(dctx, data)

            :ok = IO.binwrite(decompressed_pid, decompressed)

            # Keep unconsumed bytes for next iteration
            remaining = binary_part(data, bytes_consumed, byte_size(data) - bytes_consumed)

            # Store buffer in response private for next iteration
            updated_resp = Req.Response.update_private(resp, :buffer, <<>>, fn _ -> remaining end)

            {:cont, {req, updated_resp}}
        end
      )

    # Process any remaining buffered data after stream ends (handle connection close!)
    final_buffer = Req.Response.get_private(result, :buffer, <<>>)

    if byte_size(final_buffer) > 0 do
      # Drain the final buffer (like in decompress_file)
      decompressed_chunks = drain_buffer(dctx, final_buffer, [])

      Enum.each(Enum.reverse(decompressed_chunks), fn chunk ->
        :ok = IO.binwrite(decompressed_pid, chunk)
      end)
    end

    :ok = File.close(decompressed_pid)
  end
Enter fullscreen mode Exit fullscreen mode

Elixir module usage

# compression
{:ok, cctx} = ExZstandard.cctx_init(3, 2) # strategy: .ZSTD_dfast

File.read!(path)
|> ExZstandard.compress_file(output_path, cctx)

url_of_data_to_compress
|> ExZstandard.download_compress(output_path, cctx)

# decompression
{:ok, dctx} = ExZstandard.dctx_init(nil) # no max_window limitation

File.read!(path)
|> ExZstandard.decompress_file(output_path, dctx)

url_of_compressed_data
|> ExZstandard.download_decompress(output_path, dctx)
Enter fullscreen mode Exit fullscreen mode

Zig code

We declared above that we localise the Zig module in the file lib/lib.zig (to take advantage of the language server).

For brevity, we expose here only the "compression" code. The "decompression" is similar.

// lib.zig

const z = @cImport({
  @cInclude("zstd");
})

// this will be our resource: an optional pointer
// to the opaque zstandard context
const ZstdCCtx = struct {
    cctx: ?*z.ZSTD_CCtx,
};

/// compression context resource
pub const ZstdCResource = beam.Resource(
    *ZstdCCtx,
    @import("root"),
    .{ .Callbacks = ZstdCCtxCallback },
);

/// NIF callback to free the ZstdCCtx when the resource
/// is garbage collected
pub const ZstdCCtxCallback = struct {
    pub fn dtor(handle: **ZstdCCtx) void {
        _ = z.ZSTD_freeCCtx(handle.*.cctx);
        beam.allocator.destroy(handle.*);
        if (@import("builtin").mode == .Debug) std.debug.print("CDOTR called\n", .{});
    }
};

/// Instantiate the reusable compression context. 
/// Args: compression level and compression algorithm in the ZSTD_CCtx.
/// Returns {:ok, cctx} or {:error, reason}
pub fn cctx_init(level: i32, strategy: i16) ZstdError!beam.term {
    if (level < z.ZSTD_minCLevel() or level > z.ZSTD_maxCLevel()) {
        return beam.make_error_pair(ZstdError.InvalidCompressionLevel, .{});
    }

    // Create the ZstdCCtx struct
    const ctx = beam.allocator.create(ZstdCCtx) catch {
        return beam.make_error_pair(ZstdError.OutOfMemory, .{});
    };
    // zigler will free it in the resource callback
    // but if any error, we free it.
    errdefer beam.allocator.destroy(ctx);

    // Create the libzstd compression context
    ctx.cctx = z.ZSTD_createCCtx() orelse {
        return beam.make_error_pair(ZstdError.OutOfMemory, .{});
    };
    errdefer _ = z.ZSTD_freeCCtx(ctx.cctx);

    // Set compression level
    var result = z.ZSTD_CCtx_setParameter(
        ctx.cctx,
        z.ZSTD_c_compressionLevel,
        level,
    );
    if (z.ZSTD_isError(result) != 0) {
        const err_name = std.mem.span(z.ZSTD_getErrorName(result));
        std.log.err("Failed to set compression level: {s}", .{err_name});
        return beam.make_error_pair(ZstdError.InvalidCompressionLevel, .{});
    }

    // Set strategy from recipe or use default
    const strategy = if (config.strategy) |recipe|
        recipe.getStrategy()
    else
        ZSTD_strategy.ZSTD_dfast;

    result = z.ZSTD_CCtx_setParameter(
        ctx.cctx,
        z.ZSTD_c_strategy,
        @intFromEnum(strategy),
    );
    if (z.ZSTD_isError(result) != 0) {
        const err_name = std.mem.span(z.ZSTD_getErrorName(result));
        std.log.err("Failed to set strategy: {s}", .{err_name});
        return beam.make_error_pair(ZstdError.InvalidCompressionLevel, .{});
    }

    // Wrap in resource
    const resource = ZstdCResource.create(ctx, .{}) catch {
        return beam.make_error_pair(ZstdError.OutOfMemory, .{});
    };
    return beam.make(.{ .ok, resource }, .{});
}

pub fn compress_stream(ctx: ZstdCResource, input: []const u8, end_op: EndOp) ZstdError!beam.term {
    const cctx = ctx.unpack().*.cctx.?;

    // Allocate output buffer using recommended size
    const out_buf_size = z.ZSTD_CStreamOutSize();
    const output_data = beam.allocator.alloc(u8, out_buf_size) catch {
        return beam.make_error_pair(ZstdError.OutOfMemory, .{});
    };
    errdefer beam.allocator.free(output_data);

    // Setup buffers
    var in_buf = z.ZSTD_inBuffer_s{
        .src = input.ptr,
        .size = input.len,
        .pos = 0,
    };

    var out_buf = z.ZSTD_outBuffer_s{
        .dst = output_data.ptr,
        .size = out_buf_size,
        .pos = 0,
    };

    // Compress
    const remaining = z.ZSTD_compressStream2(
        cctx,
        &out_buf,
        &in_buf,
        end_op.toZstd(),
    );

    if (z.ZSTD_isError(remaining) != 0) {
        const err_name = std.mem.span(z.ZSTD_getErrorName(remaining));
        std.log.err("Stream compression failed: {s}", .{err_name});
        return beam.make_error_pair(ZstdError.StreamCompressionFailed, .{});
    }

    // Resize output to actual size
    const compressed = beam.allocator.realloc(output_data, out_buf.pos) catch {
        return beam.make_error_pair(ZstdError.OutOfMemory, .{});
    };
    defer beam.allocator.free(compressed);

    return beam.make(
        .{ .ok, .{ compressed, in_buf.pos, remaining } },
        .{},
    );
}
Enter fullscreen mode Exit fullscreen mode

If you may want to measure the execution time, you can add the lines below in the function you want to monitor:

const start_time = std.time.nanoTimestamp();
defer {
  const elapsed_ns = std.time.nanoTimestamp() - start_time;
  const elapsed_ms = @as(f64, @floatFromInt(elapsed_ns)) / 1_000_000.0;

  std.debug.print("[NIF duration]: {d:.3} ms, for: {} bytes)\n", .{ elapsed_ms });
};
Enter fullscreen mode Exit fullscreen mode

Top comments (0)