您好,欢迎来到飒榕旅游知识分享网。
搜索
您的当前位置:首页pytorch框架--简单模型预测

pytorch框架--简单模型预测

来源:飒榕旅游知识分享网

模型预测示例

使用训练好的模型进行预测

import torchvision

from model import Tudui
import torch
from PIL import Image

# 读取图像
img = Image.open("./data/train/Dog/9.jpg")
# 数据预处理

# 缩放
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
                                            torchvision.transforms.ToTensor()])
image = transform(img)
print(image.shape)

# 根据保存方式加载
model = torch.load("tudui_99.pth", map_location=torch.device('cpu'))

# 注意维度转换,单张图片
image1 = torch.reshape(image, (1, 3, 32, 32))

# 测试开关
model.eval()
# 节约性能
with torch.no_grad():
    output = model(image1)
print(output)
# print(output.argmax(1))
# 定义类别对应字典
dist = {0: "飞机", 1: "汽车", 2: "鸟", 3: "猫", 4: "鹿", 5: "狗", 6: "青蛙", 7: "马", 8: "船", 9: "卡车"}
# 转numpy格式,列表内取第一个
a = dist[output.argmax(1).numpy()[0]]
img.show()
print(a)

因篇幅问题不能全部显示,请点此查看更多更全内容

Copyright © 2019- sarr.cn 版权所有 赣ICP备2024042794号-1

违法及侵权请联系:TEL:199 1889 7713 E-MAIL:2724546146@qq.com

本站由北京市万商天勤律师事务所王兴未律师提供法律服务