Exploring using YOLOv8n with PyTorch Mobile (Android)

Stephen Cow Chau
12 min readMay 15, 2023

Background

Given the previous experience on using YOLOv8 on Android through converting it from ONNX exported => Tensroflow => Tensorflow Lite (plus the realization on ditching tflite support object detection API and use plaint interpretor), there are following items that I am wondering if directly using PyTorch mobile would be a better option.

The points of consideration are:

  1. I see Tensorflow Lite model would run with XNNPACK in Android to speed up the inference in CPU, but the GPU acceleration is not supported (because some operation in the YOLOv8 implementation are not compatible in Tensorflow Lite GPU operation set, one can check with tensoflow.lite.experimental.Analyzer.analyzer(…) with gpu_compatibility=True), also NNAPI run with crash with my previous work, so maybe there is additional conversion need to be done.
  2. The conversion process is like a surgery (stitching non max suppression) , and when the core model update, there might be some tiny bit and pieces on the surgery code need to update as well. Also a lot of time the ONNX export with main stitching operation works perfectly but the down stream conversion mess things up (e.g. arrangement of output when there are bounding boxes, classes, scores and number of detection mismatch between ONNX and Tensorflow/Lite model).

Start of exploration

According to PyTorch Mobile page, the preparation of model is as simple as following diagram:

Attempt 1— convert direct to TorchScript and run the model

The YOLOv8 implementation provide handy export function, it’s as easy as follow:

model = YOLO("/path/to/model.pt", task="detect")
model.export(format="torchscript")
# the torchscript model would be at the same folder as your model.pt

It’s important to check the output of the model, so I run:

torch_script_model = torch.jit.load(torchscript_model_path)
with torch.inference_mode():
result = torch_script_model(input_tensor)

And the result is a tuple with the first output is tensor with shape [1,6,8400]. The shape sounds familiar, as it’s the pre-NMS (non max suppression) tensor.

Back to our YOLO models, when we run:

model = YOLO("/path/to/model.pt", task="detect")
result = model(input_tensor)

and the result is a very nice Object like following:

{ '_keys': ('boxes', 'masks', 'probs'),
'boxes': ultralytics.yolo.engine.results.Boxes
type: torch.Tensor
shape: torch.Size([9, 6])
dtype: torch.float32
tensor([[ 2.88526e+02, 1.18719e+02, 4.01752e+02, 2.11924e+02, 9.42204e-01, 0.00000e+00],
[ 2.98364e+02, 1.71492e+02, 3.33414e+02, 2.08354e+02, 8.26868e-01, 1.00000e+00],
[ 3.57114e+02, 1.59409e+02, 4.00666e+02, 2.02675e+02, 8.01369e-01, 1.00000e+00],
# ... (omitted)
]),
'masks': None,
'names': {0: 'container', 1: 'bottle'},
'orig_img': tensor([[[[0.13333, 0.12941, 0.11765, ..., 0.09804, 0.09020, 0.09020],
[0.12941, 0.12549, 0.11765, ..., 0.07451, 0.06275, 0.05490],
[0.12941, 0.12157, 0.11765, ..., 0.05882, 0.05098, 0.04314],
...,
# (omitted)
[0.00000, 0.00000, 0.00000, ..., 0.00000, 0.00000, 0.00000],
[0.00000, 0.00000, 0.00000, ..., 0.00000, 0.00000, 0.00000],
[0.00000, 0.00000, 0.00000, ..., 0.00000, 0.00000, 0.00000]]]]),
'orig_shape': torch.Size([1, 3]),
'path': None,
'probs': None,
'speed': {'inference': 94.71845626831055, 'postprocess': 0.8051395416259766, 'preprocess': 1.5766620635986328}}

The reason is YOLO class perform a bunch of work (post processing) after the core model infer, while the exported torch script does NOT include those post process, which include the NMS.

If we run the following on our TorchScript model, we would get back the NMS processed result:

from ultralytics.yolo.utils.ops import non_max_suppression

# just as what we did previously
torch_script_model = torch.jit.load(torchscript_model_path)
with torch.inference_mode():
result = torch_script_model(input_tensor)

nms_result = non_max_suppression(result)
nms_result[0]

