Enable LoRA weights with Stable Diffusion Controlnet Pipeline

Authors: Zhen Zhao(Fiona), Kunda Xu

OpenVINO™ toolkit
OpenVINO-toolkit
9 min readOct 25, 2023

--

Low-Rank Adaptation(LoRA) is a novel technique introduced to deal with the problem of fine-tuning Diffusers and Large Language Models (LLMs). In the case of Stable Diffusion fine-tuning, LoRA can be applied to the cross-attention layers for the image representations with the latent described. You can refer to Hugging Face diffusers to understand the basic concept and method for model fine-tuning: https://huggingface.co/docs/diffusers/training/lora

In this blog, we aimed to introduce the method of building up the pipeline for Stable Diffusion + ControlNet with OpenVINO™ optimization and enable LoRA weights for the Unet model of Stable Diffusion to generate images with different styles. The demo source code is based on: https://github.com/FionaZZ92/OpenVINO_sample/tree/master/SD_controlnet

Stable Diffusion ControlNet Pipeline

Step 1: Environment preparation

First, please follow the below method to prepare your development environment, you can choose the download model from Hugging Face for a better runtime experience. In this case, we choose ControlNet for the canny image task.

$ mkdir ControlNet && cd ControlNet
$ wget https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth

$ conda create -n SD python==3.10
$ conda activate SD

$ pip install opencv-contrib-python
$ pip install -q "diffusers>=0.14.0" "git+https://github.com/huggingface/accelerate.git" controlnet-aux gradio
$ pip install openvino openvino-dev onnx
$ pip install torch==1.13.1 #important

$ git lfs install
$ git clone https://huggingface.co/lllyasviel/sd-controlnet-canny
$ git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
$ git clone https://huggingface.co/openai/clip-vit-large-patch14

$ wget https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/vermeer_512x512.png

* Please note, the diffusers start to use `torch.nn.functional.scaled_dot_product_attention` if your installed torch version is >= 2.0, and the ONNX does not support op conversion for “Aten:: scaled_dot_product_attention”. To avoid the error during the model conversion by “torch.onnx.export”, please make sure you are using torch==1.13.1.

Step 2: Model Conversion

The demo provides two programs, to convert a model to OpenVINO™ IR, you should use “get_model.py”. Please check the options of this script by:

$ python get_model.py -h
usage: get_model.py [-h] -b BATCH -sd SD_WEIGHTS [-lt LORA_TYPE] [-lw LORA_WEIGHTS]

Options:
-h, --help Show this help message and exit.
-b BATCH, --batch BATCH
Required. batch_size for solving single/multiple prompt->image generation.
-sd SD_WEIGHTS, --sd_weights SD_WEIGHTS
Specify the path of the stable diffusion model
-lt LORA_TYPE, --lora_type LORA_TYPE
Specify the type of lora weights, you can choose "safetensors" or "bin"
-lw LORA_WEIGHTS, --lora_weights LORA_WEIGHTS
Add lora weights to Stable diffusion.

In this case, let us choose multiple batch sizes to generate multiple images. The common application of vision generation has two concepts of batch:

  1. `batch_size`: Specify the length of the input prompt or negative prompt. This method is used for generating N images with N prompts.
  2. `num_images_per_prompt`: Specify the number of images that each prompt generates. This method is used to generate M images with 1 prompt.

Thus, for common user applications, you can use these two attributes in diffusers to generate N*M images by N prompts with increased random seed values. For example, if your basic seed is 42, to generate N(2)*M(2) images, the actual generation is like below:

  • N=1, M=1: prompt_list[0], seed=42
  • N=1, M=2: prompt_list[0], seed=43
  • N=2, M=1: prompt_list[1], seed=42
  • N=2, M=2: prompt_list[1], seed=43

In this case, let’s use N=2, M=1 as a quick example for demonstration, thus the use` — batch 2`. This script will generate a static shape model by default. If you are using different values of N and M, please specify ` — dynamic`.

$ python get_model.py -b 2 -sd stable-diffusion-v1-5/

Please check your current path, and make sure you already generated the below models. Other ONNX files can be deleted to save space.

  • controlnet-canny.<xml|bin>
  • text_encoder.<xml|bin>
  • unet_controlnet.<xml|bin>
  • vae_decoder.<xml|bin>

* If your local path already exists ONNX or IR model, the script will generate ONNX/IR. If you updated the pytorch model or want to generate a model with a different shape, please remember to delete existing ONNX and IR models.

Step 3: Runtime pipeline test

The provided demo program `run_pipe.py` is manually built up the pipeline for StableDiffusionControlNet which refers to the source of `diffusers.StableDiffusionControlNetPipeline`

https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet.py

