Federated Learning of a Recurrent Neural Network for text classification, with Raspberry Pis working as remote workers

Naufil Muhammad
Secure and Private AI Writing Challenge
6 min readAug 16, 2019

This article is an extended version of Official Open Mined Blog.

In this tutorial, you are going to learn how to train a Recurrent Neural Network (RNN) in a federated way using two Raspberry Pis as remote workers, in order to determine to classify the surname of a person to match with most related language of origin.

We will train two Recurrent Neural Networks residing on two remote workers (Raspberry Pis). We will have a data-set of 20,000 surnames from 18 different language, which will help us to predict, which language a name belongs based on the name’s spelling.

A character-level RNN treats words as a series of characters — outputting a prediction and “hidden state” per character, feeding its previous hidden state into each next step. We take the final prediction to be the output, i.e. which class the word belongs to. Hence the training process proceeds sequentially character-by-character through the different hidden layers.

From 10,000 Feet Above

We will set up two Raspberry Pis as websocket servers, using latest Rasbian Buster image and creating virtual environment and installing/downloading various dependencies. Then we will move on to setting up our central server, which will be our laptop. We will install Anaconda in it and create a virtual environment and install dependencies and run jupyter notebook.

Setting up Raspberry Pis

Download Rasbian Buster (the latest Raspbian image)

It comes with pre-installed python 3.7.3. Download it from here

You can either choose NOOBS or Raspbian, I would recommend to go for Raspbian.

Choose your favorite image burner

I personally use Etcher. Download it from here

After installing, it looks like this

Get an SD card reader. Insert your SD card(32 GB or bigger) in it. Plug in your machine(laptop/computer). Open Etcher, select the Raspbain Buster newly downloaded file(image). Select the SD card, you want to install(burn) raspbian in. Click Flash. It takes around 5 mins to install. Then eject the SD card reader. Take out SD card from it, and insert it in your Raspberry Pi and tadaa congrats you have an economical desktop computer :D

My friends have written a great step by step and detailed beginners articles on how to install Rasbian Buster in Raspberry pi. Check it out here and here.

Increase Swap Space

Type in terminal:

sudo nano /etc/dphys-swapfile

The default value in Raspbian is:

CONF_SWAPSIZE=100

We will need to change this to:

CONF_SWAPSIZE=2048

and save.

Create a Virtual Environment

Before using Raspberry Pi as a virtual worker, we need to create a virtual environment. It is always a good idea to create separate virtual environments for different projects so that dependencies don’t conflict with each other.

Press Ctrl + Alt + T, it will open the terminal for you.

pip3 install virtualenv 
cd ~
python3 -m venv federated_learning
source federated_learning/bin/activate

Now you are inside a virtual environment, named federated_learning.

Install dependencies

Download .whl file in the same directory from here. Type the following commands in terminal:

sudo apt-get install libopenblas-dev libblas-dev m4 cmake cython python3-dev python3-yaml python3-setuptools
pip install torch-1.0.0a0+8322165-cp37-cp37m-linux_armv7l.whl
pip install torchvision==0.2.2.post3
pip install syft --no-dependencies
pip install Flask numpy wheel flask_socketio msgpack tblib websocket_client websockets zstd

Clone the Official PySyft repository, by typing:

git clone https://github.com/OpenMined/PySyft
cd PySyft/examples/tutorials/advanced/websockets-example-MNISTwebsockets-example-MNIST
sudo apt-get install gedit
gedit run_websocker_server.py

Press Ctrl + F, it will pop up a search box. Look for “localhost” and change it to IP address of your raspberry Pi.

Hover your mouse above the wifi icon. It will show the IP address

Before

from multiprocessing import Process
import syft as sy
from syft.workers import WebsocketServerWorker
import torch
import argparse
import os

hook = sy.TorchHook(torch)


def start_proc(participant, kwargs): # pragma: no cover
""" helper function for spinning up a websocket participant """

def target():
server = participant(**kwargs)
server.start()

p = Process(target=target)
p.start()
return p


parser = argparse.ArgumentParser(description="Run websocket server worker.")

parser.add_argument(
"--port", "-p", type=int, help="port number of the websocket server worker, e.g. --port 8777"
)

parser.add_argument("--host", type=str, default="localhost", help="host for the connection")

parser.add_argument(
"--id", type=str, help="name (id) of the websocket server worker, e.g. --id alice"
)

parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="if set, websocket server worker will be started in verbose mode",
)

args = parser.parse_args()

kwargs = {
"id": args.id,
"host": args.host,
"port": args.port,
"hook": hook,
"verbose": args.verbose,
}


if os.name != "nt":
server = start_proc(WebsocketServerWorker, kwargs)
else:
server = WebsocketServerWorker(**kwargs)
server.start()

After

from multiprocessing import Process
import syft as sy
from syft.workers import WebsocketServerWorker
import torch
import argparse
import os

hook = sy.TorchHook(torch)


def start_proc(participant, kwargs): # pragma: no cover
""" helper function for spinning up a websocket participant """

def target():
server = participant(**kwargs)
server.start()

p = Process(target=target)
p.start()
return p


parser = argparse.ArgumentParser(description="Run websocket server worker.")

parser.add_argument(
"--port", "-p", type=int, help="port number of the websocket server worker, e.g. --port 8777"
)

parser.add_argument("--host", type=str, default="192.168.8.104", help="host for the connection")

parser.add_argument(
"--id", type=str, help="name (id) of the websocket server worker, e.g. --id alice"
)

parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="if set, websocket server worker will be started in verbose mode",
)

args = parser.parse_args()

kwargs = {
"id": args.id,
"host": args.host,
"port": args.port,
"hook": hook,
"verbose": args.verbose,
}


if os.name != "nt":
server = start_proc(WebsocketServerWorker, kwargs)
else:
server = WebsocketServerWorker(**kwargs)
server.start()

Save it.

Type in terminal:

python run_websocket_server.py --id alice --port 8777

Set up the same things on another Raspberry Pi and at the end, you need to modify the last command like this:

python run_websocket_server.py --id bob --port 8778

Setting up Laptop

Your laptop will work as a central server. Download Anaconda for python 3.7 in your laptop from here

Once installed, search for anaconda prompt. It looks like just terminal and type in following commands.

conda update conda
conda create -n federated_learning
source federated_learning activate
conda install pytorch torchvision cudatoolkit=9.0 -c pytorch
conda install pip
pip install syft
git clone https://github.com/OpenMined/PySyft

Reduce the Timeout Interval

You have to go to: Anaconda3\envs\federated_learning\Lib\site-packages\syft\workers and open the websocket_client.py file and remove one 9 from that TIMEOUT_INTERVAL so that the time out interval is reduced. It will finally look like this:

Time to train the RNN

cd PySyft\examples\tutorials\advanced
jupyter notebook "Federated Recurrent Neural Network.ipynb"

Jupyter notebook will open up in your browser. Press Ctrl + F to find TorchHook. You will land in here. Change this cell, so that it looks like this. Of course your IP addresses will vary from mine.

Now go back to start, and execute each cell by pressing Shift + Enter, except for the first cell.

The following cell will take approximately 10 minutes to complete

This below cell will take approximately 4 hours to train the neural network:

At the end, your notebook, will look like mine. I am sharing it, so that you can match :D Here is the Github Repo.

Congrats you have made it. Leave your comments below if you made it or if you have any queries or suggestions.

--

--