跳到主要内容
版本:release-2.2.0

实现推理逻辑

在项目创建后,用户只需要完成下面的任务即可完成推理服务的在athenaServing上适配和部署。

  • 填写模型服务的输入输出
  • 编写模型的推理服务流程(目前仅支持非流式)

注: wrapper.py中需要修改的key需要与创建能力的key值相同

示例:通过Pytorch实现MNIST手写数字识别

1. 克隆示例模型服务仓库

git clone https://github.com/sea-wyq/pytorch-MNIST.git

root@34c3e20d9572:/home# tree pytorch-MNIST/
pytorch-MNIST/
|-- 0.png # 测试推理使用的手写数字图片
|-- README.md
|-- Model.py # 服务模型结构
|-- inference.py # 模型推理逻辑
|-- requirements.txt # 服务依赖的包
`-- train.py # 模型训练逻辑

执行 python3 pytorch-MNIST/train.py会获得模型文件mnist.pkl和mnist数据集data

root@34c3e20d9572:/home# tree pytorch-MNIST/
pytorch-MNIST/
|-- 0.png
|-- Model.py
|-- README.md
|-- data
| `-- MNIST
| `-- raw
| |-- t10k-images-idx3-ubyte
| |-- t10k-images-idx3-ubyte.gz
| |-- t10k-labels-idx1-ubyte
| |-- t10k-labels-idx1-ubyte.gz
| |-- train-images-idx3-ubyte
| |-- train-images-idx3-ubyte.gz
| |-- train-labels-idx1-ubyte
| `-- train-labels-idx1-ubyte.gz
|-- inference.py
|-- mnist.pkl
|-- requirements.txt
`-- train.py

2. 查看服务的推理逻辑(inference.py)

import cv2
import torch
import torchvision.transforms as transforms

from Model import MNIST


def images2tensor(image):
img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
transf = transforms.ToTensor()
img_tensor = torch.unsqueeze(transf(img), dim=0)
return img_tensor


if __name__ == "__main__":
device = torch.device('cpu')
model = MNIST().to(device)
model.load_state_dict(torch.load('mnist.pkl')) # 加载服务模型
input_data = images2tensor("0.png") # 加载测试数据
res = model(input_data)
print("手写数字图片检测的结果为:", res.argmax()) # 打印推理结果

3. 安装服务依赖(requirements.txt)

cd pytorch-MNIST
pip3 install -r requirements.txt

4. 适配通过aiges构建的mnist项目结构

  • 将0.png移动到mnist/wrapper/test_data目录下
  • 将mnist.pkl和Model.py文件放到mnist/wrapper目录下
  • 将pytorch-MNIST目录下的requirements.txt添加到mnist/requirements.txt中
  • 编写Dockerfile(下章节介绍)
  • 编写wrapper目录下的wrapper.py推理逻辑

最终的mnist目录结构为:

root@34c3e20d9572:/home# tree mnist/
mnist/
|-- Dockerfile
|-- README.md
|-- requirements.txt
`-- wrapper
|-- mnist.pkl
|-- Model.py
|-- test_data
| `-- 0.png
`-- wrapper.py

5. 编写wrapper.py

#!/usr/bin/env python
# coding:utf-8
"""
@license: Apache License2
@file: wrapper.py
@time: 2022-08-19 02:05:07.467170
@project: mnist
@project: ./
"""

import sys
import hashlib
try:
from aiges_embed import ResponseData, Response, DataListNode, DataListCls # c++
except:
from aiges.dto import Response, ResponseData, DataListNode, DataListCls,Once, DataText

from aiges.sdk import WrapperBase, \
StringParamField, \
ImageBodyField, \
StringBodyField
from aiges.utils.log import log


# 导入inference.py中的依赖包
import cv2
import torch
import torchvision.transforms as transforms
from Model import MNIST
import numpy as np


# 定义模型的超参数和输入参数
class UserRequest(object):

input1 = ImageBodyField(key="img", path="test_data/0.png")

# 定义模型的输出参数
class UserResponse(object):
accept1 = StringBodyField(key="number")

# 定义服务推理逻辑
class Wrapper(WrapperBase):
serviceId = "mnist"
version = "v1"
requestCls = UserRequest()
responseCls = UserResponse()
model = None

# 加载模型文件
def wrapperInit(cls, config: {}) -> int:
log.info("Initializing ...")
device = torch.device('cpu')
cls.model = MNIST().to(device)
cls.model.load_state_dict(torch.load('mnist.pkl'))
return 0

# 编写推理逻辑
def wrapperOnceExec(cls, params: {}, reqData: DataListCls) -> Response:

# 读取测试图片并进行模型推理
log.info("got reqdata , %s" % reqData.list)
imagebytes = reqData.get("img").data # 读取到的是二进制数据
image = [cv2.imdecode(np.frombuffer(imagebytes, np.uint8), cv2.COLOR_BGR2GRAY)]
image_tensor = torch.unsqueeze(torch.Tensor(image), dim=0)
result = cls.model(image_tensor).argmax()
print(result)

# 使用Response封装result
res = Response()
resd = ResponseData()
resd.key = "img"
resd.type = DataText
resd.status = Once
resd.data = result
resd.len = 1
res.list = [resd]

def wrapperFini(cls) -> int:
return 0


def wrapperError(cls, ret: int) -> str:
if ret == 100:
return "user error defined here"
return ""

'''
此函数保留测试用,不可删除
'''

def wrapperTestFunc(cls, data: [], respData: []):
pass

if __name__ == '__main__':
m = Wrapper()
m.run()

6. 测试wrapper.py

root@34c3e20d9572:/home/mnist/wrapper# python3 wrapper.py 
Found proto: WrapperBase
Found proto: Wrapper
2022-08-19 03:32:06,707 - root - INFO - Initializing ...
2022-08-19 03:32:06,885 - root - INFO - got reqdata , [<aiges.dto.DataListNode object at 0x7f6bfc7beda0>]
/home/mnist/wrapper/wrapper.py:64: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:204.)
image_tensor = torch.unsqueeze(torch.Tensor(image), dim=0)
[W NNPACK.cpp:51] Could not initialize NNPACK! Reason: Unsupported hardware.
tensor(0)
['img']

mnist完整代码