Skip to main content
Version: 最新版本(unreleased)

Implement inference logic

After the project is created, users only need to complete the following tasks to complete the adaptation and deployment of the inference service on athenaServing.

  • Fill in the input and output of the model service
  • Write the inference process of the service model (currently only non-streaming is supported)

Note: The key that needs to be modified in the wrapper.py needs to be the same as the key value of the creation ability.

Example: MNIST Handwritten Digit Recognition Via Pytorch

1. Clone the example model service repository

git clone https://github.com/sea-wyq/pytorch-MNIST.git

root@34c3e20d9572:/home# tree pytorch-MNIST/
pytorch-MNIST/
|-- 0.png # Picture of handwritten digits used for test reasoning
|-- README.md
|-- Model.py # Service model structure
|-- inference.py # Model inference logic
|-- requirements.txt # Service dependencies
`-- train.py # model training logic

Executing python3 pytorch-MNIST/train.py will get the model file mnist.pkl and the mnist dataset data.

root@34c3e20d9572:/home# tree pytorch-MNIST/
pytorch-MNIST/
|-- 0.png
|-- Model.py
|-- README.md
|-- data
| `-- MNIST
| `-- raw
| |-- t10k-images-idx3-ubyte
| |-- t10k-images-idx3-ubyte.gz
| |-- t10k-labels-idx1-ubyte
| |-- t10k-labels-idx1-ubyte.gz
| |-- train-images-idx3-ubyte
| |-- train-images-idx3-ubyte.gz
| |-- train-labels-idx1-ubyte
| `-- train-labels-idx1-ubyte.gz
|-- inference.py
|-- mnist.pkl
|-- requirements.txt
`-- train.py

2. View the inference logic of the service (inference.py)

import cv2
import torch
import torchvision.transforms as transforms

from Model import MNIST


def images2tensor(image):
img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
transf = transforms.ToTensor()
img_tensor = torch.unsqueeze(transf(img), dim=0)
return img_tensor


if __name__ == "__main__":
device = torch.device('cpu')
model = MNIST().to(device)
model.load_state_dict(torch.load('mnist.pkl')) # Load the service model
input_data = images2tensor("0.png") # load test data
res = model(input_data)
print("手写数字图片检测的结果为:", res.argmax()) # print inference results

3. Install service dependencies (requirements.txt)

cd pytorch-MNIST
pip3 install -r requirements.txt

4. Adapt to the mnist project structure by aiges

  • Move 0.png to mnist/wrapper/test_data directory
  • Put the mnist.pkl and Model.py files in the mnist/wrapper directory
  • Add requirements.txt in pytorch-MNIST directory to mnist/requirements.txt
  • Write Dockerfile (introduced in the next chapter)
  • Write the inference logic of wrapper.py in the wrapper directory

The final mnist directory structure is:

root@34c3e20d9572:/home# tree mnist/
mnist/
|-- Dockerfile
|-- README.md
|-- requirements.txt
`-- wrapper
|-- mnist.pkl
|-- Model.py
|-- test_data
| `-- 0.png
`-- wrapper.py

5. Write wrapper.py

#!/usr/bin/env python
# coding:utf-8
"""
@license: Apache License2
@file: wrapper.py
@time: 2022-08-19 02:05:07.467170
@project: mnist
@project: ./
"""


import sys
import hashlib
try:
from aiges_embed import ResponseData, Response, DataListNode, DataListCls # c++
except:
from aiges.dto import Response, ResponseData, DataListNode, DataListCls,Once
from aiges.sdk import WrapperBase, \
StringParamField, \
ImageBodyField, \
StringBodyField, \
AudioBodyField
from aiges.utils.log import log
from aiges.types import *


# Import dependencies in inference.py
import cv2
import torch
import torchvision.transforms as transforms
from Model import MNIST
import numpy as np


# Define the hyperparameters and input parameters of the model
class UserRequest(object):

input1 = ImageBodyField(key="img", path="test_data/0.png")

# Define the output parameters of the model
class UserResponse(object):
accept1 = StringBodyField(key="number")

# Define service inference logic
class Wrapper(WrapperBase):
serviceId = "mnist"
version = "v1"
requestCls = UserRequest()
responseCls = UserResponse()
model = None

# load model file
def wrapperInit(cls, config: {}) -> int:
log.info("Initializing ...")
device = torch.device('cpu')
cls.model = MNIST().to(device)
cls.model.load_state_dict(torch.load('mnist.pkl'))
return 0

# write inference logic
def wrapperOnceExec(cls, params: {}, reqData: DataListCls) -> Response:

# Read the test image and perform model inference
log.info("got reqdata , %s" % reqData.list)
imagebytes = reqData.get("img").data # read binary data
image = [cv2.imdecode(np.frombuffer(imagebytes, np.uint8), cv2.COLOR_BGR2GRAY)]
image_tensor = torch.unsqueeze(torch.Tensor(image), dim=0)
result = cls.model(image_tensor).argmax()
print(result)

# Use Response to encapsulate inference results
res = Response()
resd = ResponseData()
resd.key = "img"
resd.type = DataText
resd.status = Once
resd.data = memoryview(result.numpy().tobytes())
resd.len = 1
res.list = [resd]

def wrapperFini(cls) -> int:
return 0


def wrapperError(cls, ret: int) -> str:
if ret == 100:
return "user error defined here"
return ""

'''
This function is reserved for testing and cannot be deleted
'''

def wrapperTestFunc(cls, data: [], respData: []):
pass

if __name__ == '__main__':
m = Wrapper()
m.run()

6. Test wrapper.py

root@34c3e20d9572:/home/mnist/wrapper# python3 wrapper.py 
Found proto: WrapperBase
Found proto: Wrapper
2022-08-19 03:32:06,707 - root - INFO - Initializing ...
2022-08-19 03:32:06,885 - root - INFO - got reqdata , [<aiges.dto.DataListNode object at 0x7f6bfc7beda0>]
/home/mnist/wrapper/wrapper.py:64: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:204.)
image_tensor = torch.unsqueeze(torch.Tensor(image), dim=0)
[W NNPACK.cpp:51] Could not initialize NNPACK! Reason: Unsupported hardware.
tensor(0)
['img']

full code for mnist