Satellite image super-resolution with SR3

Soumya Kanta Dash
GeoAI
Published in
6 min readDec 26, 2023

High-resolution satellite imagery is often desirable for interpretation, feature extraction, analysis, visualization, etc. but, is quite expensive to procure. Advances in computer vision and deep learning have enabled significant progress in super-resolution, which aims to algorithmically increase image resolution. This blog demonstrates how the ‘SR3’ backbone in SuperResolution class, made available in arcgis.learn module of the ArcGIS Python API, can be used to increase satellite image resolution. This leads to addition of texture and details to upscaled low resolution satellite imagery.

SR3 adapts a denoising diffusion probabilistic model (DDPM) to conditionally generate images and performs super-resolution through a denoising process. Made popular by text-to-image models like Stable Diffusion, DALL E2, and Midjourney, diffusion models excel at generating images by modelling complex data distributions probabilistically, employing a hierarchical denoising process, and enabling fine-grained control through conditioning.

In this blog, we will focus on using a modified version of the SR3 model available as backbone in SuperResolution class, originally developed by Saharia et al. (2022), for the task of increasing spatial resolution of satellite imagery.

SR3 workflow in arcgis.learn

Recent advances in generative modelling have introduced diffusion models, which have demonstrated better performance compared to earlier approaches. `SR3` or `Super-Resolution via Repeated Refinement` adapts denoising diffusion probabilistic model for conditional image generation and performs super-resolution through a stochastic denoising process. During inferencing, a UNet model trained on denoising at various noise levels is used to iteratively refine noisy output. Learning by denoising consists of two processes, each of which is a Markov Chain. These are:

1. The forward process or noising — In the forward markovian diffusion process, it gradually adds Gaussian noise to a high-resolution or target image over T iterations.

2. The reverse process or denoising — The reverse inference process, it iteratively denoises the target image conditioned on a source image or low-resolution image. The process is shown below in figure below.

Figure. 1. Diffusion models smoothly perturb data by adding noise, then reverse this process to generate new data from noise. Each denoising step in the reverse process typically requires estimating the score function.

We learn the reverse chain using a neural denoising model or UNet that takes as input a source image and a noisy target image and aims to recover or estimate noiseless target image. The noisy target image is like the distribution of noisy images generated at different steps of the forward diffusion process.

Refer to Image super-resolution via iterative refinement for details on the model architecture.

Required imports

from arcgis.learn import prepare_data, SuperResolution

Export and prepare training data.

The Export Training Data for Deep Learning tool with `Export Tiles` metadata format, is used to export training data for the model. There are two cases,

  • if a pair of high resolution (hr) and its corresponding low resolution (lr) imagery is available. The hr is provided in input imagery and lr is provided in additional raster.
  • And, if only high resolution (hr) is available. The hr is provided in input imagery, during `prepare_data` low resolution samples are generated based on the provided down-sampling factor.

Refer to this sample notebook for more details on data export for super-resolution.

data = prepare_data(path, batch_size=8, downsampling_factor=4) 

Visualize training samples

data.show_batch(rows=3)

Train SR3 model

The following sections describe how to train the model.

Load SR3 model architecture

Initialize the SR3 model object as shown below.

model = SuperResolution(data, backbone="SR3",
norm_groups = 16,
res_blocks = 3,
n_timestep = 1500)

The following model parameters can be passed using keyword arguments:

  • inner_channel — Optional integer. inner_channel is the dimension of the first unet layer. default set to 64.
  • norm_groups — Optional integer. number of groups for group normalization. default set to 32.
  • channel_mults — Optional list. depth multipliers are the multipliers for subsequent resolutions in the unet. default set to [1, 2, 4, 4, 8, 8].
  • attn_res— Optional integer. number of attentions in residual blocks. default set to 16.
  • res_blocks — Optional integer. number of residual blocks. default set to 3.
  • dropout — Optional float. dropout. default set to 0.
  • schedule — Optional string. type of schedule. Available are ‘linear’, ‘warmup10’, ‘warmup50’, ‘const’, ‘jsd’, ‘cosine’. Default set to ‘linear’.
  • n_timestep — Optional list. number of diffusion timesteps. default set to 1500.
  • linear_start— Optional integer. schedule start. default set to 1e-05.
  • linear_end — Optional integer. schedule end. default set to 1e-02.

Default values are set for 64 to 256 upscale tasks as provided in the paper. But, keyword arguments may also need some adjustments based on the dataset and results obtained. If there is out-of-memory issue, decrease the number of res blocks. Optimal results are generally obtained, if trained for at least 200–300 epochs.

Find optimal learning rate

model.lr_find()

It is observed that learning rate have great impact on the model performance. Learning rate ranging from `1e-5` to `1e-04` generally work for a variety of datasets. The suggested learning rate using the learning rate finder can also be used.

Fit the model

To train the model, use the fit() method, which uses the following arguments:

  • epochs — Number of cycles of training on the data.
  • lr — Learning rate to be used for training the model.
  • tensorboard — monitor model performance while training, helps in deciding optimal hyperparams.
model.fit(300, lr=2.5e-05, tensorboard=True)

Here, the sr3 model is trained for 300 epochs. Initial 10 epochs are shown in the figure above.

Visualize results

The show_results method can be used to visualize the results of the trained model. It accepts the following arguments:

  • sampling_type — type of sampling. Two types of sampling can be used. ‘ddim’ and ‘ddpm’. ddim is set as default as it generates in much less timesteps compared to ddpm.
  • schedule — default set to the schedule the model is trained on.
  • n_timestep — default is set to 200. It can be increased and decreased based on the quality of generations. It may take significant greater time for generation.

As you can see, the model has learnt to increase the resolution of the imagery by adding realistic details and texture!

Save the model

As metrics and losses look satisfactory, we can save the trained model as a deep learning package (.dlpk format) for large-scale inferencing. The deep learning package format is the standard format used to deploy deep learning models on the ArcGIS platform.

Use the save() method to save the trained model. By default, it will be saved to the model’s subfolder in the training data folder.

model.save('sr3_trained_model')

Model inferencing in larger extent

The `Classify Pixels Using Deep Learning` tool in ArcGIS Pro is used for inferencing on low resolution imagery, using the saved model.

During inference, the model utilizes the following model arguments:

  • sampling_type — type of sampling. Two types of sampling can be used. ‘ddim’ and ‘ddpm’. ‘ddim’ is set as default as it generates results in much fewer timesteps compared to ddpm.
  • schedule — default set to the schedule the model is trained on.
  • n_timestep — default is set to 200. It can be increased and decreased based on the quality of generations. It may take significant greater time for generation.

Shown below is the predicted raster by applying the trained super-resolution model on low resolution imagery:

Conclusion

In this post, you have seen how you can train the SR3 model made available in arcgis.learn module of `ArcGIS API for Python` to increase the spatial resolution in case of satellite imagery.

References

See the following to learn more:

  1. Saharia, Chitwan, Jonathan Ho, William Chan, Tim Salimans, David J. Fleet, and Mohammad Norouzi. “Image super-resolution via iterative refinement.” IEEE Transactions on Pattern Analysis and Machine Intelligence 45, no. 4 (2022): 4713–4726.
  2. Data preparation methods

--

--