Published in


Training Vision Transformers from Scratch for Malware Classification

An image is worth 16x16 words, what is a malware worth? Maybe a malware is worth 66x66 image.

# 1. Background

Task & dataset description

Recently, a malware classification track was launched in the 2021 iFLYTEK A.I. Developer Challenge. The competition provides known malware data and requires the competitors to predict the class (family) to which each malware sample data belongs. This is a multi-class problem containing 9 malware classes, identified by 0 to 8.

The data of the competition consists of training set and test set, with a total data volume of over 10w, containing 70 fields, where id is the unique identifier of each sample and label is the malware category to which the sample belongs. In order to ensure the fairness of the competition, 50,000 samples were selected as the training set and 8,000 as the test set, and some fields were desensitized. In particular, the feature fields are mainly asm file information, such as “line_count_asm” for the number of lines in the asm file, “size_asm” for the size of the asm file, and the rest of the feature fields about asm are prefixed with “asm_commands”, which is understood as a command in asm. Unlike Microsoft Malware Classification Challenge in 2015, this dataset only has opcode word frequency and file size information, i.e. I can only use static features for analysis and modeling. The evaluation metric is accuracy refer from sklearn.metrics.

Here just simple to show distributions of classes.

Distributions of classes in the training dataset

It is obvious that samples of class 0, 1, 2 are much less than others.


Nowadays, neural network methodology has reached a level that may exceed the limits of previous machine learning methods, most of the image based malware classification techniques[1] are implemented with convolutional neural networks (CNNs). It cleverly transfers the malware classification problem to the image classification problem. However, Vision Transformer (ViT)[2], which extends the application of the Transformer architecture from natural language processing to computer vision, has gradually attained state-of-the-art results on many computer vision benchmarks and has been taken an alternative to the existing CNNs architecture.

Motivated by the visual similarity between malware samples of the same family and success of ViT on vision tasks, we propose MalwareViT for applying Vision Transformers to malware classification, a file agnostic deep learning approach based on the co-occurrence matrix obtained from the opcodes frequency extracted from Asm as images to efficiently group malicious software into families.

In the following, I will introduce the malware classification method based on the opcode frequency as an image by applying ViT.

# 2. Malware Image Generation

In ViT, an image is worth 16x16 words, similarily, a malware is worth 66x66 image in MalwareViT.

Steps for processing as follow. After calculating the word frequencies of the 66 opcodes obtained by decompiling the malicious binary file, we sorted them in ascending order of the total frequency, normalized them to the interval from 0 to 255, considered them as pixel values, and arranged them in one column each horizontally and vertically to form a two-dimensional array. Since some studies[3] have shown that the smaller the word frequency, the better it is at distinguishing malware, I performed an “Inverse frequency” operation, i.e. 255/(freq+1), to make the smaller the opcode frequency, the larger the grayscale value. Additionally, the values of the intersection of the columns and rows in the matrix are taken as the maximum between them to obtain the co-occurrence matrices. Finally, we save these matrices as images with size of 66x66.

Feature processing is one of the most important steps. Constructing features based on the properties of convolutional or attentional mechanisms has a significant impact on deep learning models. Here, local opcode patches and global location distribution information are used to construct the image features, and the generated image is as follows.

malware images and labels

# 3. Overview of ViT

ViT model overview

The picture above (from the paper[3]) shows how the visual transformer works.

In the paper, they proposes an approach that focuses not on pixels but on small areas of the image. Each color block in the input image is flattened by using a linear projection matrix, and position embedding is added to it. This is necessary because the transformer processes all the inputs without considering the order, so having this location information helps the model to correctly evaluate the weight of attention. Additional class tags are connected to the input (position 0 in the image) as placeholders for the classes to be predicted in the classification task. We can use it as a supplement to the global information.

This code is based on the example Image classification with ViT on the Keras website.

# 3. Implement MalwareViT

Import package

# environment: Colab Tensorflow 2.5.0
# !pip install -U tensorflow-addons
import tensorflow as tf
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

import random
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

Prepare the data

Use ImageDataGenerator to import data from folders with categories of images

path = './dataset/'
train_path = path+'/train'
img_width, img_height = 66,66
img_size = (img_width, img_height)
def img_data_gen(imgs_path,img_size,batch_size,rescale,shuffle=False):
return ImageDataGenerator(rescale=rescale).flow_from_directory(imgs_path,target_size=img_size,batch_size=batch_size,class_mode='categorical',shuffle=shuffle)

train_gen = img_data_gen(imgs_path=train_path,img_size=img_size,batch_size=50000,rescale=1. / 255,shuffle=True)

imgs, labels = next(train_gen)

Divide the data for model training and validation by test_size scale of 0.25

from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(imgs,labels, test_size=0.25,stratify=labels, random_state=SEED)

Configure the hyperparameters

num_classes = labels.shape[1]
input_shape = imgs.shape[1:]

learning_rate = 0.0005
weight_decay = 0.0001
batch_size = 256
num_epochs = 150
patience = 30 # After patience epoch stop if not improve
image_size = 66 # We'll resize input images to this size
patch_size = 11 # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
# Here input_shape=(66,66),patch shape=(11,11) -> 36 patches
projection_dim = 36
num_heads = 6
transformer_units = [
projection_dim * 2,
] # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024] # Size of the dense layers of the final classifier