# result as follow:
# tensor([[ 2.88526e+02, 1.18719e+02, 4.01752e+02, 2.11924e+02, 9.42204e-01, 0.00000e+00],
# [ 2.98364e+02, 1.71492e+02, 3.33414e+02, 2.08354e+02, 8.26868e-01, 1.00000e+00],
# [ 3.57114e+02, 1.59409e+02, 4.00666e+02, 2.02675e+02, 8.01369e-01, 1.00000e+00],
# [ 3.70031e+02, 2.37801e+02, 4.13906e+02, 2.77529e+02, 7.05622e-01, 1.00000e+00],
# [ 3.22101e+02, 1.62586e+02, 4.58572e+02, 2.97309e+02, 6.52932e-01, 0.00000e+00],
# [ 4.20349e+02, 1.17764e+01, 5.71449e+02, 5.01500e+01, 6.01493e-01, 0.00000e+00],
# [ 1.76297e+02, -3.57257e+00, 2.83910e+02, 3.98838e+01, 3.84342e-01, 0.00000e+00],
# [ 5.12204e+02, 1.41641e+01, 5.40370e+02, 4.03212e+01, 3.52462e-01, 1.00000e+00]])

Attempt 2 — How about using the TorchScript tracing on the YOLO object itself instead of going through export provided by YOLO class?

model = YOLO("/path/to/model.pt", task="detect")
torch.jit.script(model)

we would get non support error…

This is pretty obvious that this approach would not work as he YOLO class is not inheriting from torch module, so we cannot assume it can trace/script with TorchScript like that (as well…otherwise why we have the export function from the YOLO class)

Attempt 3 — Adding non_max_suppression function call into a wrapper model and then convert to TorchScript

We can wrap the core model inside the YOLO class into another torch.nn.Module like following:

import torch.nn as nn
from ultralytics.yolo.utils.ops import non_max_suppression

class WrapperModel_wNMS(nn.Module):
def __init__(self, model: YOLO):
super().__init__()
# model is the YOLO class object itself
self.model = model.predictor.model

def forward(self, input_tensor: torch.Tensor):
x = self.model(input_tensor)
x = non_max_suppression(x)
return x

w_model_nms = WrapperModel_wNMS(model)
# test the model
with torch.inference_mode():
w_result_nms = w_model_nms(input_tensor.to(model.device))
# the w_result_nms[0] is result after NMS, which shape [N, 6]
scripted_wrapped_model_nsm= torch.jit.script(w_model_nms)

Again, we have error in the torch.jit.script:

Intermission — A decision to make

At this point, there could be 2 options

  1. To keep pushing and move most of the non_max_suppression operation into the TorchScript model
  2. To just use the TorchScript model exported with the [1,6,8400] result tensor and process it in Android (for me, using Kotlin)

I went down the first path, but after the whole exercise, I am considering path 2 might reduce some troubles (in Python side, but yet adding some downstream in Android development).

Attempt 4 — Let’s extract some operations from non_max_suppression() and include in the TorchScript

I have chosen path 1, because I am seeing people mentioned there would be torchvision.ops.nms in Pytorch in Android (and it’s not the case)

Given the wrong assumption, I decided to extract the steps just before torchvision.ops.nms in non_max_suppression(), so refering to the implementation, copy the code up to this line:

The code is as below, with:

  1. At the forward function, instead of passing in params for NMS like original non_max_suppression() , I just hard code them in the code
