Working with Nx Containers
Nx.Container
for Nested Data Structures in Elixir
In numerical computing, a common task is the need to work with collections of tensors. These collections could come in the form of a tuple, map, or other possibly nested data structure. For example, you might want to store the parameters of a model such as a neural network in a map, so that it’s easy to identify which parameters belong to which layers. Nx makes the process of working with arbitrarily nested containers easy with the Nx.Container
API.
If you’re familiar with the Python machine learning ecosystem, you might be familiar with the tf.nest
API, or JAX’s pytrees
. Both tf.nest
and pytrees
create abstractions for working with tree-like data structures. In both APIs, tensors are treated as leaf or atom values. Functions such as tf.nest.map_structure
and JAX’s tree_map
work similarly to Elixir’s Enum.map
— but they apply functions to atomic values in a nested structure, preserving the original tree-like structure. For example, if you had the collection:
{'foo': (1, {'bar': 2, 'baz': (3, 4, (5,))})}
Then applying a map_structure
or tree_map
on the nested structure with the function x * x
would result in a new structure:
{'foo': (1, {'bar': 4, 'baz': (9, 16, (25,))})}
Nx offers a similar API with Nx.Container
. Nx.Container
allows users to traverse over arbitrarily nested containers, and to register custom data structures as containers to be used with Nx’s numerical definitions. Each Nx container needs to implement two functions, or @derive
an implementation from the default Nx.Container
protocol. The two functions in the protocol are reduce
and traverse
, which have the following signatures:
reduce(t(), acc, (Nx.Tensor.t(), acc -> acc)) :: acctraverse(t(), acc, (Nx.Tensor.t(), acc -> {term(), acc})) :: acc
reduce
applies a function to leaf values within the container, carrying and returning an accumulator. traverse
does the same; however, it also returns updated leaf values along with the updated accumulator. It’s important to note that both of these functions do not apply functions in a nested manner, but rather leave the responsibility of handling nested traversals to the user. Given these two functions, you can implement similar functionality to tree_map
or tf.nest.map_structure
:
defn tree_map(container, fun) do
{container, :ok} = Nx.Container.traverse(container, :ok, &do_tree_map(&1, &2, fun)
enddefp do_tree_map(container_or_tensor, :ok, fun) do
case container_or_tensor do
%Nx.Tensor{} = leaf ->
{fun.(leaf), :ok} container ->
{tree_map(container, fun), :ok}
end
end
You can also implement more complex functional transformations such as reduce
, map_reduce
, zip
, and so on.
Nx supports container implementations for maps and tuples out of the box. Adding support for a custom container such as a custom struct is as easy as using Elixir’s @derive
for your custom struct:
defmodule MyData do
@derive {Nx.Container,
containers: [:value_1, :value_2],
keep: [:option]}
defstruct [:option, :value_1, :value_2]end
:keep
specifies which fields of a struct should be kept or preserved when the struct is used within defn
. This is useful for preserving metadata such as atoms. :containers
specify which fields contain tensors (or other containers). By implementing Nx.Container
for your struct, you can use custom data structures within defn
without any issues:
defmodule Functions do
import Nx.Defndefn function_on_my_data(my_data) do
operation = get_operation(my_data)
operation.(my_data.value_1, my_data.value_2)
enddefnp get_operation(my_data) do
transform(my_data, fn %MyData{option: op} -> &apply(Nx, op, [&1, &2]) end)
end
enddata = struct(MyData, %{option: :add, value_1: 1, value_2: 2})
Functions.function_on_my_data(data) |> IO.inspect
Nx.Container
is a flexible abstraction which makes it easy to work with nested data structures and to register custom data structures within Nx’s defn
. For more information, I recommend reading the Nx documentation and checking out some container implementations in the wild such as in Axon.