DEV Community

Cover image for Mandelbrot set with Numerical Elixir and Zigler
NDREAN
NDREAN

Posted on

Mandelbrot set with Numerical Elixir and Zigler

The code below is a Livebook.

Run in Livebook

Mix.install(
  [
    {:nx, "~> 0.9.1"},
    {:exla, "~> 0.9.1"},
    {:kino, "~> 0.14.2"},
    {:zigler, "~> 0.13.3"},
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Nx.Defn.global_default_options(compiler: EXLA, client: :host)
Enter fullscreen mode Exit fullscreen mode

Warning

If you run into an error above, this means that you don't have Zig. Comment the package Zigler above.
This means you can't run the very last module where we use inline Zig code.

Introduction

We want to produce an image that represents the beautiful Mandelbrot set

Source: https://en.wikipedia.org/wiki/Mandelbrot_set

We use the Nx library with the EXLA backend to speed up the computations.

We also propose to run the equivalent code in Zig in this Livebook if you want extra speed. This happens thanks to the Zigler library. The Zig code returns a binary that Nx is able to consume and Kino to display.

What is a Mandelbrot set?

In a "Mandlebrot image", each pixel has a colour repesenting how fast the underlying point "escapes" when calculating its iterates under a certain function.

What is an underlying point?

A pixel has some coordinates [i,j]. For example, in a 1024 × 768 image (WIDTH x HEIGHT), the row number varies from from 0 to 1023 and column number from 0 to 767.

We transform these couples of integers (i,j) into a point into a 2D plane. This map "quantitizes" the 2D plane.

Here, the 2D "real" plan is defined by the upper left corner, say (-2,1), and bottom right corner, say (1,-1).

How? We have a linear mapping between the couple (i,j) and a point (x,y) in the defined zone. For example, the pixel (0,0) becomes (-2,1) and the pixel (767, 1023) becomes (1,-1).

What is iterating?

We will iterate the function: x -› x*x +c where c is a given number and x the variable.

We start with:

z0 = f(0) = c
z1 = f(z0) = z0 * z0 + c
z2 = f(z1) = z1 * z1 + c
...
Enter fullscreen mode Exit fullscreen mode

Let's take an example. The module below calculates the iterations x(n) = f(x(n-1)) by a simple recursion.

The sets of these iterates of c is called its orbit .

defmodule Simple do
  def p(x,c), do: x**2 + c

  # initial value
  def iterate(1,c), do: c
  # the n-th step
  def iterate(n,c), do: p(iterate(n-1, c), c)
end
Enter fullscreen mode Exit fullscreen mode

We calculate the first elements of its orbit and evaluate how does the point c=1 behaves. It looks like it will diverge to infinity.

c = 1
{ c, 
  Simple.iterate(1,c), Simple.iterate(2,c), Simple.iterate(3,c), Simple.iterate(4,c), Simple.iterate(5,c), Simple.iterate(6,c),
}
Enter fullscreen mode Exit fullscreen mode

gives:

{1, 1, 2, 5, 26, 677, 458330}
Enter fullscreen mode Exit fullscreen mode

On the other side, the point c=-1 seems well bahaved: the orbit has only two values, 0 and - 1, and is periodic.

c = -1

{ c, 
  Simple.iterate(1,c), Simple.iterate(2,c), Simple.iterate(3,c), Simple.iterate(4,c),Simple.iterate(5,c),Simple.iterate(6,c),
}
Enter fullscreen mode Exit fullscreen mode

gives:

{-1, -1, 0, -1, 0, -1, 0}
Enter fullscreen mode Exit fullscreen mode

In the examples above, we took a simple "real" number.

For the Mandelbrot set, we use the complex representation of a point: (x,y) -> x + y*i where i is the imaginary number (i * i = -1).

So, each pixel (i,j) is mapped to a complex number c = projection(i,j), and we want to evaluate how do the iterates of c behave under this iteration starting at z0 = 0.

Iteration number?

We are interested by assigning an iteration number to each c.

The number of iterations needs to be bounded (think of a periodic orbit). Let max_iter be the maximum number of iterations, for example 100.

If the orbit of c remains bounded, we assign an iteration number to max_iter.

If it escapes, meaning one iterate has a norm greater than 2, then we calculate the first index such that the iterate norm is greater than 2 (in absolute value as a complex, or its norm as a point).

Complex calculus interface

We will use two types of functions:

  • Elixir functions using def
  • Nx functions using defn; these use a special backend (EXLA with CPU or GPU if any)

The points of the 2D plane will be represented as complex numbers as the Mandelbrot map works with complex numbers.

The function z(n+1) = z(n) * z(n) + c takes a complex number and returns a complex number.

Below is a helper module to work with complex number in numerical Elixir.

We use numerical functions, declared with defn. All the arguments are treated as tensors .

defmodule Ncx do
  import Nx.Defn

  defn i() do
    Nx.Constants.i()
  end

  # primitive to build a complex scalar tensor
  defn new(x,y) do
    x + i() * y
  end

  # square norm
  defn sq_norm(z) do
    Nx.conjugate(z) |> Nx.dot(z) |> Nx.real()
  end
end
Enter fullscreen mode Exit fullscreen mode

Algorithm

Source: https://en.wikipedia.org/wiki/Plotting_algorithms_for_the_Mandelbrot_set

Input: image dimensions (eg w x h of 1500 x 1000), max iteration (eg 100)

Iterate over each pixel (i,j):

  • map it into the 2D plane: compute its "complex coordinates"
  • compute the iteration number
  • compute a colour
  • Sum-up and draw from the final tensor with Kino.

Pixel to complex plan mapping

This module transforms a couple (i,j) into a complex number.

› Notice that once you are in a numerical function, the arguments becomes "tensors", and a tensor can be of type complex "c64". It natively understands complex numbers.

defmodule Pixel do
  import Nx.Defn

  defn map(index, {h,w}, {top_left_x, top_left_y, bottom_right_x,bottom_right_y}) do

    scale_x = Nx.divide(bottom_right_x-top_left_x, w-1)
    scale_y = Nx.divide(bottom_right_y-top_left_y, h-1)
    # building a complex typed tensor
    Ncx.new(
      top_left_x + Nx.dot(index[1],scale_x),
      top_left_y + Nx.dot(index[0], scale_y)
    )
  end
end
Enter fullscreen mode Exit fullscreen mode

Orbit and iteration number

This module computes the iteration number for a given input c.

If |c|>2, then this point is unstable. Otherwise, we have to compute for each point whether it stays bounded or not.

If it is bounded, we get max_iter, otherwise a lower value.

It is also using numerical functions via defn.

We cannot use the recursion form we did earlier because numerical functions don't accept several headers as plain Elixir. Instead we run a specil while loop. Note how we use the Nx versions of cond, and also the double condition managed byNx.logical_and, and also the Nx version of cond. Also, true is 1.

defmodule Orbit do
  import Nx.Defn

  defn poly(z,c) do
    z*z + c
  end

  defn number(c,max_iter) do
    condition = (Nx.real(c) +1) ** 2 + (Nx.imag(c)**2)
    cond do
      # points in first cardioid are all stable. Save on iterations
      Nx.less(condition, 0.0625) ->
        max_iter
      # these points are unbounded whenever the norm is > 2
      Nx.greater(Ncx.sq_norm(c), 4) ->
        0
      # we have to evaluate each point as it can be or not bounded in the disk 2
      1 ->
          {_, _, j} =
            while {z=c, c, i=max_iter}, Nx.logical_and(Nx.greater(i,1), Nx.less(Ncx.sq_norm(z), 4)) do
                {poly(z,c), c,i-1}
            end
          max_iter - j
    end
  end
end
Enter fullscreen mode Exit fullscreen mode

Examples:

st = Ncx.new(0.2, 0.2)
dv1 = Ncx.new(0.4, 0.4)
dv2 = Ncx.new(0.3, 0.6)
dv3 = Ncx.new(2,2)

iter_max = 100

iter_dv1 = Orbit.number(dv1, iter_max) #<- we should find 8 iterations before z_n escapes from the disk 2
iter_dv2 = Orbit.number(dv2, iter_max) #<- we should find 14 iterations before z_n escapes from the disk 2
iter_dv3 = Orbit.number(dv3, iter_max)
iter_st  = Orbit.number(st, iter_max) #<- this point is stable and the loop reaches n interations.

%{
  "unstable/2:    #{Nx.to_number(dv2)}" => iter_dv2 |> Nx.to_number(),
  "unstable/1:    #{Nx.to_number(dv1)}" => iter_dv1 |> Nx.to_number(),
  "out_of_disk2:  #{Nx.to_number(dv3)}" => iter_dv3 |> Nx.to_number(),
  "stable:        #{Nx.to_number(st)}" => iter_st |> Nx.to_number(),
}
Enter fullscreen mode Exit fullscreen mode
%{
  "out_of_disk2:  2.0+2.0i" => 0,
  "stable:        0.20000000298023224+0.20000000298023224i" => 99,
  "unstable/1:    0.4000000059604645+0.4000000059604645i" => 8,
  "unstable/2:    0.30000001192092896+0.6000000238418579i" => 14
}
Enter fullscreen mode Exit fullscreen mode

A Colour palette

Each iteration number is an integer n. We want to associate a colour [r(n),g(n),b(n)].

This will help us to visualise which point of the complex plane is stable, and if not how fast it escapes.

The choice below is just an example. Other choices can be made.

defmodule Colour do
  import Nx.Defn

  defn normalize(n, max_iter) do
    n / max_iter
  end

  defn rgb(n) do
    cond do
      Nx.equal(n, 0) -> 
        Nx.stack([255, 255, 0]) |> Nx.as_type(:u8)
      Nx.less(n, 0.5) ->
        scaled = n * 2
        r = 255 * (1 - scaled)
        g = 255 * (1 - scaled/2)
        b = 127 * scaled
        Nx.stack([r, g, b]) |> Nx.as_type(:u8)
      true ->
        scaled = (n - 0.5) * 2;
        r = 255*(1+scaled/2)
        g = 128 * (1+scaled/2)
        b = 255 * (1 - scaled)
         Nx.stack([r, g, b]) |> Nx.as_type(:u8)
    end
  end
end
Enter fullscreen mode Exit fullscreen mode

Computing the Mandelbrot set

We will know reassemble our modules.

Firstly, an example.

dim = {500,500}; iter_max = 100

p = Nx.tensor([30,40])
c_i_j = Pixel.map(p,dim, defining_points)
n_i_j = Orbit.number(c_i_j, iter_max)
nm_i_j = Colour.normalize(n_i_j, iter_max)

{Nx.to_number(n_i_j), Colour.rgb(nm_i_j)} |> dbg()

p = Nx.tensor([40,70])
c_i_j = Pixel.map(p,dim, defining_points)
n_i_j = Orbit.number(c_i_j, iter_max)
nm_i_j = Colour.normalize(n_i_j, iter_max)

{Nx.to_number(n_i_j), Colour.rgb(nm_i_j)} |> dbg()

Enter fullscreen mode Exit fullscreen mode

We found that this pixel reached a point in the complex plan that escapes rather quickly from the disk 2. It get stamped with some colour.

{4,
 #Nx.Tensor<
   u8[3]
   EXLA.Backend<host:0, 0.3807096825.1655832596.50891>
   [234, 244, 10]
 >}
Enter fullscreen mode Exit fullscreen mode

The final module

We then reassemble the tensor into the desired format for Kino to consume it and display.

Note that you want to pass arguments into a defn function that you don't want to be treated as tensors, you need to use a keyword list or a map.

defmodule Mandelbrot do
  import Nx.Defn

  defn compute(opts) do
    top_left_x = -2; top_left_y = 1.2; bottom_right_x = 0.6; bottom_right_y = - 1.2;
    defining_points = {top_left_x, top_left_y, bottom_right_x, bottom_right_y}

    h = opts[:h]
    w = opts[:w]
    max_iter = opts[:max_iter]

    # build the tensor [[0,0],, ...[0,m], [1,1]...[n,m]]. Thks to PValente
    iota_rows = Nx.iota({h}, type: :u16) |> Nx.vectorize(:rows)
    iota_cols = Nx.iota({w}, type: :u16) |> Nx.vectorize(:cols)
    cross_product = Nx.stack([iota_rows, iota_cols])

    Pixel.map(cross_product,{h,w}, defining_points)
      |> Orbit.number(max_iter)
      |> Colour.normalize(max_iter)
      |> Colour.rgb()
      |> Nx.devectorize()
      |> Nx.reshape({h, w, 3})
      |> Nx.as_type(:u8)
  end
end

Enter fullscreen mode Exit fullscreen mode

Depending on your machine, the computation below can be lengthy.
If you want to simply evaluate, set h = w = 400.

h = w = 400;
Mandelbrot.compute(h: h, w: w, max_iter: 100)
|> Kino.Image.new()
Enter fullscreen mode Exit fullscreen mode

Parallelise it with async_stream

When the resolution of the image increases, it is interesting to parallelise the computations.

We divide the image in horizontal bands, as many as the number of CPU cores on the machine.

When you use async_stream, the BEAM - the VM that runs this code - parallelises the running code on the cores.

This is worth only if the size of the image is large enough as this comes with non negligible overhead.

We also set ordered: true as we need to sum-up the results in an ordered manner.

Another possible optimisation is to remark that the image is symmetric. You can compute half of the image (redefine h to be h-rem(h, cpus*2) but you would need to be able to reverse a tensor.

defmodule StreamMandelbrot do
  import Nx.Defn

    @doc"""
    Example: 42 rows, 8 cpus
    42 rows = 8cpus * 5 + 2
    We run 8 threads consuming 5 rows each
    We just ignore the last 2 rows.
    """
    def run(%{h: h, w: w} = opts) do
      cpus = :erlang.system_info(:logical_processors_available)
      # we eliminate a few rows from the final image, 8 at most.
      h = h - rem(h,cpus) 
      rows_per_cpu = div(h, cpus) 

      Task.async_stream(0..cpus-1, fn cpu_count -> 
          # we shift the start index by the number of rows already consummed
          iota_rows = Nx.iota({rows_per_cpu}, type: :u16) |> Nx.add(cpu_count * rows_per_cpu)|> Nx.vectorize(:rows)
          # full width
          iota_cols = Nx.iota({w}, type: :u16) |> Nx.vectorize(:cols)
          cross_product = Nx.stack([iota_rows, iota_cols])
          Nx.Defn.jit_apply(fn t -> 
            compute(t, opts) end, [cross_product])
          end, 
          timeout: :infinity, ordered: true
      )
      |> Enum.map(fn {:ok, t} -> t end) #&elem(&1, 1)
      |> Nx.stack()
      |> Nx.reshape({h,w,3})
  end



  defn compute(cross_product, %{h: h, w: w, max_iter: max_iter}) do
    top_left_x = -2; top_left_y = 1.2; bottom_right_x = 0.6; bottom_right_y = -1.2;
    defining_points = {top_left_x, top_left_y, bottom_right_x, bottom_right_y}

    Pixel.map(cross_product,{h,w}, defining_points)
    |> Orbit.number(max_iter)
    |> Colour.normalize(max_iter)
    |> Colour.rgb()
    |> Nx.devectorize()
    |> Nx.as_type(:u8)
  end
end
Enter fullscreen mode Exit fullscreen mode

When we run the code, we have much faster results. On my machine, it took 44s to draw a 1M pixels image. We get the expected performance boost.

h= w = 400;

StreamMandelbrot.run( %{h: h, w: w, max_iter: 200})   
|> Kino.Image.new()
Enter fullscreen mode Exit fullscreen mode

Run embedded Zig code

If we still need or want extra speed, we can also embed Zig code in Elixir within a Livebook.

Zigler offers a remarkable documentation.

❗ You may to have Zig installed on your machine.

In the Livebook, we add the dependencies (in the first cell):

Mix.install([{:zigler, "~> 0.13.3"},{:zig_get, "~> 0.13.1"},])
Enter fullscreen mode Exit fullscreen mode

With the Zigler, we can even inline Zig code.

The code below runs the same algorithm and runs OS threads for concurrency.

we use the beam memory allocator from the library.

the slice is returned as a binary - typed as beam.term - to be easily consumed by Nx and then Kino.

defmodule Zigit do
  use Zig, otp_app: :zigler, 
    nifs: [..., generate_mandelbrot: [:threaded]]
    # release_mode: :fast

  ~Z"""
    const beam = @import("beam");
    const std = @import("std");
    const Cx = std.math.Complex(f64);

    const topLeft = Cx{ .re = -2.1, .im = 1.2 };
    const bottomRight = Cx{ .re = 0.6, .im = -1.2 };
    const w = bottomRight.re - topLeft.re;
    const h = bottomRight.im - topLeft.im;

    const Context = struct {res_x: usize, res_y: usize, imax: usize};

    /// nif: generate_mandelbrot/3 Threaded
    pub fn generate_mandelbrot(res_x: usize, res_y: usize, max_iter: usize) !beam.term {
        const pixels = try beam.allocator.alloc(u8, res_x * res_y * 3);
        defer beam.allocator.free(pixels);

        const resolution = Context{ .res_x = res_x, .res_y = res_y, .imax = max_iter };

        const res = try createBands(pixels, resolution);
        return beam.make(res, .{ .as = .binary });
    }

    // <--- threaded version
    fn createBands(pixels: []u8, ctx: Context) ![]u8 {
        const cpus = try std.Thread.getCpuCount();
        var threads = try beam.allocator.alloc(std.Thread, cpus);
        defer beam.allocator.free(threads);

        // half of the total rows
        const rows_to_process = ctx.res_y / 2 + ctx.res_y % 2;
        // one band is one count of cpus
        // const nb_rows_per_band = rows_to_process / cpus + rows_to_process % cpus;
        const rows_per_band = (rows_to_process + cpus - 1) / cpus;

        for (0..cpus) |cpu_count| {
            const start_row = cpu_count * rows_per_band;

            // Stop if there are no rows to process
            if (start_row >= rows_to_process) break;

            const end_row = @min(start_row + rows_per_band, rows_to_process);
            const args = .{ ctx, pixels, start_row, end_row };
            threads[cpu_count] = try std.Thread.spawn(.{}, processRows, args);
        }
        for (threads[0..cpus]) |thread| {
            thread.join();
        }

        return pixels;
    }

    fn processRows(ctx: Context, pixels: []u8, start_row: usize, end_row: usize) void {
        for (start_row..end_row) |current_row| {
            processRow(ctx, pixels, current_row);
        }
    }

    fn processRow(ctx: Context, pixels: []u8, row_id: usize) void {
        // Calculate the symmetric row
        const sym_row_id = ctx.res_y - 1 - row_id;

        if (row_id <= sym_row_id) {
            // loop over columns
            for (0..ctx.res_x) |col_id| {
                const c = mapPixel(.{ @as(usize, @intCast(row_id)), @as(usize, @intCast(col_id)) }, ctx);
                const iter = iterationNumber(c, ctx.imax);
                const colour = createRgb(iter, ctx.imax);

                const p_idx = (row_id * ctx.res_x + col_id) * 3;
                pixels[p_idx + 0] = colour[0];
                pixels[p_idx + 1] = colour[1];
                pixels[p_idx + 2] = colour[2];

                // Process the symmetric row (if it's different from current row)
                if (row_id != sym_row_id) {
                    const sym_p_idx = (sym_row_id * ctx.res_x + col_id) * 3;
                    pixels[sym_p_idx + 0] = colour[0];
                    pixels[sym_p_idx + 1] = colour[1];
                    pixels[sym_p_idx + 2] = colour[2];
                }
            }
        }
    }

    fn mapPixel(pixel: [2]usize, ctx: Context) Cx {
        const px_width = ctx.res_x - 1;
        const px_height = ctx.res_y - 1;
        const scale_x = w / @as(f64, @floatFromInt(px_width));
        const scale_y = h / @as(f64, @floatFromInt(px_height));

        const re = topLeft.re + scale_x * @as(f64, @floatFromInt(pixel[1]));
        const im = topLeft.im + scale_y * @as(f64, @floatFromInt(pixel[0]));
        return Cx{ .re = re, .im = im };
    }

    fn iterationNumber(c: Cx, imax: usize) ?usize {
        if (c.re > 0.6 or c.re < -2.1) return 0;
        if (c.im > 1.2 or c.im < -1.2) return 0;

        // first cardiod
        if ((c.re + 1) * (c.re + 1) + c.im * c.im < 0.0625) return null;

        var z = Cx{ .re = 0.0, .im = 0.0 };
        for (0..imax) |j| {
            if (sqnorm(z) > 4) return j;
            z = Cx.mul(z, z).add(c);
        }
        return null;
    }

    fn sqnorm(z: Cx) f64 {
        return z.re * z.re + z.im * z.im;
    }

    fn createRgb(iter: ?usize, imax: usize) [3]u8 {
        // If it didn't escape, return black
        if (iter == null) return [_]u8{ 0, 0, 0 };

        // Normalize time to [0,1[ now that we know it isn't "null"
        const normalized = @as(f64, @floatFromInt(iter.?)) / @as(f64, @floatFromInt(imax));

        if (normalized < 0.5) {
            const scaled = normalized * 2;
            return [_]u8{ @as(u8, @intFromFloat(255 * (1 - scaled))), @as(u8, @intFromFloat(255.0 * (1 - scaled / 2))), @as(u8, @intFromFloat(127 + 128 * scaled)) };
        } else {
            const scaled = (normalized - 0.5) * 2.0;
            return [_]u8{ 0, @as(u8, @intFromFloat(127 * (1 - scaled / 2))), @as(u8, @intFromFloat(255 * (1 - scaled))) };
        }
    }

  """
end
Enter fullscreen mode Exit fullscreen mode

We run the Zig code. It returns a binary that we are able to consume with Nx and display the image.

To draw an image of 1M pixels, it takes a few milliseconds. Feels like magic.

h = w = 5_000
max_iter = 300;

Zigit.generate_mandelbrot(h, w, max_iter)
|> Nx.from_binary(:u8)
|> Nx.reshape({h, w, 3})
|> Kino.Image.new()
Enter fullscreen mode Exit fullscreen mode

Top comments (0)