How to Enhance Your Unity Projects with AI Using Unity Sentis

Satya Dev
Antaeus AR
Published in
6 min readMar 22, 2024

In the world of extended reality (XR) development, the integration of artificial intelligence (AI) models directly into Unity projects opens up a new realm of possibilities. Unity Sentis, a powerful tool for running AI models through the Unity runtime, has significantly broadened the scope for standalone device applications. In this article, I share my experience with Unity Sentis, focusing on the importation of AI models in the ONNX (Open Neural Network Exchange) format and applying them to create interactive, AI-driven gameplay. Here’s a guide based on my journey, from installation to creating a digit recognition application using the famous MNIST (Modified National Institute of Standards and Technology database) model.

Unity Sentis samples : https://github.com/Unity-Technologies/sentis-samples

Getting Started with Unity Sentis

Unity Sentis solves a unique problem by enabling the use of neural networks, traditionally developed in Python and incompatible with Unity, directly within your games or XR experiences.

To get started, installing Unity Sentis is straightforward: just open the package manager and search for ‘com.unity.sentis’. Installation also comes with the option to import samples, which are incredibly useful for learning various aspects, such as model encryption, asynchronous output reading, and model execution.

Package name for Sentis
Sentis package with samples

Importing the MNIST Model

For our project, we utilized the MNIST model to recognize hand-drawn numbers. After downloading the model, I placed it in a new models folder within the assets directory. Unity Sentis then optimizes the model for performance. The next step involved downloading a texture of a hand-written number for initial testing, though later on, we would create our own inputs by drawing directly in Unity.

mnist-12 model
hand written digit for testing

MNIST Model link : https://github.com/onnx/models/tree/main/validated/vision/classification/mnist

Creating a Digit Classifier Script

The core of our application was the ‘ClassifyHandwrittenDigit’ script. This script references our model and texture, processes the input, and displays the results. It involves declaring fields for the model, texture, and output results, and involves a bit of setup for the model loading and execution. Crucially, I added a softmax layer to the model within the script to transform its raw output into probabilities, making the results more interpretable.

using UnityEngine;
using Unity.Sentis;
using Unity.Sentis.Layers;
public class ClassifyHandwrittenDigit : MonoBehaviour
{
public Texture2D inputTexture;
public ModelAsset modelAsset;
Model runtimeModel;
IWorker worker;
public float[] results;
void Start()
{
// Create the runtime model
runtimeModel = ModelLoader.Load(modelAsset);
// Add softmax layer to end of model instead of non-softmaxed output
string softmaxOutputName = "Softmax_Output";
runtimeModel.AddLayer(new Softmax(softmaxOutputName, runtimeModel.outputs[0]));
runtimeModel.outputs[0] = softmaxOutputName;
// Create input data as a tensor
using Tensor inputTensor = TextureConverter.ToTensor(inputTexture, width: 28, height: 28, channels: 1);
// Create an engine
worker = WorkerFactory.CreateWorker(BackendType.GPUCompute, runtimeModel);
// Run the model with the input data
worker.Execute(inputTensor);
// Get the result
using TensorFloat outputTensor = worker.PeekOutput() as TensorFloat;
// Move the tensor data to the CPU before reading it
outputTensor.MakeReadable();
results = outputTensor.ToReadOnlyArray();
}
void OnDisable()
{
// Tell the GPU we're finished with the memory the engine used
worker.Dispose();
}
}
Classify Handwritten Digit script recognizing number 7

Implementing Finger Drawing in XR

To make the application interactive, I implemented a feature allowing users to draw numbers directly on a texture using their fingertips, with the help of Meta XR SDK. The process involved creating another script that captures input from the OVR hand and feeds it to the ‘ClassifyHandwrittenDigit’ script. Drawing on the texture and having the model classify these drawings in real-time was a fascinating process, demonstrating the model’s accuracy and the power of Unity Sentis.

Meta XR All-in-one SDK package
FingerDrawing Script component and ClassifyHandwrittenDigit component
Right hand index finger Transform
Canvas Raw Image for Black board