# we need this function as it's being called (it's in the same file)
def xywh2xyxy(x):
"""
Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
top-left corner and (x2, y2) is the bottom-right corner.

Args:
x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
Returns:
y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
return y

# this is the class
class WrapperModel(nn.Module):
def __init__(self, model: YOLO):
super().__init__()
# model is the YOLO class object itself
self.model = model.predictor.model

def forward(self, input_tensor: torch.Tensor, conf_thres=0.25):
multi_label=False
max_time_img=0.05
multi_label=False
labels=()
classes=None
max_nms=30000
max_wh=7680
agnostic=False
# conf_thres=0

prediction = self.model(input_tensor)[0]

# Copy from /ultralytics/yolo/utils/ops.py non_max_suppression
bs = prediction.shape[0] # batch size
nc = (prediction.shape[1] - 4) # number of classes
nm = prediction.shape[1] - nc - 4
mi = 4 + nc # mask start index
xc = prediction[:, 4:mi].amax(1) > conf_thres

# Settings
# min_wh = 2 # (pixels) minimum box width and height

time_limit = 0.5 + max_time_img * bs # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS

t = time.time()
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
print(f"prediction: {xi}")
# Apply constraints
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x.transpose(0, -1)[xc[xi]] # confidence

# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
v[:, :4] = lb[:, 1:5] # box
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
x = torch.cat((x, v), 0)

# If none remain process next image
print(x.shape)
if not x.shape[0]:
continue

# Detections matrix nx6 (xyxy, conf, cls)
box, cls, mask = x.split((4, nc, nm), 1)
box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
if multi_label:
i, j = (cls > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only
conf, j = cls.max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]

# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes

# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores

return x, boxes, scores

This approach work when I use it to infer, but it failed the same way as above.

Attempt 5 — Replace the Pytorch model in wrapper class by the previous YOLO exported TorchScript model

For this implementation, the changes are:

  1. Replaced the passed in model in __init__()
  2. Accordingly, when forward, as the model return only 1 result, so we do prediction = self.model(input_tensor) [while previously we need to take index 0 result]
  3. Also we remove quite some conditioning code which I am not using (as Torch JIT script do NOT like any condition)
  4. Finally, as the inference is only apply to 1 image at a time, so the function return after processed the first result (see indentation of return, compare to previous attempt), this is a trial and error approach to match what the JIT can work
class WrapperModel2(nn.Module):
def __init__(self, model: torch.jit._script.RecursiveScriptModule):
super().__init__()
# model is the YOLO exported torchscript model
self.model = model

def forward(self, input_tensor: torch.Tensor, conf_thres=0.25):
multi_label=False
max_time_img=0.05
multi_label=False
labels=()
classes=None
max_nms=30000
max_wh=7680
agnostic=False
# conf_thres=0

prediction = self.model(input_tensor)

# Copy from /ultralytics/yolo/utils/ops.py non_max_suppression
bs = prediction.shape[0] # batch size
nc = (prediction.shape[1] - 4) # number of classes
nm = prediction.shape[1] - nc - 4
mi = 4 + nc # mask start index
xc = prediction[:, 4:mi].amax(1) > conf_thres

# Settings
# min_wh = 2 # (pixels) minimum box width and height

time_limit = 0.5 + max_time_img * bs # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS

# t = time.time()
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
print(f"prediction: {xi}")
# Apply constraints
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x.transpose(0, -1)[xc[xi]] # confidence

# Cat apriori labels if autolabelling
# if labels and len(labels[xi]):
# lb = labels[xi]
# v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
# v[:, :4] = lb[:, 1:5] # box
# v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
# x = torch.cat((x, v), 0)

# If none remain process next image
print(x.shape)
if not x.shape[0]:
continue

# Detections matrix nx6 (xyxy, conf, cls)
box, cls, mask = x.split((4, nc, nm), 1)
box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
# if multi_label:
# i, j = (cls > conf_thres).nonzero(as_tuple=False).T
# x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
# else: # best class only
# conf, j = cls.max(1, keepdim=True)
# x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

# * This is the "else" condition above
conf, j = cls.max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]

# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes

# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores

# forcefully return on the first prediction (which is OK for inference)
return x , boxes, scores # actually only need x, while the others could be transformed in Android side from x

In attempt 4 and 5, we could verify the result as follow:

import torchvision
w_model_2 = WrapperModel2(torch_script_model) # or in attempt 5, using WrapperModel(model)

with torch.inference_mode():
w_p2, w_boxes2, w_scores2 = w_model_2(input_tensor)

selected2 = torchvision.ops.nms(w_boxes2, w_scores2, 0.45)
w_p2[selected2]

# we are expecting to see result like follow:
# tensor([[ 2.88526e+02, 1.18719e+02, 4.01752e+02, 2.11924e+02, 9.42204e-01, 0.00000e+00],
# [ 2.98364e+02, 1.71492e+02, 3.33414e+02, 2.08354e+02, 8.26868e-01, 1.00000e+00],
# [ 3.57114e+02, 1.59409e+02, 4.00666e+02, 2.02675e+02, 8.01369e-01, 1.00000e+00],
# [ 3.70031e+02, 2.37801e+02, 4.13906e+02, 2.77529e+02, 7.05622e-01, 1.00000e+00],
# [ 3.22101e+02, 1.62586e+02, 4.58572e+02, 2.97309e+02, 6.52932e-01, 0.00000e+00],
# [ 4.20349e+02, 1.17764e+01, 5.71449e+02, 5.01500e+01, 6.01493e-01, 0.00000e+00],
# [ 1.76297e+02, -3.57257e+00, 2.83910e+02, 3.98838e+01, 3.84342e-01, 0.00000e+00],
# [ 5.12204e+02, 1.41641e+01, 5.40370e+02, 4.03212e+01, 3.52462e-01, 1.00000e+00]])

Anyhow, finally this attempt 5 survive the JIT export

from torch.utils.mobile_optimizer import optimize_for_mobile
scripted_wrapped_model2 = torch.jit.script(w_model_2)
optimized_torchscript_b4_nms_model = optimize_for_mobile(scripted_wrapped_model2)
optimized_torchscript_b4_nms_model.save(optimized_b4_nms_model_path)

Android Side

For a minimal test, the core code to load the model (the model and the jpg to test is put in the assets folder)

Note that I am using:

org.pytorch:pytorch_android:1.13.1 and org.pytorch:pytorch_android_torchvision:1.13.1

instead of the lite version:

org.pytorch:pytorch_android_lite:1.13.1 and org.pytorch:pytorch_android_torchvision_lite:1.13.1

// some key imports related, not the only
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import org.pytorch.IValue
import org.pytorch.Module
import org.pytorch.torchvision.TensorImageUtils

fun assetFilePath(context: Context, assetName: String): String? {
val file = File(context.filesDir, assetName)

try {
context.assets.open(assetName).use { `is` ->
FileOutputStream(file).use { os ->
val buffer = ByteArray(4 * 1024)
while (true) {
val length = `is`.read(buffer)
if (length <= 0)
break
os.write(buffer, 0, length)
}
os.flush()
os.close()
}
return file.absolutePath
}
} catch (e: IOException) {
Log.e("pytorchandroid", "Error process asset $assetName to file path")
}

return null
}

// ... I put this code piece in Activity OnCreate(...)
val bitmap = BitmapFactory.decodeStream(assets.open("30.jpg"))
val moduleFilePath = assetFilePath(this, "best_b4nms_optimized.pth")
val module = Module.load(moduleFilePath)

val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB
)
// the model have 3 outputs, so make it into Tuple
val (x, boxes, scores) = module.forward(IValue.from(inputTensor)).toTuple()
// This is a self implemented NMS...(which if torchvision provide this, we no longer need to implemenet ourselves)
var detResult = nms(x.toTensor(), 0.45f) // the 0.45 is IoU threshold

The NMS implementation, note that, NMS implementation normally would first sort the result by score in descending order before doing box removal, but that’s being implemented in the TorchScript wrapper model in Python and so I safely assume

import android.graphics.RectF
import org.pytorch.Tensor

data class DetectResult(
val boundingBox: RectF,
val classId: Int,
val score: Float,
)

fun nms(x: Tensor, threshold: Float): List<DetectResult> {
// x: [0:4] - box, [4] - score, [5] - class
val data = x.dataAsFloatArray
val numElem = x.shape()[0].toInt()
val innerShape = x.shape()[1].toInt()
val selected_indices = (0 until numElem).toMutableList()

val scores = data.sliceArray( (0 until numElem).flatMap { r->(r*innerShape)+4 until (r*innerShape)+5 } )
val boxes = data.sliceArray( (0 until numElem).flatMap { r->(r*innerShape) until (r*innerShape)+4 } )
val classes = data.sliceArray( (0 until numElem).flatMap { r->(r*innerShape)+5 until (r*innerShape)+6 } )

for (i in 0 until numElem) {
val current_class = classes[i].toInt()
for (j in i+1 until numElem) {
val box_i = boxes.sliceArray(i*4 until (i*4)+4)
val box_j = boxes.sliceArray(j*4 until (j*4)+4)
val iou = calculate_iou(box_i, box_j)
if (iou > threshold && classes[j].toInt() == current_class) {
if (scores[j] > scores[i]) {
selected_indices.remove(i)
break
} else {
selected_indices.remove(j)
}
}
}
}

val result = mutableListOf<DetectResult>()
for (i in 0 until numElem) {
if (selected_indices.contains(i)) {
val box = boxes.slice((i*4) until (i*4)+4)
val detection = DetectResult(boundingBox = RectF(box[0], box[1], box[2], box[3]), score = scores[i], classId = classes[i].toInt())
result.add(detection)
}
}

return result
}

fun calculate_iou(box1: FloatArray, box2: FloatArray): Float {
val x1 = maxOf(box1[0], box2[0])
val y1 = maxOf(box1[1], box2[1])
val x2 = minOf(box1[2], box2[2])
val y2 = minOf(box1[3], box2[3])

val intersection = maxOf(0f, x2 - x1) * maxOf(0f, y2 - y1)
val area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
val area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
val union = area1 + area2 - intersection

return intersection / union
}

Some learning implementing tensor slicing in Android side

When we process multidimensional tensor in Python, either using PyTorch or Numpy, there are alot of slicing help with the library (i.e. the bracket syntax like [ : , 4:5 ] to take score “column” )

In Android, after taking the tensor to float array (with dataAsFloatArray member), we would get a 1D array, and that’s where I like Kotlin (over Java) in which the slice and sliceArray can pass in an array of index to pick elements, just like how we do in Numpy and PyTorch.

Other alternative might be converting them using other vector libraries (maybe kotlin-numpy or multik ?)

Conclusion and what’s next

In above, we have explored the approach on how to take the YOLOv8 model using PyTorch mobile on Android.

I am going to explore next:

  1. Using raw YOLO exported torch script model, and update Android implementation to take care the whole postprocess
  2. Using the pytorch android lite instead of pytorch android package and compare the performance
  3. Compare performance of the same model between TFLite and PyTorch mobile
  4. Explore NNAPI usage for PyTorch mobile
  5. Using ONNX runtime on mobile and compare performance

BECOME a WRITER at MLearning.ai //FREE ML Tools// Clearview AI

--

--