The difference is we simplify the pipeline with 4 models’ inference by OpenVINO™ runtime API which can make sure the model inference can be accelerated on Intel® CPU and GPU platform.

The default iteration is 20, the image shape is 512*512, the seed is 42, and the input image and prompt is for “Girl with Pearl Earring”. You can adjust or customize your pipeline attributes for testing.

$ python run_pipe.py

In the case with batch_size=2, the generated image is like below:

Generated images with original Stable Diffusion v1.5 + canny ControlNet

Enable LoRA weights for Stable Diffusion

Normal LoRA weights have two types, one is ` pytorch_lora_weights.bin`, and the other is using safetensors. In this case, we introduce both methods for these two LoRA weights.

The main idea for LoRA weights enabling is to append weights onto the original Unet model of Stable Diffusion, then export the IR model of Unet which remains LoRA weights.

There are various LoRA models on https://civitai.com/tag/lora, we chose some public models on HuggingFace as an example, you can consider replacing them with your own.

Step 4–1: Enable LoRA by pytorch_lora_weights.bin

This step introduces the method to add LoRA weights to the Unet model of Stable Diffusion by `pipe.unet.load_attn_procs(…)` function. By using this way, the LoRA weights will be loaded into the attention layers of the Unet model of Stable Diffusion.

$ git clone https://huggingface.co/TheUpperCaseGuy/finetune-lora-stable-diffusion
$ rm unet_controlnet.* unet_controlnet/unet_controlnet.onnx
$ python get_model.py -b 2 -sd stable-diffusion-v1-5/ -lt bin -lw finetune-lora-stable-diffusion/

* Remember to delete the existing Unet model to generate the new IR with LoRA weights.

Then, run a pipeline inference program to check the results.

$ python run_pipe.py

The LoRA weights appended Stable Diffusion model with controlNet pipeline can generate an image like the below:

Stable Diffusion v1.5+ LoRA bin weights + canny ControlNet

Step 4–2: Enable LoRA by safetensors typed weights

This step introduces the method to add LoRA weights to the Stable diffusion Unet model by `diffusers/scripts/convert_lora_safetensor_to_diffusers.py`. Diffusers provide the script to generate a new Stable Diffusion model by enabling safetensors typed LoRA model. By this method, you will need to replace the weighted path to the newly generated StableDiffusion model with LoRA. You can adjust the value of the `alpha` option to change the merging ratio in `W = W0 + alpha * deltaW` for attention layers.

$ git clone https://huggingface.co/ntc-ai/fluffy-stable-diffusion-1.5-lora-trained-without-data
$ git clone https://github.com/huggingface/diffusers.git && cd diffusers
$ python scripts/convert_lora_safetensor_to_diffusers.py --base_model_path ../stable-diffusion-v1-5/ --checkpoint_path ../fluffy-stable-diffusion-1.5-lora-trained-without-data/fluffySingleWordConcept_v10.safetensors --dump_path ../stable-diffusion-v1-5-fluffy-lora --alpha=1.5
$ cd .. && rm unet_controlnet.* unet_controlnet/unet_controlnet.onnx
$ python get_model.py -b 2 -sd stable-diffusion-v1-5-fluffy-lora/ -lt safetensors

Then, run a pipeline inference program to check the results.

$ python run_pipe.py

The LoRA weights appended SD model with controlnet pipeline can generate an image like below:

Stable Diffusion v1.5 + LoRA safetensors weights + canny ControlNet

Step 4–3: Enable runtime LoRA merging by MatcherPass

This step introduces the method to add LoRA weights in runtime before Unet or text_encoder model compiling. It will be helpful to client application usage with multiple different LoRA weights to change the image style by reusing the same Unet/text_encoder structure.

This method is to extract LoRA weights in the safetensors file, find the corresponding weights in the Unet model, and insert the LoRA weights bias. The common method to add LoRA weights is like this:

W = W0 + W_bias(alpha * torch.mm(lora_up, lora_down))

I intend to insert the Add operation for Unet’s attentions’ weights by OpenVINO™ `opset10.add(W0, W_bias)`. The original attention weights in the Unet model are loaded by `Const` op, and the common processing path is `Const->Convert->Matmul->…`, if we add the LoRA weights, we should insert the calculated LoRA weight bias as `Const->Convert->Add->Matmul->…`. In this function, we adopt `openvino.runtime.passes.MatcherPass` to insert `opset10.add()` with the call_back() function iteratively.

Model transformation method of adding LoRA weights

Your transformation operations will insert opset.Add() firstly, then during the model compiling with the device. The graph will do constant folding to combine the Add operation with the following MatMul operation to optimize the model runtime inference. Thus, this is an effective method to merge LoRA weights onto an original model.

