Dissecting DM Graphcast

Bryan
12 min readJun 3, 2024

--

Graphcast is a MLWP approach for global medium-range weather forecasting. It is capable to produce accurate multi-day forecast in under a minute on a single TPU v4 device, and supports applications including predicting tropical cyclone tracks, atmospheric rivers, and extreme temperatures. (Lam et al., 2022, p. 2)

It is implemented based on Graph Neural Networks (GNN) architecture, with a total of 36.7 million parameters. The model takes as input the two most recent states of Earth’s weather—the current time and six hours earlier—and predicts the next state of the weather six hours ahead. (Lam et al., 2022, p. 2, 4)

The Graphcast model offers countless aspects worth studying and examining. This article mainly focuses on the Graphcast package implemented by DeepMind (https://github.com/google-deepmind/graphcast), with a particular emphasis on the construction of its GNNs. The package uses JAX framework, plus a variety of tools including Haiku, Jraph, XArray, Chex, etc.

Most of the knowledge and illustration are derived from the paper “Graphcast: Learning skillful medium-range global weather forecasting”. For more information or details, please refer to the original paper (https://arxiv.org/abs/2212.12794).

Dataset

First we examine the training data. We use one sample of ERA5 archive dataset. It contains the global weather at 0.25° latitude/longitude resolution with static, surface, and atmospheric variables.

gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")
gcs_path = "dataset/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc"
with gcs_bucket.blob(gcs_path).open("rb") as f:
example_batch = xarray.load_dataset(f).compute()

Here is the model_config and task_config setting:

model_config = graphcast.ModelConfig(
resolution=0,
mesh_size=4,
latent_size=32,
gnn_msg_steps=1,
hidden_layers=1,
radius_query_fraction_edge_length=0.6)
task_config = graphcast.TaskConfig(
input_variables=graphcast.TASK.input_variables,
target_variables=graphcast.TASK.target_variables,
forcing_variables=graphcast.TASK.forcing_variables,
pressure_levels=graphcast.PRESSURE_LEVELS[13],
input_duration=graphcast.TASK.input_duration,
)

After loading one sample data point, we separate it into inputs, forcings, and targets.

train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
example_batch, target_lead_times=slice("6h", f"6h"), **dataclasses.asdict(task_config))

Inputs

Inputs are time-dependent weather state, which includes all the predicted data variables for all 37 atmospheric levels as well as surface variables. A single weather state is represented by a 0.25° latitude/longitude grid, where each grid point represents a set of surface and atmospheric variables. An input sample has 2 timestamps because the model requires current + a previous state of the weather to make prediction. (Lam et al., 2022, p. 26)

Forcings

The forcing terms consist of time-dependent features that can be computed analytically, and do not need to be predicted by GraphCast. They include the total incident solar radiation at the top of the atmosphere (need to be computed manually for the target lead times), the sine and cosine of the local time of day, and the sine and cosine of the of year progress. (Lam et al., 2022, p. 26)

Inputs and forcings are the main components of the node features in Graphcast GNNs.

Targets

Targets serve as ground truth for comparing predictions, computing loss, and ultimately calculating gradients.

Technical Terms

Graph

In Graphcast, a graph object is defined as below:

class TypedGraph(NamedTuple):
context: Context
nodes: Mapping[str, NodeSet]
edges: Mapping[EdgeSetKey, EdgeSet]

A typed graph is made of a context, multiple sets of nodes and multiple sets of edges connecting those nodes. Nodes, edges and global context have their own feature arrays.

class Context(NamedTuple):
n_graph: ArrayLike
features: ArrayLikeTree

class NodeSet(NamedTuple):
n_node: ArrayLike
features: ArrayLikeTree

class EdgeSet(NamedTuple):
n_edge: ArrayLike
indices: EdgesIndices
features: ArrayLikeTree

MLP

A Multilayer Perceptron (MLP) is a class of feedforward neural network composed of multiple layers of nodes, or perceptrons, with each node using a nonlinear activation function. MLPs consist of an input layer, one or more hidden layers, and an output layer. Each layer is fully connected to the next one. They are used to model complex relationships and are capable of learning and representing nonlinear functions.

The layers in MLP are Dense/Fully-connected layers. Each layer can be expressed as y=g(Wx+b) in algebra. In Graphcast it is defined as

class DeepTypedGraphNet():
def _networks_builder():
def build_mlp(name, output_size):
mlp = hk.nets.MLP(
output_sizes=[self._mlp_hidden_size] * self._mlp_num_hidden_layers + [
output_size], name=name + "_mlp", activation=self._activation)
# activation="swish"
......

The number of hidden layers and output size are configurable.

Graph Networks (GN)

Graph Networks are a type of model designed to handle data structured as graphs. It is powerful for tasks where the relationships between entities are crucial, such as social network analysis, molecular modeling. You can learn more about it from this article Basics about Graph Networks.

Architecture Overview

GraphCast is implemented as a message passing GNN architecture, with an “encode-process-decode” configuration.

Figure 1 (Lam et al., 2022, p. 3)

The encoder first maps the input data, from the original latitude-longitude grid, into learned features on the multi-mesh, using a GNN with directed edges from the grid points to the multi-mesh. The processor then uses a deep GNN to perform learned message-passing on the multi-mesh, allowing efficient propagation of information across space due to the long-range edges. The decoder then maps the final multi-mesh representation back to the latitude-longitude grid using a GNN with directed edges. (Lam et al., 2022, p. 26)

class GraphCast():
def __call__():
......
grid_node_features = self._inputs_to_grid_node_features(inputs, forcings)
(latent_mesh_nodes, latent_grid_nodes) = self._run_grid2mesh_gnn(grid_node_features)
updated_latent_mesh_nodes = self._run_mesh_gnn(latent_mesh_nodes)
output_grid_nodes = self._run_mesh2grid_gnn(updated_latent_mesh_nodes, latent_grid_nodes)
return self._grid_node_outputs_to_prediction(output_grid_nodes, targets_template)

Grid Nodes

Each grid node represents a vertical slice of the atmosphere at a given latitude-longitude point. At 0.25° resolution, there is a total of 721 × 1440 = 1,038,240 grid nodes.

Figure 1 (Williamson, 2007, p. 243)

Mesh Nodes

Mesh nodes are placed uniformly around the globe in a R-refined icosahedral mesh 𝑀𝑅. 𝑀0 corresponds to a unit-radius icosahedron (12 nodes and 20 triangular faces) with faces parallel to the poles.

Figure 1 (Lam et al., 2022, p. 3)

The mesh is iteratively refined 𝑀𝑟 → 𝑀𝑟+1 by splitting each triangular face into 4 smaller faces, resulting in an extra node in the middle of each edge, and re-projecting the new nodes back onto the unit sphere. Features associated with each mesh node include the cosine of the latitude, and the sine and cosine of the longitude. GraphCast works with a mesh that has been refined 𝑅 = 6 times, 𝑀6, resulting in 40,962 mesh nodes, each with the 3 input features. (Lam et al., 2022, p. 27)

Initiate Grid2Mesh subgraph

This step precomputes structural node and edge features according to config options. Structural features are those that depend on the fixed values of the latitude and longitudes of the nodes.

class GraphCast():
def _init_grid2mesh_graph():
......
(grid_indices, mesh_indices) = grid_mesh_connectivity.radius_query_indices(......)
......
(senders_node_features, receivers_node_features, edge_features)
= model_utils.get_bipartite_graph_spatial_features(......)

grid_node_set = typed_graph.NodeSet(n_node=n_grid_node, features=senders_node_features)
mesh_node_set = typed_graph.NodeSet(n_node=n_mesh_node, features=receivers_node_features)
edge_set = typed_graph.EdgeSet(
n_edge=n_edge,
indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers),
features=edge_features)
nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set}
edges = {
typed_graph.EdgeSetKey("grid2mesh", ("grid_nodes", "mesh_nodes")):
edge_set
}
grid2mesh_graph = typed_graph.TypedGraph(
context=typed_graph.Context(n_graph=np.array([1]), features=()),
nodes=nodes,
edges=edges)
return grid2mesh_graph

