In-Browser Inference with a Customized Tensorflow Lite Model and MediaPipe

Rubens Zimbres
Google Cloud - Community
14 min readJul 24, 2024

For me, developing this tutorial was not trivial, because MediaPipe only supports some models and it’s a new library. Google AI Edge was announced at Google I/O 2024 and MediaPipe is on preview: https://ai.google.dev/edge#announcements. Thus, online resources are extremely scarce.

Here, I used a customized Resnet50 in PyTorch to develop an in-browser inference, converting it to a Tensorflow Lite model .tflite using the Google library ai-edge-torch. As the library is quite new, there are still many issues going on. Here, you won’t pass through all this struggle, because I am going to give you only what works. But there is A LOT of code, and A LOT of troubleshooting.

In-browser inference with TensorFlow Lite refers to the execution of machine learning models directly within a web browser using a .tflite model, which is a lightweight version of TensorFlow designed for mobile and embedded devices. This approach leverages WebAssembly (Wasm) and JavaScript APIs to run the models efficiently within the browser environment. As the model is obtained from the browser from a public repository, I opted to quantize it also, to get a smaller sized model and speed up inference.

Key Aspects of In-Browser Inference with TensorFlow Lite

  • TensorFlow Lite for Web: TensorFlow Lite provides support for running models on the web through the TensorFlow.js library. It allows models trained in TensorFlow to be converted to TensorFlow Lite format and then deployed to web applications.
  • WebAssembly (Wasm): TensorFlow Lite uses WebAssembly, a binary instruction format that enables high-performance execution of code on web pages. Wasm allows TensorFlow Lite models to run with near-native performance within the browser.
  • JavaScript APIs: TensorFlow Lite leverages JavaScript APIs to integrate with web applications. Developers can load models, preprocess data, perform inference, and post-process results using JavaScript code.

Advantages of In-Browser Inference with TensorFlow Lite

  • Low Latency: By running inference directly in the browser, there’s no need to send data to a remote server and wait for a response, which significantly reduces latency and cost. This is crucial for real-time applications like interactive web apps, games, and augmented reality.
  • Enhanced Privacy: Since the data is processed locally on the user’s device, it enhances privacy and security by avoiding the need to transmit sensitive and private data over the internet.
  • Offline Capabilities: In-browser inference enables applications to function offline or in environments with limited connectivity. The models and necessary resources can be pre-loaded, allowing the app to operate without an active internet connection.
  • Cross-Platform Compatibility: Web applications inherently run across different platforms (desktop, mobile, tablets) and operating systems (Windows, macOS, Linux, Android, iOS) without requiring platform-specific code changes.
  • Ease of Deployment: Deploying machine learning models in a web application can be simpler compared to mobile or desktop applications. No need for containers, Flask or OpenAPI. Users can access the latest model updates automatically by visiting the web page, eliminating the need for manual app updates.
  • Resource Utilization: Modern web browsers are highly optimized and can efficiently manage resources, allowing even complex models to run smoothly. WebAssembly further enhances performance by providing a low-level execution environment.

Use Cases

  • Interactive Demos and Educational Tools: Running ML models in-browser can create interactive demos and educational tools that are easily accessible and provide instant feedback.
  • Client-Side Text, Image and Video Processing: Applications that involve processing text, images or videos, such as LLMs, object detection, face recognition, or style transfer, can benefit from in-browser inference for real-time performance.
  • Enhanced User Experience: Features like speech recognition, language translation, and personalized content recommendations can be integrated into web apps to improve user experience.