You can check with the implementation source code, and find out the definition of the MatcherPass function called `InsertLoRA(MatcherPass)`:

class InsertLoRA(MatcherPass):
def __init__(self,lora_dict_list):
MatcherPass.__init__(self)
self.model_changed = False

param = WrapType("opset10.Convert")

def callback(matcher: Matcher) -> bool:
root = matcher.get_match_root()
root_output = matcher.get_match_value()
for y in lora_dict_list:
if root.get_friendly_name().replace('.','_').replace('_weight','') == y["name"]:
consumers = root_output.get_target_inputs()
lora_weights = ops.constant(y["value"],Type.f32,name=y["name"])
add_lora = ops.add(root,lora_weights,auto_broadcast='numpy')
for consumer in consumers:
consumer.replace_source_output(add_lora.output(0))

# Use new operation for additional matching
self.register_new_node(add_lora)
# Root node wasn't replaced or changed
return False
self.register_matcher(Matcher(param,"InsertLoRA"), callback)

The `InsertLoRA(MatcherPass)` function will be registered by `manager.register_pass(InsertLoRA(lora_dict_list))`, and invoked by `manager.run_passes(ov_unet)`. After this runtime MatcherPass operation, the graph compile with the device plugin is ready for inference.

Run a pipeline inference program to check the results. The result is the same as Step 4–2.

python run_pipe.py -lp fluffy-stable-diffusion-1.5-lora-trained-without-data/fluffySingleWordConcept_v10.safetensors -a 1.5

The LoRA weights appended Stable Diffusion model with controlNet pipeline can generate an image like the below:

Stable Diffusion v1.5+runtime LoRA safetensors weights + ControlNet

Step 4–4: Enable multiple LoRA weights

There are many different methods to add multiple LoRA weights. I list two methods here. Assume you have two LoRA weights, LoRA A and LoRA B. You can simply follow Step 4–3 to loop the MatcherPass function to insert between the original Unet Convert layer and the added layer of LoRA A. It’s easy to implement. However, it is not good at performance.

Method 1: Loop InsertLoRA() twice

Please consider the Logic of the MatcherPass function. This function is required to filter out all layers with the Convert type, then through the condition judgment if each Convert layer connected by weights Constant has been fine-tuned and updated in the LoRA weights file. The main costs of LoRA enabling are costed by the InsertLoRA() function, thus the main idea is to just invoke the InsertLoRA() function once, but append multiple LoRA files’ weights.

Method 2: Append all LoRA weights together to insert

By the above method to add multiple LoRA, the cost of appending 2 or more LoRA weights is almost the same as adding 1 LoRA weight.

Now, let’s change the Stable Diffusion with dreamlike-anime-1.0 to generate images with styles of animation. I picked two LoRA weights for SD 1.5 from https://civitai.com/tag/lora.

You probably need to do prompt engineering work to generate a useful prompt like below:

  • prompt: “1girl, cute, beautiful face, portrait, cloudy mountain, outdoors, trees, rock, river, (soul card:1.2), highly intricate details, realistic light, trending on cgsociety, neon details, ultra-realistic details, global illumination, shadows, octane render, 8k, ultra sharp”
  • Negative prompt: “3D, cartoon, low-res, bad anatomy, bad hands, text, error”
  • Seed: 0
  • num_steps: 30
  • canny low_threshold: 100
$ python run_pipe.py -lp soulcard.safetensors -a 0.7 -lp2 epi_noiseoffset2.safetensors -a2 0.7

You can get a wonderful image that generates an animated girl with a soulcard typical border like below:

SD dreamlike-anime-1.0+canny_Controlnet+soulcard+noiseoffset

Additional Resources

Download OpenVINO™

OpenVINO™ Documentation

OpenVINO™ Notebooks

Provide Feedback & Report Issues

Notices & Disclaimers

Intel technologies may require enabled hardware, software, or service activation.

No product or component can be secure.

Your costs and results may vary.

Intel does not control or audit third-party data. You should consult other sources to evaluate accuracy.
Intel disclaims all express and implied warranties, including without limitation, the implied warranties of merchantability, fitness for a particular purpose, and non-infringement, as well as any warranty arising from course of performance, course of dealing, or usage in trade.

No license (express or implied, by estoppel or otherwise) to any intellectual property rights is granted by this document.

© Intel Corporation. Intel, the Intel logo, and other Intel marks are trademarks of Intel Corporation or its subsidiaries. Other names and brands may be claimed as the property of others.

--

--

OpenVINO™ toolkit
OpenVINO-toolkit

Deploy high-performance deep learning productively from edge to cloud with the OpenVINO™ toolkit.