Convert inputs to grid node features

In previous section, we have extracted inputs and forcings from sample data point. Now they are stacked together and added to Grid2Mesh subgraph.

class GraphCast():
def __call__():
......
grid_node_features = self._inputs_to_grid_node_features(inputs, forcings)
......

def _inputs_to_grid_node_features():
stacked_inputs = model_utils.dataset_to_stacked(inputs)
stacked_forcings = model_utils.dataset_to_stacked(forcings)
stacked_inputs = xarray.concat([stacked_inputs, stacked_forcings], dim="channels")
grid_xarray_lat_lon_leading = model_utils.lat_lon_to_leading_axes(stacked_inputs)
return xarray_jax.unwrap(grid_xarray_lat_lon_leading.data).reshape(
(-1,) + grid_xarray_lat_lon_leading.data.shape[2:])

After stacking inputs and forcings’ channels, we get 183 channels in total (at each grid point).

grid_node_features.shape  # (1038240, 1, 183)

How do we get this number? Take variable “temperature” for instance, there are 2 timestamps and 13 levels, so “temperature” has 26 channels. on the other hand, variable “day_progress_sin” has 2 timestamps and no levels, so it has 2 channels. After combining all the variables, we get 178 channels from inputs, and 5 channels from forcings. That’s 183 channels in total.