ViT Modeling

def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x

class Patches(layers.Layer):
def __init__(self, patch_size):
super(Patches, self).__init__()
self.patch_size = patch_size

def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
# refer:
def get_config(self):
config = super().get_config().copy()
'patch_size': self.patch_size,
return config

class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim

def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
return config

def create_vit_classifier():
inputs = layers.Input(shape=input_shape)
# Augment data.
# augmented = data_augmentation(inputs)
# Create patches.
patches = Patches(patch_size)(inputs)
# Encode patches.
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

# Create multiple layers of the Transformer block.
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = layers.Add()([x3, x2])

# Create a [batch_size, projection_dim] tensor.
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
representation = layers.Flatten()(representation)
representation = layers.Dropout(0.5)(representation)
# Add MLP.
features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
logits = layers.Dense(num_classes)(features)
model = keras.Model(inputs=inputs, outputs=logits)
return model

Visualize patches

Image size: 66 X 66
Patch size: 11 X 11
Patches per image: 36
Elements per patch: 363
original image of size 66x66
36 patches of size 11x11

Train, evaluate and predict

def run_experiment(model):
optimizer = tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=weight_decay)
keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
model_name = "keras_trained_MalwareViT.h5"
log_dir = os.path.join(os.getcwd(), 'logs')
ck_path = os.path.join(log_dir, filepath)
if not os.path.isdir(log_dir):
mc = keras.callbacks.ModelCheckpoint(ck_path, monitor='val_loss',save_best_only=True,save_weights_only=True)
es = keras.callbacks.EarlyStopping(monitor='val_loss', patience=patience, verbose=0) # when patience epoch val_loss not improve, stop train
# tb = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0)
callbacks = [es, mc]

history =
validation_data=(X_val, y_val),
# validation_split=0.1,
# shuffle=True,
# To see history keys for visualization

# save model and weight
# model_path = os.path.join(log_dir, model_name)

return history

vit_classifier = create_vit_classifier()

# show and save model structure
# vit_classifier.summary()
# keras.utils.plot_model(vit_classifier, show_shapes=True)

# train
history = run_experiment(vit_classifier)
Loss and accuracy curve


# load best model

_, accuracy, top_5_accuracy = vit_classifier.evaluate(X_val, y_val)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")


sub_path = "./dataset/test"
sub_gen = img_data_gen(imgs_path=sub_path,img_size=img_size,batch_size=1,rescale=1. / 255)
sub_pred = vit_classifier.predict(sub_gen)
sub_pred_class = np.argmax(sub_pred, axis = 1)


df_sub = pd.read_csv("./dataset/sample_submit.csv")
df_sub['label'] = sub_pred_class

# 5. Improvement suggestions

Since the ViT implementation here does not use pretrain, if you want to get higher accuracy, try to train more rounds, use deeper layers, change the input image size, change the patch size, increase the projection dimension, and also consider changing the learning rate, switching to an optimizer, using weight decay, and other training strategies.

Because the images are relatively simple, too complex models tend to be overfitted. Of course, this dataset and the 15-year Microsoft malware classification dataset are likely to be homologous, and one can also consider training from the large Microsoft malware classification dataset to get a pre-trained model, and then fine-tuning on this dataset may work well.

It has been shown in the paper[4] that fine-tuning using pre-trained models of several classical CNN architectures on ImageNet also gives excellent results, where color images are used and image data enhancement from image processing methods is used to improve the robustness of the model, which is better compared to grayscale images. (ResNet50 works best as a pre-trained model in the comparison of grayscale based models in the paper experiments. I subsequently tried the pre-training scheme based on several classical CNN architectures, and it is really good, MobileNetV2 training faster and better, but all very easy to over-fit)

From the paper[4], we also got some inspirations that we may be able to improve model robustness by introducing adversarial training based on code-level obfuscation techniques, such as Dead-code Insertion, Code Transposition, Register Reassignment, Instruction Substitution, etc. In this case, the tabular data set can be manipulated by randomly increasing the overall opcode word frequency, randomly increasing the number of JMPs and CALLs, etc. The image data can be manipulated by increasing the grayscale value, changing the brightness, overall scaling, etc.

In addition to table-like data can use tree models, image data use neural network models, and then model fusion, you can also do this classification problem as a sequence classification, a simple combination of a column of word frequency features into a column, and the NLP direction of ideas, will be based on word frequency reverse to get asm documents, and then embedding, as a text classification task.

Here is my github repository, just throw in a brick to draw in the jade, welcome to click ⭐. If you have any questions, please feel free to communicate with me.


[1] Nataraj L , Karthikeyan S , Jacob G , et al. Malware Images: Visualization and Automatic Classification. ACM, 2011.

[2] Dosovitskiy A , Beyer L , Kolesnikov A , et al. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale[J]. 2020.

[3] Bilar D . Opcodes as predictor for malware[J]. International Journal of Electronic Security and Digital Forensics, 2007, 1(2):156–168.

[4] Vasan D , Alazab M , Wassan S , et al. IMCFN: Image-based Malware Classification using Fine-tuned Convolutional Neural Network Architecture[J]. Computer Networks, 2020, 171:107138.



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Ricky Xu

Ricky Xu

A data science learner. Ever walking, never settle.