How to train Pix2Pix model and generating on the web with ml5.js
In this post, I will go through the process of training a generative image model using Spell and then porting the model to ml5.js. This is how I did for my project during the class in NYU ITP this semester. I made a simple interactive web app using newly trained my own Pix2Pix model.
What is Pix2Pix?
Pix2Pix, or Image-to-Image Translation, can learn how to convert images of one type into another type of new image. Pix2Pix uses a kind of generative algorithm called Conditional Adversarial Network(cGAN) where the generation process is “conditioned” on the input image. The original paper and code published by Phillip Isola et al.
Fortunately, I could get in-depth explanations of the model and technical details from classes which are Neural Aesthetic and Machine Learning for the Web taught by Gene Kogan and Yining Shi, respectively. Here are good resources from classes that contain more detailed descriptions of how Pix2Pix works.
Also, we can try some well-known interactive demo and find out a great post including a detailed explanation of how the encoding-decoding process works on this page made by Christopher Hesse. I used his Pix2Pix Tensorflow implementation and ran it inside Spell run.
What I built
First, I created an interactive human picture generator which is based on the ml5.js Pix2Pix example. Here is a demo link of mine:
I trained a model on pairs of photos of people and color segmented image of those original pictures. From that, Pix2Pix learned how to convert the color segmented images into output images that show human-like pictures. So you can use the same approach to train a model in any kind of related pairs of images you like.
I also made another web app which is called #HUMANSOFNOWHERE. You can use a webcam stream image as an input source for Pix2Pix generation. It also generates text with LSTM model ported by ml5.js. The model trained with all the images and texts that scraped from an existing Instagram account Humans of New York. So the basic idea of this project was making an autonomous Instagram bot which generates images and caption then posts them so that the account could be like a reflection of humansofny, one of New York’s iconic IG account.
Actually, this one is yet completed, but I would release a link of its alpha version and the Instagram account here.
Humans of Nowhere (@__humansofnw__) * Instagram photos and videos
1 Followers, 0 Following, 9 Posts - See Instagram photos and videos from Humans of Nowhere (@__humansofnw__)
How I built
I will cover only a process about the first Pix2Pix example, not about the other one which is combined with LSTM. If you want to train and use LSTM with ml5.js, there is a video tutorial good to start with.
1) Data preparation
I needed a couple of hundred paired images of human photo and its person segmentation to train my own Pix2Pix model. Since the model doesn’t seem to require insanely large sets of image data compare to other generative models, I decided to scrape all images from Humans of New York Instagram account. I used Instagram Scraper python script for this. As a result, I got total 4200 images from the account.
The hardest part of data preparation was refining images. To train Pix2Pix properly, I needed to get 256x256 pixel size of square images first, and then I had to make them into 512x256 images paired with the same size of person segmented images like this:
I had a look several methods to do this effectively, and I ended up in a dirty way. Firstly, I almost randomly chose a few hundred square images among the scraped images that have various poses, gender, figures of people. And then I fed chosen images into the person segmentation process.
For the person segmentation, I wrote a script using tensorflow.js person segmentation model (if it possible, I recommend using body-pix model instead). The code snippet below is particularly used for pair images processing. After image processing, I downloaded 240 paired images by hand.
There are other choices to get person segmented images. For instance, Mask-RCNN model likely can help to make it happen.
Once pairs of images have been collected, I needed to train my model. In order to do so, I used GPU accelerated computing. This process is done with Spell. If you are not familiar with Spell, check this introductory video by Daniel Shiffman *Choo Choo* or Learn with Spell page. And also Yinying has a repo that handled the exact same process so that you can get more sense of Spell.
After all the Spell settings are done on my local, I opened terminal and typed
spell login to login Spell. Then I cloned the Pix2Pix Tensorflow implementation repo and moved to the directory.
git clone https://github.com/affinelayer/pix2pix-tensorflow.git
Before starting the training, I uploaded the prepared image dataset to remote Spell resources folder. In my case, I set the upload folder name as ‘data’.
spell upload DIRECTORY_PATH_WITH_DATASET [--name new-name]
If uploading succeeds, you can find a folder at Spell Resources dashboard with the same name you set.
All seemed good, so I ran the command to train.
spell run --machine-type V100 --mount uploads/data:/data/ --framework tensorflow 'python pix2pix.py --mode train --input_dir /data/ --output_dir ckpt --which_direction BtoA --max_epochs 200'
It took around a half an hour, but the running time could be varied by different machines or hyperparameters you set.
In root folder pix2pix-tensorflow, I could export the model to use it on the web with ml5.js.
# create folder in pix2pix-tensorflow repo folder
# copy the result from spell run
spell cp runs/YOUR_RUN_NUMBER/ckpt
# export the model in pix2pix-tensorflow root folder
python pix2pix.py --mode export --output_dir export/ --checkpoint ckpt/ --which_direction BtoA
# port the model to tensorflow.js
python3 tools/export-checkpoint.py --checkpoint ../export --output_file static/models/MY_MODEL_BtoA.pict
These commands gave the exported model file in
.pict format. I used that model on ml5.js.
3) Use the model
In ml5.js script, I could simply change the path to point to my new model.
pix2pix = ml5.pix2pix('model/MY_OWN_MODEL.pict', modelLoaded);
If anyone wants to play with the example I made, here is the source. This approach and code for training Pix2Pix should work with any well-structured pairs of images you want to use. You can check some interesting results created by this model on #HUMANSOFNOWHERE Instagram.