Now we have the feature matrix which will be combined with the structural features in the input graph.

class GraphCast():
def _run_grid2mesh_gnn():
......
grid_nodes = grid2mesh_graph.nodes["grid_nodes"]
mesh_nodes = grid2mesh_graph.nodes["mesh_nodes"]
new_grid_nodes = grid_nodes._replace(
features=jnp.concatenate([
grid_node_features,
_add_batch_second_axis(
grid_nodes.features.astype(grid_node_features.dtype),
batch_size)
],
axis=-1))

dummy_mesh_node_features = jnp.zeros(
(self._num_mesh_nodes,) + grid_node_features.shape[1:],
dtype=grid_node_features.dtype)
new_mesh_nodes = mesh_nodes._replace(
features=jnp.concatenate([
dummy_mesh_node_features,
_add_batch_second_axis(
mesh_nodes.features.astype(dummy_mesh_node_features.dtype),
batch_size)
],
axis=-1))

grid2mesh_edges_key = grid2mesh_graph.edge_key_by_name("grid2mesh")
edges = grid2mesh_graph.edges[grid2mesh_edges_key]

new_edges = edges._replace(
features=_add_batch_second_axis(
edges.features.astype(dummy_mesh_node_features.dtype), batch_size))

input_graph = self._grid2mesh_graph_structure._replace(
edges={grid2mesh_edges_key: new_edges},
nodes={
"grid_nodes": new_grid_nodes,
"mesh_nodes": new_mesh_nodes
})

Notice the mesh nodes and grid2mesh edges don’t get any information from inputs/forcings, they only contain structural features.

Grid2Mesh GNN: encoder

The purpose of the encoder is to prepare data into latent representations for the processor.

Only the Grid2Mesh subgraph (grid nodes, mesh nodes, grid2node edges) is involved in this step.

Figure g2m_gnn

First, each of the grid2mesh edges are updated using information from the adjacent nodes.

def GraphNetwork():
def _apply_graph_net():
......
for edge_set_name, edge_fn in update_edge_fn.items():
edge_set_key = graph.edge_key_by_name(edge_set_name)
updated_edges[edge_set_key] = _edge_update(updated_graph, edge_fn, edge_set_key)
......

def _edge_update():
......
sent_attributes = tree.tree_map(lambda n: n[senders], sender_nodes.features)
received_attributes = tree.tree_map(lambda n: n[receivers], receiver_nodes.features)
......
new_features = edge_fn(
edge_set.features, sent_attributes, received_attributes,
global_features)
return edge_set._replace(features=new_features)

In the code above, there is only one type of edge in update_edge_fn: “grid2mesh”. Its edge_fn is:

def build_mlp_with_maybe_layer_norm():
network = build_mlp(name, output_size)
if self._use_layer_norm:
layer_norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True,
name=name + "_layer_norm")
network = hk.Sequential([network, layer_norm])
return jraph.concatenated_args(network)

Then each of the mesh nodes is updated by aggregating information from all of the edges arriving at that mesh node. Each of the grid nodes are also updated, but with no aggregation, because grid nodes are not receivers of any edges in the Grid2Mesh subgraph.

