您好,欢迎来到飒榕旅游知识分享网。
搜索
您的当前位置:首页pytorch框架--网络方面--pytorch自带模型(增、改)

pytorch框架--网络方面--pytorch自带模型(增、改)

来源:飒榕旅游知识分享网

自带模型的增、改

import torchvision
from torch import nn

# 加载vgg16网络模型,pretrained 是否使用优质网络的参数,并不是权重参数
vgg16_f = torchvision.models.vgg16(pretrained=False)
# 加载vgg16网络模型,pretrained 是否使用优质参数
vgg16_t = torchvision.models.vgg16(pretrained=True)

print(vgg16_t)

train_data = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor)

# 增加一层 网络 当前为全连接层
vgg16_f.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_f)

# 改层
vgg16_t.classifier[6] = nn.Linear(4096, 10)
print(vgg16_t)

因篇幅问题不能全部显示,请点此查看更多更全内容

Copyright © 2019- sarr.cn 版权所有 赣ICP备2024042794号-1

违法及侵权请联系:TEL:199 1889 7713 E-MAIL:2724546146@qq.com

本站由北京市万商天勤律师事务所王兴未律师提供法律服务