As I said, it was a rough path to make the unsupported model to work. Here are some of the obstacles I found:

  1. The torchvision Resnet50 model receives an input in the format (batch, channels, height, width) like (1,3,224,224). However, MediaPipe expects the input to be in the format (1,224,224,3), (batch, height, width, channels).
  2. If you simply add a top layer at the model to get the input, it will generate a second Tensorflow graph and MediaPipe does not accept more than one graph.
  3. MediaPipe accepts float32 and int8, but converted PyTorch model does not work with int8. If you use float32, you have to normalize the neural network’s output.
  4. You cannot add a signature to the .tflite model, otherwise MediaPipe will not accept two signatures. Keep only the default_serving.
  5. MediaPipe has model compatibility requirements:
  • image input of size [batch x height x width x channels]. Here I added a reshape layer to the PyTorch model.
  • batch inference is not supported (batch is required to be 1).
  • only RGB inputs are supported (channels is required to be 3).
  • if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the metadata for input normalization. Here I normalized the PyTorch model output. kTfLite indicates that it’s related to TensorFlow Lite, Float32 refers to the data type used to represent the numbers.
  • Output tensor with N classes and either 2 or 4 dimensions, i.e. [1 x N] or [1 x 1 x 1 x N]. Here I’m working with [1,1000].
  • Recommended: label map as AssociatedFile with type TENSOR_AXIS_LABELS, containing one label per line. See the example label file. This label file from example has 1001 labels, our model has a [1,1000] output, this must be adjusted.

6. When adding the signature, MediaPipe supported models for image classification work with 1001 classes, but the PyTorch model Resnet50 works with 1000 classes. This must be fixed in the resnet_labels.txt and also in the Python notebook that adds metadata.

7. The MediaPipe model files CSS, JS and HTML in the MediaPipe codepen do not explictly tell you how do you add these files into a single HTML file, what I want. For many of you this must be trivial, but not for me.

8. When you are working in the HTML file in VSCode, all you got is HTML debugger info. If there is a problem with the .tflite graph, signature or anything else, VSCode HTML debugger is useless. I had to find this out, to upload the .tflite model to MediaPipe Studio, to effectively debug the Tensorflow error.

9. The library ai-edge-torch runs on a Tensorflow 2.17.0 environment. However, the Python notebook for metadata creation only runs successfuly on a Tensorflow 2.13.0 environment. So, you need to have both of these Anaconda environments set up.

10. After .tflite was validated and perfect, the HTML didn’t load the model from the bucket. I had to set up a CORS (Cross Origin Resource Sharing) policy in the bucket plus other bucket setups.

11. When finally everything worked, the inference time of the customized model was less than 2 seconds. However, the supported models’ inference time is milliseconds.

THE .tflite MODEL AND CODE

The process of creating a .tflite model from a PyTorch model in quite simple with the library ai-edge-torch. First, we create an Anaconda envioronment with Python 3.10 and install the necessary libraries:

conda create -n TF27 python=3.10
conda activate TF27
pip install ai-edge-torch tensorflow==2.17.0

Then, go to https://pytorch.org/get-started/locally/ and install PyTorch in your Anaconda environment. For Ubuntu 22.04 with GPU is:

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Now we import the libraries:

import torch
import torchvision
import ai_edge_torch
from PIL import Image
import torchvision.transforms as transforms
import tensorflow as tf
import numpy as np
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch._export import capture_pre_autograd_graph
from ai_edge_torch.quantize.pt2e_quantizer import get_symmetric_quantization_config
from ai_edge_torch.quantize.pt2e_quantizer import PT2EQuantizer
from ai_edge_torch.quantize.quant_config import QuantConfig
import json
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

Next, we will load and initialize the Resnet50 model in PyTorch:

resnet50 = torchvision.models.resnet50(torchvision.models.ResNet50_Weights.IMAGENET1K_V1).eval()

As I said, we have to reshape the PyTorch input (1,3,224,224) to what MediaPipe accepts (1,224,224,3):

class PermuteInput(nn.Module):
def __init__(self):
super(PermuteInput, self).__init__()

def forward(self, x):
# Permute from (batch, height, width, channels) to (batch, channels, height, width)
return x.permute(0, 3, 1, 2)

class HandleOutput(nn.Module):
def __init__(self):
super(HandleOutput, self).__init__()

def forward(self, x):
return F.normalize(x)

# Add the custom reshape layer to the model
# Here, we use a Sequential container to append the reshape layer after the adaptive average pooling layer
resnet50_with_reshape = nn.Sequential(
PermuteInput(),
resnet50,
HandleOutput()
)

# Print the modified model architecture
print(resnet50_with_reshape)