def GraphNetwork():
def _apply_graph_net():
......
for node_set_key, node_fn in update_node_fn.items():
updated_nodes[node_set_key] = _node_update(updated_graph, node_fn, node_set_key, aggregate_edges_for_nodes_fn)
......

def _node_update():
......
for edge_set_key, edge_set in graph.edges.items():
......
sent_features[edge_set_key.name] = tree.tree_map(
lambda e: aggregation_fn(e, senders, sum_n_node), edge_set.features)
......
for edge_set_key, edge_set in graph.edges.items():
......
received_features[edge_set_key.name] = tree.tree_map(
lambda e: aggregation_fn(e, receivers, sum_n_node), edge_set.features)
......
new_features = node_fn(
node_set.features, sent_features, received_features, global_features)
......

In the code above, there are two types of nodes in update_node_fn: “grid_nodes” and “mesh_nodes”. Both of their node_fn are defined the same way as edge_fn: MLP + layer norm.

After updating all the nodes and edges elements, the model includes a residual connection, and reassigns the variables. (Lam et al., 2022, p. 28)

Initiate Mesh subgraph

The structural node features don’t need to be recomputed, because they are already part of the latent state via previous step. However, it is necessary to precompute structural features for the Mesh edges, because it is the first time this particular set of edges appears.

class GraphCast():
def _init_mesh_graph(self) -> typed_graph.TypedGraph:
merged_mesh = icosahedral_mesh.merge_meshes(self._meshes)
senders, receivers = icosahedral_mesh.faces_to_edges(merged_mesh.faces)
......
node_features, edge_features = model_utils.get_graph_spatial_features(......)
......
mesh_node_set = typed_graph.NodeSet(
n_node=n_mesh_node, features=node_features)
edge_set = typed_graph.EdgeSet(
n_edge=n_edge,
indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers),
features=edge_features)
nodes = {"mesh_nodes": mesh_node_set}
edges = {
typed_graph.EdgeSetKey("mesh", ("mesh_nodes", "mesh_nodes")): edge_set
}
mesh_graph = typed_graph.TypedGraph(
context=typed_graph.Context(n_graph=np.array([1]), features=()),
nodes=nodes,
edges=edges)

return mesh_graph

Then add the features from latent mesh nodes from previous step (the original Grid2Mesh GNN).

class GraphCast():  
def _run_mesh_gnn(self, latent_mesh_nodes: chex.Array) -> chex.Array:
mesh_graph = self._mesh_graph_structure
......
mesh_edges_key = mesh_graph.edge_key_by_name("mesh")
edges = mesh_graph.edges[mesh_edges_key]
......
new_edges = edges._replace(
features=_add_batch_second_axis(
edges.features.astype(latent_mesh_nodes.dtype), batch_size))
nodes = mesh_graph.nodes["mesh_nodes"]
nodes = nodes._replace(features=latent_mesh_nodes)
input_graph = mesh_graph._replace(
edges={mesh_edges_key: new_edges}, nodes={"mesh_nodes": nodes})
......

Mesh GNN: processor

The processor is a deep GNN that operates on the Mesh subgraph which only contains the Mesh nodes and and the Mesh edges.

Figure m_gnn

The number of layers is configurable: (Figure m_gnn only illustrates a single layer)

class DeepTypedGraphNet():
def __init__():
......
self._num_message_passing_steps = num_message_passing_steps
......

def _networks_builder():
......
for step_i in range(self._num_message_passing_steps):
self._processor_networks.append(
typed_graph_net.InteractionNetwork(......)
)
......

Here is how a single layer of the Mesh GNN works. First it updates each of the mesh edges using information of the adjacent nodes.

def GraphNetwork():
def _apply_graph_net():
......
for edge_set_name, edge_fn in update_edge_fn.items():
edge_set_key = graph.edge_key_by_name(edge_set_name)
updated_edges[edge_set_key] = _edge_update(updated_graph, edge_fn, edge_set_key)
......

