Image by Pramoon Design on Shutterstock

Working with Nx Containers

Nx.Container for Nested Data Structures in Elixir

Sean Moriarity
The Pragmatic Programmers
4 min readJun 29, 2022

--

https://pragprog.com/newsletter/

In numerical computing, a common task is the need to work with collections of tensors. These collections could come in the form of a tuple, map, or other possibly nested data structure. For example, you might want to store the parameters of a model such as a neural network in a map, so that it’s easy to identify which parameters belong to which layers. Nx makes the process of working with arbitrarily nested containers easy with the Nx.Container API.

If you’re familiar with the Python machine learning ecosystem, you might be familiar with the tf.nest API, or JAX’s pytrees . Both tf.nest and pytrees create abstractions for working with tree-like data structures. In both APIs, tensors are treated as leaf or atom values. Functions such as tf.nest.map_structure and JAX’s tree_map work similarly to Elixir’s Enum.map — but they apply functions to atomic values in a nested structure, preserving the original tree-like structure. For example, if you had the collection:

{'foo': (1, {'bar': 2, 'baz': (3, 4, (5,))})}

Then applying a map_structure or tree_map on the nested structure with the function x * x would result in a new structure:

{'foo': (1, {'bar': 4, 'baz': (9, 16, (25,))})}

Nx offers a similar API with Nx.Container . Nx.Container allows users to traverse over arbitrarily nested containers, and to register custom data structures as containers to be used with Nx’s numerical definitions. Each Nx container needs to implement two functions, or @derive an implementation from the default Nx.Container protocol. The two functions in the protocol are reduce and traverse, which have the following signatures:

reduce(t(), acc, (Nx.Tensor.t(), acc -> acc)) :: acctraverse(t(), acc, (Nx.Tensor.t(), acc -> {term(), acc})) :: acc

reduce applies a function to leaf values within the container, carrying and returning an accumulator. traverse does the same; however, it also returns updated leaf values along with the updated accumulator. It’s important to note that both of these functions do not apply functions in a nested manner, but rather leave the responsibility of handling nested traversals to the user. Given these two functions, you can implement similar functionality to tree_map or tf.nest.map_structure :

defn tree_map(container, fun) do
{container, :ok} = Nx.Container.traverse(container, :ok, &do_tree_map(&1, &2, fun)
end
defp do_tree_map(container_or_tensor, :ok, fun) do
case container_or_tensor do
%Nx.Tensor{} = leaf ->
{fun.(leaf), :ok}
container ->
{tree_map(container, fun), :ok}
end
end

You can also implement more complex functional transformations such as reduce , map_reduce , zip , and so on.

Nx supports container implementations for maps and tuples out of the box. Adding support for a custom container such as a custom struct is as easy as using Elixir’s @derive for your custom struct:

defmodule MyData do

@derive {Nx.Container,
containers: [:value_1, :value_2],
keep: [:option]}
defstruct [:option, :value_1, :value_2]
end

:keep specifies which fields of a struct should be kept or preserved when the struct is used within defn . This is useful for preserving metadata such as atoms. :containers specify which fields contain tensors (or other containers). By implementing Nx.Container for your struct, you can use custom data structures within defn without any issues:

defmodule Functions do
import Nx.Defn
defn function_on_my_data(my_data) do
operation = get_operation(my_data)
operation.(my_data.value_1, my_data.value_2)
end
defnp get_operation(my_data) do
transform(my_data, fn %MyData{option: op} -> &apply(Nx, op, [&1, &2]) end)
end
end
data = struct(MyData, %{option: :add, value_1: 1, value_2: 2})
Functions.function_on_my_data(data) |> IO.inspect

Nx.Container is a flexible abstraction which makes it easy to work with nested data structures and to register custom data structures within Nx’s defn . For more information, I recommend reading the Nx documentation and checking out some container implementations in the wild such as in Axon.

Sean Moriarity and José Valim worked together to create Elixir’s Nx library. Sean has also written a book, Genetic Algorithms in Elixir, with The Pragmatic Bookshelf:

You can save 35% on the ebook version with promo code smgaelixir_35 through July 31, 2022.

--

--