The library ai-edge-torch requires a model.eval() to successfully convert this PyTorch model into a .tflite model. Then we test the output:

edge_model = resnet50_with_reshape.eval()
sample_input = (torch.rand((1, 224, 224, 3), dtype=torch.float32),)
edge_model(*sample_input)

And finally convert to .tflite:

edge_model = ai_edge_torch.convert(edge_model.eval(), sample_input)
edge_model.export("/home/user/resnet50.tflite")

Now we will quantize the model, reducing drastically its size, so that it can loads easier in out HMTL page:

pt2e_quantizer = PT2EQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True))

pt2e_torch_model = capture_pre_autograd_graph(resnet50_with_reshape.eval(),sample_input)
pt2e_torch_model = prepare_pt2e(pt2e_torch_model, pt2e_quantizer)

# Run the prepared model with sample input data to ensure that internal observers are populated with correct values
pt2e_torch_model(*sample_input)

# Convert the prepared model to a quantized model
pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)

# Convert to an ai_edge_torch model
pt2e_drq_model = ai_edge_torch.convert(pt2e_torch_model, sample_input, quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer))
pt2e_drq_model.export("/home/user/resnet50_quantized.tflite"

Once you have the quantized model, you can upload it to https://netron.app/ to see its specs and architecture:

If you want to check the output of classification, use this code locally:

########################################## INFERENCE ############################################

# Load the TFLite model and allocate tensors
interpreter = tf.lite.Interpreter(model_path="/home/ai-edge/resnet50.tflite")
interpreter.allocate_tensors()

# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Get input shape
input_shape = input_details[0]['shape']

# Load and preprocess the image
image = Image.open('/home/ai-edge/car.jpeg')

preprocess = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(), # Converts to float32 and scales to [0, 1]
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

img_tensor = preprocess(image) # img_tensor is now a FloatTensor with shape (3, 224, 224)
img_tensor = img_tensor.unsqueeze(0).permute(0,2,3,1) # Shape: (1, 3, 224, 224)

# Convert img_tensor to numpy array and ensure it matches the expected dtype (uint8)
img_numpy = img_tensor.numpy()
img_numpy = img_numpy.astype(np.float32)

# Set the tensor to the interpreter
interpreter.set_tensor(input_details[0]['index'], img_numpy)

# Run inference
interpreter.invoke()

# Get the output tensor
output_data = interpreter.get_tensor(output_details[0]['index'])

# Print the output
print("Output:", output_data)

# Apply softmax to the output data
probabilities = tf.nn.softmax(output_data[0])
# Find the index of the highest probability
predicted_class = np.argmax(probabilities)

# Load class labels - get this from Github
with open('/home/imagenet_class_index.json') as f:
class_idx = json.load(f)

# Now `class_idx` is a dictionary of class names
print(class_idx[str(predicted_class)][1])

If you try to load this pure model on MediaPipe Studio, you will get an error, as it does not contain metadata.

METADATA CREATION FOR .tflite MODEL

Let’s add metadata. Change your Anaconda environment to Tensorflow 2.13.0 and run the .py notebook:

conda create -n TF13 python=3.10
conda activate TF13
pip install tensorflow==2.13.0 tflite-support
from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb
import os

model_file="resnet50_quantized.tflite"
os.chdir("/home/user/model_with_metadata")

"""Creates the metadata for an image classifier."""

# Creates model info.
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "Resnet50 image classifier"
model_meta.description = ("Identify the most prominent object in the "
"image from a set of 1,000 categories such as "
"trees, animals, food, vehicles, person etc.")
model_meta.version = "v1"
model_meta.author = "Rubens Zimbres"
model_meta.license = ("Apache License. Version 2.0 "
"http://www.apache.org/licenses/LICENSE-2.0.")

# Creates input info.
input_meta = _metadata_fb.TensorMetadataT()

# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()

input_meta.name = "Image"
input_meta.description = (
"Input image to be classified. The expected image is {0} x {1}, with "
"three channels (red, blue, and green) per pixel. Each value in the "
"tensor is a single byte between 0 and 255.".format(244, 244))
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = (
_metadata_fb.ColorSpaceType.RGB)
input_meta.content.contentPropertiesType = (
_metadata_fb.ContentProperties.ImageProperties)
input_normalization = _metadata_fb.ProcessUnitT()
input_normalization.optionsType = (
_metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_normalization.options = _metadata_fb.NormalizationOptionsT()
input_normalization.options.mean = [127.5]
input_normalization.options.std = [127.5]
input_meta.processUnits = [input_normalization]
input_stats = _metadata_fb.StatsT()
input_stats.max = [1]
input_stats.min = [0]
input_stats.width = [224]
input_stats.height = [224]
input_stats.num_classes = [1000]
input_meta.stats = input_stats

# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()
output_meta.name = "probability"
output_meta.description = "Probabilities of the 1000 labels respectively."
output_meta.content = _metadata_fb.ContentT()
output_meta.content.content_properties = _metadata_fb.FeaturePropertiesT()
output_meta.content.contentPropertiesType = (
_metadata_fb.ContentProperties.FeatureProperties)
output_stats = _metadata_fb.StatsT()
output_stats.max = [1.0]
output_stats.min = [0.0]
output_meta.stats = output_stats
label_file = _metadata_fb.AssociatedFileT()
label_file.name = os.path.basename("resnet_labels.txt")
label_file.description = "Labels for objects that the model can recognize."
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
output_meta.associatedFiles = [label_file]

# Creates subgraph info.
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = [output_meta]
model_meta.subgraphMetadata = [subgraph]

b = flatbuffers.Builder(0)
b.Finish(
model_meta.Pack(b),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()

populator = _metadata.MetadataPopulator.with_model_file(model_file)
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files(["resnet_labels.txt"])
populator.populate()

## OPTIONAL: check metadata in JSON file

displayer = _metadata.MetadataDisplayer.with_model_file('/home/user/resnet50_quantized.tflite')
export_json_file = "/home/user/metadata.json"
json_file = displayer.get_metadata_json()
# Optional: write out the metadata as a json file
with open(export_json_file, "w") as f:
f.write(json_file)

If you opt for seeing the JSON file with metadata you will see:

The metadata writer notebook will rewrite the resnet50_quantized.tflite file and add the resnet_labels.txt and other metadata. You can check by running:

unzip resnet50_quantized.tflite

Archive: /home/user/resnet50_quantized.tflite
extracting: resnet_labels.txt

Now we can upload to the MediaPipe Studio. Choose “CPU”, upload any image you want. I set the max results to 7.

The customized model predicted the egyptian cat (best choice), the same way as Efficientnet, the default Google model.

FINAL TASK: THE CHALLENGE

Apparently, this last step is quite simple, as MediaPipe provides the HTML, CSS and JS files, available from https://codepen.io/mediapipe-preview/pen/BaVZejK. However, it took me some time to make it work due to the bucket setup.

Use the following code:

 <!DOCTYPE html>
<html lang="en">
<head>
<style>

@use "@material";
body {
font-family: roboto;
margin: 2em;
color: #3d3d3d;
--mdc-theme-primary: #007f8b;
--mdc-theme-on-primary: #f1f3f4;
}

h1 {
color: #007f8b;
}

h2 {
clear: both;
}

video {
clear: both;
display: block;
}

section {
opacity: 1;
transition: opacity 500ms ease-in-out;
}

.mdc-button.mdc-button--raised.removed {
display: none;
}

.removed {
display: none;
}

.invisible {
opacity: 0.2;
}

.videoView,
.classifyOnClick {
position: relative;
float: left;
width: 48%;
margin: 2% 1%;
cursor: pointer;
}

.videoView p,
.classifyOnClick p {
padding: 5px;
background-color: #007f8b;
color: #fff;
z-index: 2;
margin: 0;
}

.highlighter {
background: rgba(0, 255, 0, 0.25);
border: 1px dashed #fff;
z-index: 1;
position: absolute;
}

.classifyOnClick {
z-index: 0;
font-size: calc(8px + 1.2vw);
}

.classifyOnClick img {
width: 100%;
}

.webcamPredictions {
padding-top: 5px;
padding-bottom: 5px;
background-color: #007f8b;
color: #fff;
border: 1px dashed rgba(255, 255, 255, 0.7);
z-index: 2;
margin: 0;
width: 100%;
font-size: calc(8px + 1.2vw);
}


</style>
</head>
<body>

<link href="https://unpkg.com/material-components-web@latest/dist/material-components-web.min.css" rel="stylesheet">


<h1>Classifying images using the MediaPipe Image Classifier Task</h1>

<section id="demos" class="invisible">
<h2>Demo: Classify Images</h2>
<p><b>Click on an image below</b> to see its classification.</p>
<div class="classifyOnClick">
<img src="https://assets.codepen.io/9177687/dog_flickr_publicdomain.jpeg" width="100%" crossorigin="anonymous" title="Click to get classification!" />
<p class="classification removed">
</p>
</div>
<div class="classifyOnClick">
<img src="https://assets.codepen.io/9177687/cat_flickr_publicdomain.jpeg" width="100%" crossorigin="anonymous" title="Click to get classification!" />
<p class="classification removed">
</p>
</div>

<h2>Demo: Webcam continuous classification</h2>
<p>Hold some objects up close to your webcam to get real-time classification. For best results, avoid having too many objects visible to the camera.</br>Click <b>enable webcam</b> below and grant access to the webcam if prompted.</p>

<div class="webcam">
<button id="webcamButton" class="mdc-button mdc-button--raised">
<span class="mdc-button__ripple"></span>
<span class="mdc-button__label">ENABLE WEBCAM</span>
</button>
<video id="webcam" autoplay playsinline></video>
<p id="webcamPredictions" class="webcamPredictions removed"></p>
</div>
</section>

<script src="https://unpkg.com/material-components-web@latest/dist/material-components-web.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"> </script>
<script type="module">

import {
ImageClassifier,
FilesetResolver
} from "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision";

// import { mediapipetasksVision }from "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@0.10.14/+esm"
// Get DOM elements
const video = document.getElementById("webcam");
const webcamPredictions = document.getElementById("webcamPredictions");
const demosSection = document.getElementById("demos") ;
let enableWebcamButton;
let webcamRunning = false;
const videoHeight = "360px";
const videoWidth = "480px";

const imageContainers = document.getElementsByClassName(
"classifyOnClick"
);
let runningMode = "IMAGE";

// Add click event listeners for the img elements.
for (let i = 0; i < imageContainers.length; i++) {
imageContainers[i].children[0].addEventListener("click", handleClick);
}

// Track imageClassifier object and load status.
let imageClassifier;

/**
* Create an ImageClassifier from the given options.
* You can replace the model with a custom one.
*/
const createImageClassifier = async () => {
const vision = await FilesetResolver.forVisionTasks(
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision/wasm"
);
imageClassifier = await ImageClassifier.createFromOptions(vision, {
baseOptions: {
modelAssetPath: `https://storage.googleapis.com/xxxxxxxxxxxxx/resnet50_quantized.tflite`
},
maxResults: 1,
runningMode: runningMode
});

// Show demo section now model is ready to use.
demosSection.classList.remove("invisible");
};
createImageClassifier();

/**
* Demo 1: Classify images on click and display results.
*/
async function handleClick(event) {
// Do not classify if imageClassifier hasn't loaded
if (imageClassifier === undefined) {
return;
}
// if video mode is initialized, set runningMode to image
if (runningMode === "VIDEO") {
runningMode = "IMAGE";
await imageClassifier.setOptions({ runningMode: "IMAGE" });
}

// imageClassifier.classify() returns a promise which, when resolved, is a ClassificationResult object.
// Use the ClassificationResult to print out the results of the prediction.
const classificationResult = imageClassifier.classify(event.target);
// Write the predictions to a new paragraph element and add it to the DOM.
const classifications = classificationResult.classifications;

const p = event.target.parentNode.childNodes[3];
p.className = "classification";
p.innerText =
"Classificaton: " +
classifications[0].categories[0].categoryName +
"\n Confidence: " +
Math.round(parseFloat(classifications[0].categories[0].score) * 100) +
"%";
classificationResult.close();
}

/********************************************************************
// Demo 2: Continuously grab image from webcam stream and classify it.
********************************************************************/

// Check if webcam access is supported.
function hasGetUserMedia() {
return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}

// Get classification from the webcam
async function predictWebcam() {
// Do not classify if imageClassifier hasn't loaded
if (imageClassifier === undefined) {
return;
}
// if image mode is initialized, create a new classifier with video runningMode
if (runningMode === "IMAGE") {
runningMode = "VIDEO";
await imageClassifier.setOptions({ runningMode: "VIDEO" });
}
const startTimeMs = performance.now();
const classificationResult = imageClassifier.classifyForVideo(
video,
startTimeMs
);
video.style.height = videoHeight;
video.style.width = videoWidth;
webcamPredictions.style.width = videoWidth;
const classifications = classificationResult.classifications;
webcamPredictions.className = "webcamPredictions";
webcamPredictions.innerText =
"Classification: " +
classifications[0].categories[0].categoryName +
"\n Confidence: " +
Math.round(parseFloat(classifications[0].categories[0].score) * 100) +
"%";
// Call this function again to keep predicting when the browser is ready.
if (webcamRunning === true) {
window.requestAnimationFrame(predictWebcam);
}
}

// Enable the live webcam view and start classification.
async function enableCam(event) {
if (imageClassifier === undefined) {
return;
}

if (webcamRunning === true) {
webcamRunning = false;
enableWebcamButton.innerText = "ENABLE PREDICTIONS";
} else {
webcamRunning = true;
enableWebcamButton.innerText = "DISABLE PREDICTIONS";
}

// getUsermedia parameters.
const constraints = {
video: true
};

// Activate the webcam stream.
video.srcObject = await navigator.mediaDevices.getUserMedia(constraints);
video.addEventListener("loadeddata", predictWebcam);
}

// If webcam supported, add event listener to button.
if (hasGetUserMedia()) {
enableWebcamButton = document.getElementById("webcamButton");
enableWebcamButton.addEventListener("click", enableCam);
} else {
console.warn("getUserMedia() is not supported by your browser");
}

</script>

</body>
</html>

In order to use our customized model, all we have to do is to replace modelAssetPath: with the path of the customized model uploaded to a Google Cloud bucket: use Regional bucket, DO NOT Enforce public access prevention on this bucket, Fine-grained permissions, where permissions for the bucket and object is AllUsers is set to Viewer and Read:

Bucket:

Bucket permissions configuration

Object:

Object permissions configuration
modelAssetPath: `https://storage.googleapis.com/xxxxxx/resnet50_quantized.tflite`

However, I was not able to make it work only by following these steps, because I got a CORS (Cross Origin Resource Sharing) error, locally in Chrome console debugger that was preventing the HTML page to get the model. This prevented the WebGL to open and run the Graph.

Thankfully, there was a good soul on Stackoverflow, Noam Gaash, that had a solution: enable cross origin support on the Google Cloud Storage bucket. First you create a file cors_file.json:

[
{
"origin": ["https://your-website.com"],
"method": ["GET", "POST"],
"responseHeader": ["Content-Type", "Authorization"],
"maxAgeSeconds": 86400
}
]

Then you run:

gcloud storage buckets update gs://your-bucket-with-.tflite --cors-file=cors_file.json

This will allow the web page to effectively download de .tflite file and run the classification in the browser:

Browser’s console

Acknowledgements

Google ML Developer Programs and Google Cloud Champion Innovators Program supported this work by providing Google Cloud Credits

🔗 https://developers.google.com/machine-learning

🔗 https://cloud.google.com/innovators/champions?hl=en

--

--

Google Cloud - Community
Google Cloud - Community

Published in Google Cloud - Community

A collection of technical articles and blogs published or curated by Google Cloud Developer Advocates. The views expressed are those of the authors and don't necessarily reflect those of Google.

Rubens Zimbres
Rubens Zimbres

Written by Rubens Zimbres

I’m a Senior Data Scientist and Google Developer Expert in ML and GCP. I love studying NLP algos and Cloud Infra. CompTIA Security +. PhD. www.rubenszimbres.phd

Responses (1)