def _edge_update(graph, edge_fn, edge_set_key):
......
sent_attributes = tree.tree_map(lambda n: n[senders], sender_nodes.features)
received_attributes = tree.tree_map(lambda n: n[receivers], receiver_nodes.features)
......
new_features = edge_fn(
edge_set.features, sent_attributes, received_attributes,
global_features)
return edge_set._replace(features=new_features)

In the code above, there is only one type of edges in update_edge_fn: “mesh”. Its edge_fn is still MLP + layer norm.

Then it updates each of the mesh nodes, aggregating information from all of the edges arriving at that mesh node.

def GraphNetwork():
def _apply_graph_net():
......
for node_set_key, node_fn in update_node_fn.items():
updated_nodes[node_set_key] = _node_update(updated_graph, node_fn, node_set_key, aggregate_edges_for_nodes_fn)
......

def _node_update():
......
for edge_set_key, edge_set in graph.edges.items():
......
sent_features[edge_set_key.name] = tree.tree_map(
lambda e: aggregation_fn(e, senders, sum_n_node), edge_set.features)
......
for edge_set_key, edge_set in graph.edges.items():
......
received_features[edge_set_key.name] = tree.tree_map(
lambda e: aggregation_fn(e, receivers, sum_n_node), edge_set.features)
......
new_features = node_fn(
node_set.features, sent_features, received_features, global_features)
......

In the code above, there is only one type of nodes in update_node_fn: “mesh_nodes”. Its node_fn is still MLP + layer norm.

After updating both, the representations are updated with a residual connection and reassigned to the input variables. (Lam et al., 2022, p. 29)

Initiate Mesh2Grid subgraph

The edges in this graph are created according to how the grid nodes are contained by mesh triangles. The structural features of the nodes don’t need to be recomputed, because these are already part of the latent state via previous step. However, it is necessary to precompute the structural features for the Mesh2Grid edges, because it is the first time this particular set of edges appears.

def _init_mesh2grid_graph(self) -> typed_graph.TypedGraph:
(grid_indices,
mesh_indices) = grid_mesh_connectivity.in_mesh_triangle_indices(......)
......

(senders_node_features, receivers_node_features, edge_features)
= model_utils.get_bipartite_graph_spatial_features(......)
......
grid_node_set = typed_graph.NodeSet(
n_node=n_grid_node, features=receivers_node_features)
mesh_node_set = typed_graph.NodeSet(
n_node=n_mesh_node, features=senders_node_features)
edge_set = typed_graph.EdgeSet(
n_edge=n_edge,
indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers),
features=edge_features)
nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set}
edges = {
typed_graph.EdgeSetKey("mesh2grid", ("mesh_nodes", "grid_nodes")):
edge_set
}
mesh2grid_graph = typed_graph.TypedGraph(
context=typed_graph.Context(n_graph=np.array([1]), features=()),
nodes=nodes, edges=edges)
return mesh2grid_graph

Then the features of latent_grid_nodes from previous step (the original Grid2Mesh GNN) are added to Mesh2Grid subgraph.

def _run_mesh2grid_gnn():
......
mesh2grid_graph = self._mesh2grid_graph_structure
......
mesh_nodes = mesh2grid_graph.nodes["mesh_nodes"]
grid_nodes = mesh2grid_graph.nodes["grid_nodes"]
new_mesh_nodes = mesh_nodes._replace(features=updated_latent_mesh_nodes)
new_grid_nodes = grid_nodes._replace(features=latent_grid_nodes)
mesh2grid_key = mesh2grid_graph.edge_key_by_name("mesh2grid")
edges = mesh2grid_graph.edges[mesh2grid_key]
new_edges = edges._replace(
features=_add_batch_second_axis(
edges.features.astype(latent_grid_nodes.dtype), batch_size))
input_graph = mesh2grid_graph._replace(
edges={mesh2grid_key: new_edges},
nodes={
"mesh_nodes": new_mesh_nodes,
"grid_nodes": new_grid_nodes
})
......

Mesh2Grid GNN: decoder

The role of the decoder is to bring back information to the grid, and extract an output.

Figure m2g_gnn

the Mesh2Grid GNN performs a single message passing over the Mesh2Grid subgraph (grid nodes, mesh nodes, mesh2grid edges).

The GNN first updates each of the grid2mesh edges using information of the adjacent nodes.

