博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Pytorch保存和加载模型
阅读量:4136 次
发布时间:2019-05-25

本文共 3874 字,大约阅读时间需要 12 分钟。

文章目录

一、保存加载模型基本用法

1、保存加载整个模型(不推荐)

保存整个网络模型(网络结构+权重参数)。

torch.save(model, 'net.pkl')

直接加载整个网络模型(可能比较耗时)。

model = torch.load('net.pkl')

2、只保存加载模型参数(推荐)

只保存模型的权重参数(速度快,占内存少)。

torch.save(model.state_dict(), 'net_params.pkl')

因为我们只保存了模型的参数,所以需要先定义一个网络对象,然后再加载模型参数。

# 构建一个网络结构model = ClassNet()# 将模型参数加载到新模型中state_dict = torch.load('net_params.pkl')model.load_state_dict(state_dict)

二、保存加载自定义模型

上面保存加载的 net.pkl 其实一个字典,通常包含如下内容:

  1. 网络结构:输入尺寸、输出尺寸以及隐藏层信息,以便能够在加载时重建模型。
  2. 模型的权重参数:包含各网络层训练后的可学习参数,可以在模型实例上调用 state_dict()方法来获取,比如前面介绍只保存模型权重参数时用到的 model.state_dict()
  3. 优化器参数:有时保存模型的参数需要稍后接着训练,那么就必须保存优化器的状态和所其使用的超参数,也是在优化器实例上调用 state_dict() 方法来获取这些参数。
  4. 其他信息:有时我们需要保存一些其他的信息,比如 epochbatch_size 等超参数。

知道了这些,那么我们就可以自定义需要保存的内容,比如:

# saving a checkpoint assuming the network class named ClassNetcheckpoint = {
'model': ClassNet(), 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch}torch.save(checkpoint, 'checkpoint.pkl')

上面的 checkpoint 是个字典,里面有4个键值对,分别表示网络模型的不同信息。

然后我们要加载上面保存的自定义的模型:

def load_checkpoint(filepath):    checkpoint = torch.load(filepath)    model = checkpoint['model']  # 提取网络结构    model.load_state_dict(checkpoint['model_state_dict'])  # 加载网络权重参数    optimizer = TheOptimizerClass()    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  # 加载优化器参数        for parameter in model.parameters():        parameter.requires_grad = False    model.eval()        return model    model = load_checkpoint('checkpoint.pkl')

三、跨设备保存加载模型

1、在 CPU 上加载在 GPU 上训练并保存的模型(Save on GPU, Load on CPU):

device = torch.device('cpu')model = TheModelClass()# Load all tensors onto the CPU devicemodel.load_state_dict(torch.load('net_params.pkl', map_location=device))

map_location:a function, torch.device, string or a dict specifying how to remap storage locations

torch.load() 函数的 map_location 参数等于 torch.device('cpu') 即可。 这里令 map_location 参数等于 'cpu' 也同样可以。

2、在 GPU 上加载在 GPU 上训练并保存的模型(Save on GPU, Load on GPU):

device = torch.device("cuda")model = TheModelClass()model.load_state_dict(torch.load('net_params.pkl'))model.to(device)

在这里使用 map_location 参数不起作用,要使用 model.to(torch.device("cuda")) 将模型转换为CUDA优化的模型。

还需要对将要输入模型的数据调用 data = data.to(device),即将数据从CPU转移到GPU。请注意,调用 my_tensor.to(device) 会返回一个 my_tensor 在 GPU 上的副本,它不会覆盖 my_tensor。因此需要手动覆盖张量:my_tensor = my_tensor.to(device)

3、在 GPU 上加载在 GPU 上训练并保存的模型(Save on CPU, Load on GPU)

device = torch.device("cuda")model = TheModelClass()model.load_state_dict(torch.load('net_params.pkl', map_location="cuda:0"))model.to(device)

当加载包含GPU tensors的模型时,这些tensors 会被默认加载到GPU上,不过是同一个GPU设备。

当有多个GPU设备时,可以通过将 map_location 设定为 *cuda:device_id* 来指定使用哪一个GPU设备,上面例子是指定编号为0的GPU设备。

其实也可以将 torch.device("cuda") 改为 torch.device("cuda:0") 来指定编号为0的GPU设备。

最后调用 model.to(torch.device('cuda')) 来将模型的tensors转换为 CUDA tensors。

下面是PyTorch官方文档上的用法,可以进行参考:

>>> torch.load('tensors.pt')# Load all tensors onto the CPU>>> torch.load('tensors.pt', map_location=torch.device('cpu'))# Load all tensors onto the CPU, using a function>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)# Load all tensors onto GPU 1>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))# Map tensors from GPU 1 to GPU 0>>> torch.load('tensors.pt', map_location={
'cuda:1':'cuda:0'})

四、CUDA 的用法

在PyTorch中和GPU相关的几个函数:

import torch# 判断cuda是否可用;print(torch.cuda.is_available())# 获取gpu数量;print(torch.cuda.device_count())# 获取gpu名字;print(torch.cuda.get_device_name(0))# 返回当前gpu设备索引,默认从0开始;print(torch.cuda.current_device())

有时我们需要把数据和模型从cpu移到gpu中,有以下两种方法:

use_cuda = torch.cuda.is_available()# 方法一:if use_cuda:    data = data.cuda()    model.cuda()# 方法二:device = torch.device("cuda" if use_cuda else "cpu")data = data.to(device)model.to(device)

个人比较习惯第二种方法,可以少一个 if 语句。而且该方法还可以通过设备号指定使用哪个GPU设备,比如使用0号设备:

device = torch.device("cuda:0" if use_cuda else "cpu")

参考

转载地址:http://csvvi.baihongyu.com/

你可能感兴趣的文章
将有序数组转换为平衡二叉搜索树
查看>>
最长递增子序列
查看>>
从一列数中筛除尽可能少的数,使得从左往右看这些数是从小到大再从大到小...
查看>>
判断一个整数是否是回文数
查看>>
经典shell面试题整理
查看>>
腾讯的一道面试题—不用除法求数字乘积
查看>>
素数算法
查看>>
java多线程环境单例模式实现详解
查看>>
将一个数插入到有序的数列中,插入后的数列仍然有序
查看>>
在有序的数列中查找某数,若该数在此数列中,则输出它所在的位置,否则输出no found
查看>>
万年历
查看>>
作为码农你希望面试官当场指出你错误么?有面试官这样遭到投诉!
查看>>
好多程序员都认为写ppt是很虚的技能,可事实真的是这样么?
查看>>
如果按照代码行数发薪水会怎样?码农:我能刷到公司破产!
查看>>
程序员失误造成服务停用3小时,只得到半月辞退补偿,发帖喊冤
查看>>
码农:很多人称我“技术”,感觉这是不尊重!纠正无果后果断辞职
查看>>
php程序员看过来,这老外是在吐糟你吗?看看你中了几点!
查看>>
为什么说程序员是“培训班出来的”就是鄙视呢?
查看>>
码农吐糟同事:写代码低调点不行么?空格回车键与你有仇吗?
查看>>
阿里p8程序员四年提交6000次代码的确有功,但一次错误让人唏嘘!
查看>>