Writing a Genetic Algorithm in Nx

Numerical Computing for Elixir

Genetic Algorithms in Elixir

Introducing Nx for Elixir

Requirements for Our Elixir Nx Project

Image via https://github.com/elixir-nx/nx/blob/main/nx/nx.png
Mix.install([
{:exla, “~> 0.1.0-dev”, github: “elixir-nx/nx”, sparse: “exla”},
{:nx, “~> 0.1.0-dev”, github: “elixir-nx/nx”, sparse: “nx”, override: true},
{:scidata, “~> 0.1.0”}
])

Setting Up the Data

{images, _} = Scidata.MNIST.download()
{data, type, shape} = images
{_, channels, height, width} = shape
image_size = channels * height * width
<<image::size(image_size)-binary, _::binary>> = data
example_image =
image
|> Nx.from_binary(type)
|> Nx.reshape({channels, height, width})
|> Nx.divide(255.0)
example_image |> Nx.to_heatmap()

Defining the Algorithm

Image by hobbit on Shutterstock

Initialization

population_size = 1000initialize = fn ->
Nx.random_uniform(
{population_size, image_size},
backend: Nx.Defn.Expr
)
end

Evaluation

evaluate = fn population, example ->
# Reshape example so it broadcasts for each image in population
example =
example
|> Nx.flatten()
|> Nx.new_axis(0)
# Calculate MSE between each image and example
population
|> Nx.subtract(example)
|> Nx.power(2)
|> Nx.mean(axes: [-1])
end

Selection

select = fn population, target_image ->
population
|> evaluate.(target_image)
|> Nx.argsort()
|> then(&Nx.take(population, &1))
end

Crossover

crossover = fn population ->
{population_size, _} = Nx.shape(population)
half_pop = div(population_size, 2)
even_idx = Nx.multiply(Nx.iota({half_pop}), 2)
odd_idx = Nx.add(Nx.multiply(Nx.iota({half_pop}), 2), 1)
{evens, odds} = {
Nx.take(population, even_idx),
Nx.take(population, odd_idx)
}
children = Nx.divide(Nx.add(evens, odds), 2)
Nx.concatenate([children, children], axis: 0)
end

Mutation

mutate = fn population ->
mask = Nx.random_uniform(Nx.shape(population), backend: Nx.Defn.Expr)
noise = Nx.random_uniform(Nx.shape(population), -0.15, 0.15, backend: Nx.Defn.Expr)
Nx.select(Nx.less(mask, 0.4), Nx.add(population, noise), population)
|> Nx.clip(0, 1)
end

The Algorithm

evolve = fn population, target_image ->
population
|> select.(target_image)
|> crossover.()
|> mutate.()
end
population = Nx.Defn.jit(initialize, [], compiler: EXLA)final_population =
Enum.reduce(1..2500, population, fn i, population ->
population = Nx.Defn.jit(evolve, [population, example_image], compiler: EXLA)
best =
Nx.Defn.jit(
fn population, example_image ->
population
|> evaluate.(example_image)
|> Nx.reduce_min()
end,
[population, example_image],
compiler: EXLA
)
|> Nx.to_scalar()
IO.write(“\rGeneration: #{i} Best: #{:io_lib.format(‘~.5f’, [best])}”)population
end)
# Visualize the top 3
final_population
|> select.(example_image)
|> Nx.slice_axis(0, 3, 0)
|> Nx.reshape({3, 28, 28})
|> Nx.to_heatmap()