def GraphNetwork():
def _apply_graph_net():
......
for edge_set_name, edge_fn in update_edge_fn.items():
edge_set_key = graph.edge_key_by_name(edge_set_name)
updated_edges[edge_set_key] = _edge_update(updated_graph, edge_fn, edge_set_key)
......

def _edge_update():
......
sent_attributes = tree.tree_map(lambda n: n[senders], sender_nodes.features)
received_attributes = tree.tree_map(lambda n: n[receivers], receiver_nodes.features)
......
new_features = edge_fn(
edge_set.features, sent_attributes, received_attributes,
global_features)
return edge_set._replace(features=new_features)

In the code above, there is only one type of edge in update_edge_fn: “mesh2grid”, and its edge_fn is MLP + layer norm.

Then it updates each of the grid nodes, aggregating information from all of the edges arriving at that grid node. (the mesh nodes are not updated as they won’t play any role from this point on.)

def GraphNetwork():
def _apply_graph_net():
......
for node_set_key, node_fn in update_node_fn.items():
updated_nodes[node_set_key] = _node_update(updated_graph, node_fn, node_set_key, aggregate_edges_for_nodes_fn)
......

def _node_update():
......
for edge_set_key, edge_set in graph.edges.items():
......
sent_features[edge_set_key.name] = tree.tree_map(
lambda e: aggregation_fn(e, senders, sum_n_node), edge_set.features)
......
for edge_set_key, edge_set in graph.edges.items():
......
received_features[edge_set_key.name] = tree.tree_map(
lambda e: aggregation_fn(e, receivers, sum_n_node), edge_set.features)
......
new_features = node_fn(
node_set.features, sent_features, received_features, global_features)
......

Here again a residual connection is added, and the variables are reassigned, this time only for the grid nodes.

Eventually the prediction for each of the grid nodes is produced using another MLP, which contains all predicted variables for that grid node. (Lam et al., 2022, p. 30)

class GraphCast():
def __init__():
......
self._mesh2grid_gnn = deep_typed_graph_net.DeepTypedGraphNet(
node_output_size=dict(grid_nodes=num_outputs)
......
)

class DeepTypedGraphNet():
def _networks_builder():
......
output_kwargs = dict(
......
embed_node_fn=_build_update_fns_for_node_types(
build_mlp, graph_template, "decoder_nodes_", self._node_output_size)
if self._node_output_size else None
)
self._output_network = typed_graph_net.GraphMapFeatures(**output_kwargs)

Convert grid node features to prediction

Finally the channels in the grid node features are restored to proper variables.

class GraphCast():  
def _grid_node_outputs_to_prediction():
......
grid_shape = (self._grid_lat.shape[0], self._grid_lon.shape[0])
grid_outputs_lat_lon_leading = grid_node_outputs.reshape(
grid_shape + grid_node_outputs.shape[1:])
dims = ("lat", "lon", "batch", "channels")
grid_xarray_lat_lon_leading = xarray_jax.DataArray(
data=grid_outputs_lat_lon_leading, dims=dims)
grid_xarray = model_utils.restore_leading_axes(grid_xarray_lat_lon_leading)
return model_utils.stacked_to_dataset(grid_xarray.variable, targets_template)

The prediction array (figure below) has the exact same structure as the targets

Summary

In previous section we explored the construction and logic of the three GNNs (encoder, processor, decoder) in the Graphcast model, detailing how graphs flows through these networks and how messages are passed between nodes and edges. While this article provides an overview, there are countless aspects that warrant further study and examination.

To use Graphcast model in an application, we can simply use the following code to initiate parameters or make forward pass predictions.

@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
predictor = graphcast.GraphCast(model_config, task_config)
return predictor(inputs, targets_template, forcings)

params, state = run_forward.init(
rng=jax.random.PRNGKey(0),
model_config=model_config,
task_config=task_config,
inputs=train_inputs,
targets_template=train_targets * np.nan,
forcings=train_forcings)

predictions, state = run_forward.apply(
rng=jax.random.PRNGKey(0),
params=params,
state=state,
model_config=model_config,
task_config=task_config,
inputs=train_inputs,
targets_template=train_targets * np.nan,
forcings=train_forcings)

--

--