Finger Drawing Script:

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.UI;
public class FingerDrawing : MonoBehaviour
{
[SerializeField] private RawImage displayImage;
[SerializeField] private ClassifyHandwrittenDigit classifier;
[SerializeField] private Transform FingerTipMarkerTransform;
[SerializeField] private float delayToSend = 1f;
[SerializeField] private float distanceToCanvas;
private bool hasDrawn = false;
private float lastDrawTime;
private Camera mainCamera;
private Texture2D drawingTexture;
private Coroutine CheckForSendCoroutine;
private void Start()
{
drawingTexture = new Texture2D(28,28, TextureFormat.RGBA32, false);
displayImage.texture = drawingTexture;
mainCamera = Camera.main;
ClearTexture();
}
public void ClearTexture()
{
Color[] clearColors = new Color[drawingTexture.width * drawingTexture.height];
for (int i = 0; i < clearColors.Length; i++)
clearColors[i] = Color.black;
drawingTexture.SetPixels(clearColors);
drawingTexture.Apply();
}
private void Update()
{
bool isDrawing = Vector3.Distance(FingerTipMarkerTransform.position, displayImage.transform.position) < distanceToCanvas;
if (isDrawing)
{
if (CheckForSendCoroutine != null)
{
StopCoroutine(CheckForSendCoroutine);
CheckForSendCoroutine = null;
}
Draw(FingerTipMarkerTransform.position);
}
else if (hasDrawn && Time.time - lastDrawTime > delayToSend && CheckForSendCoroutine == null)
{
CheckForSendCoroutine = StartCoroutine(CheckForSend());
}
}
private void Draw(Vector3 fingerTipPos)
{
Vector2 screenPoint = mainCamera.WorldToScreenPoint(fingerTipPos);
RectTransformUtility.ScreenPointToLocalPointInRectangle(displayImage.rectTransform, screenPoint, mainCamera, out Vector2 localPoint);
Vector2 normalizedPoint = Rect.PointToNormalized(displayImage.rectTransform.rect, localPoint);
AddPixels(normalizedPoint);
}
private void AddPixels(Vector2 normalizedPoint)
{
int TexX = (int) (normalizedPoint.x * drawingTexture.width);
int TexY = (int) (normalizedPoint.y * drawingTexture.height);
if (TexX >= 0 && TexX < drawingTexture.width && TexY >= 0 && TexY < drawingTexture.height)
{
drawingTexture.SetPixel(TexX, TexY, Color.white);
drawingTexture.Apply();
}
}
private IEnumerator CheckForSend()
{
yield return new WaitForSeconds(delayToSend);
classifier.ExecuteModel(drawingTexture);
hasDrawn = false;
CheckForSendCoroutine = null;
}
}

Classified hand written script:

using UnityEngine;
using Unity.Sentis;
using Unity.Sentis.Layers;
public class ClassifyHandwrittenDigit : MonoBehaviour
{
//public Texture2D inputTexture;
public ModelAsset modelAsset;
Model runtimeModel;
IWorker worker;
public float[] results;
private TensorFloat inputTensor; [SerializeField] private FingerDrawing fingerDrawing; void Start()
{
runtimeModel = ModelLoader.Load(modelAsset);
string softmaxOutputName = "Softmax_Output";
runtimeModel.AddLayer(new Softmax(softmaxOutputName, runtimeModel.outputs[0]));
runtimeModel.outputs[0] = softmaxOutputName;
worker = WorkerFactory.CreateWorker(BackendType.GPUCompute, runtimeModel);
//ExecuteModel();
}
public void ExecuteModel(Texture2D inputTexture)
{
inputTensor?.Dispose();
inputTensor = TextureConverter.ToTensor(inputTexture, width: 58, height: 58, channels: 1);
worker.Execute(inputTensor);
TensorFloat outputTensor = worker.PeekOutput() as TensorFloat; outputTensor.MakeReadable();
results = outputTensor.ToReadOnlyArray();
outputTensor.Dispose();
fingerDrawing.ClearTexture();
}
void OnDisable()
{
inputTensor?.Dispose();
worker.Dispose();
}
}

Tips for Successful Implementation

  1. Optimize Your Model: Before importing your model into Unity, ensure it’s optimized for the platform. Unity Sentis does some optimization, but further tweaks might be needed depending on your specific use case.
  2. Test With Different Inputs: While developing, test your model with a variety of inputs to ensure it performs well across different scenarios.
  3. Utilize Unity Samples: The sample projects provided with Unity Sentis are a goldmine of information and can significantly speed up your learning process.
  4. Experiment With Different Models: Don’t limit yourself to the MNIST model. Unity Sentis opens up a world of possibilities, so try different models to see what best fits your project.

Conclusion

Integrating AI into Unity projects using Unity Sentis is not only possible but also remarkably straightforward with the right approach. The process of creating a digit recognition application provided invaluable insights into the power of AI in game development and interactive applications. If you’re looking to enhance your Unity projects with AI, I highly recommend giving Unity Sentis a try.

Let me know your thoughts and what you’d like to see next. See you in the next tutorial!

--

--