[python人工智能] 十一.tensorflow如何保存神经网络参数 丨【百变ai秀】-4008云顶国际网站

eastmount 发表于 2021/09/10 10:41:19 2021/09/10
【摘要】 该系列主要研究python深度学习、神经网络及人工智能相关知识。这篇文章将讲解tensorflow如何保存变量和神经网络参数,通过saver保存神经网络,再通过restore调用训练好的神经网络。基础性文章,希望对您有所帮助。

从本专栏开始,作者正式开始研究python深度学习、神经网络及人工智能相关知识。前一篇详细讲解了tensorflow opencv实现cnn自定义图像分类案例,它能解决我们现实论文或实践中的图像分类问题,并与机器学习的图像分类算法进行对比实验。这篇文章将讲解tensorflow如何保存变量和神经网络参数,通过saver保存神经网络,再通过restore调用训练好的神经网络。

本专栏主要结合作者之前的博客、ai经验和相关文章及论文介绍,后面随着深入会讲解更多的python人工智能案例及应用。基础性文章,希望对您有所帮助,如果文章中存在错误或不足之处,还请海涵~作者作为人工智能的菜鸟,希望大家能与我在这一笔一划的博客中成长起来。

代码下载地址(欢迎大家关注点赞):



通过tf.variable()定义权重和偏置变量,然后调用tf.train.saver()存储变量,将数据保存至本地“my_net/save_net.ckpt”文件中。

# -*- coding: utf-8 -*-
"""
created on thu jan  2 20:04:57 2020
@author: xiuzhang eastmount csdn
"""
import tensorflow as tf
import numpy as np
#---------------------------------------保存文件---------------------------------------
w = tf.variable([[1,2,3], [3,4,5]], dtype=tf.float32, name='weights') #2行3列的数据
b = tf.variable([[1,2,3]], dtype=tf.float32, name='biases')
# 初始化
init = tf.initialize_all_variables()
# 定义saver 存储各种变量
saver = tf.train.saver()
# 使用session运行初始化
with tf.session() as sess:
    sess.run(init)
    # 保存 官方保存格式为ckpt
    save_path = saver.save(sess, "my_net/save_net.ckpt")
    print("save to path:", save_path)

“save to path: my_net/save_net.ckpt”保存成功如下图所示:

打开内容如下图所示:

接着定义标记变量train,通过restore操作使用我们保存好的变量。注意,在restore时需要定义相同的dtype和shape,不需要再定义init。最后直接通过 saver.restore(sess, “my_net/save_net.ckpt”) 提取保存的变量并输出即可。

# -*- coding: utf-8 -*-
"""
created on thu jan  2 20:04:57 2020
@author: xiuzhang eastmount csdn
"""
import tensorflow as tf
import numpy as np
# 标记变量
train = false
#---------------------------------------保存文件---------------------------------------
# save
if train==true:
    # 定义变量
    w = tf.variable([[1,2,3], [3,4,5]], dtype=tf.float32, name='weights') #2行3列的数据
    b = tf.variable([[1,2,3]], dtype=tf.float32, name='biases')
    # 初始化
    init = tf.global_variables_initializer()
    
    # 定义saver 存储各种变量
    saver = tf.train.saver()
    
    # 使用session运行初始化
    with tf.session() as sess:
        sess.run(init)
        # 保存 官方保存格式为ckpt
        save_path = saver.save(sess, "my_net/save_net.ckpt")
        print("save to path:", save_path)
#---------------------------------------restore变量-------------------------------------
# restore
if train==false:
    # 记住在restore时定义相同的dtype和shape
    # redefine the same shape and same type for your variables
    w = tf.variable(np.arange(6).reshape((2,3)), dtype=tf.float32, name='weights') #空变量
    b = tf.variable(np.arange(3).reshape((1,3)), dtype=tf.float32, name='biases') #空变量
    
    # restore不需要定义init
    saver = tf.train.saver()
    with tf.session() as sess:
        # 提取保存的变量
        saver.restore(sess, "my_net/save_net.ckpt")
        # 寻找相同名字和标识的变量并存储在w和b中
        print("weights", sess.run(w))
        print("biases", sess.run(b))

运行代码,如果报错“notfounderror: restoring from checkpoint failed. this is most likely due to a variable name or other graph key that is missing from the checkpoint. please ensure that you have not altered the graph expected based on the checkpoint. ”,则需要重置spyder即可。

最后输出之前所保存的变量,weights为 [[1,2,3], [3,4,5]],偏置为 [[1,2,3]]。


那么,tensorflow如何保存我们的神经网络框架呢?我们需要把整个网络训练好再进行保存,其方法和上面类似,完整代码如下:

"""
created on sun dec 29 19:21:08 2019
@author: xiuzhang eastmount csdn
"""
import os
import glob
import cv2
import numpy as np
import tensorflow as tf
# 定义图片路径
path = 'photo/'
#---------------------------------第一步 读取图像-----------------------------------
def read_img(path):
    cate = [path   x for x in os.listdir(path) if os.path.isdir(path   x)]
    imgs = []
    labels = []
    fpath = []
    for idx, folder in enumerate(cate):
        # 遍历整个目录判断每个文件是不是符合
        for im in glob.glob(folder   '/*.jpg'):
            #print('reading the images:%s' % (im))
            img = cv2.imread(im)             #调用opencv库读取像素点
            img = cv2.resize(img, (32, 32))  #图像像素大小一致
            imgs.append(img)                 #图像数据
            labels.append(idx)               #图像类标
            fpath.append(path im)            #图像路径名
            #print(path im, idx)
            
    return np.asarray(fpath, np.string_), np.asarray(imgs, np.float32), np.asarray(labels, np.int32)
# 读取图像
fpaths, data, label = read_img(path)
print(data.shape)  # (1000, 256, 256, 3)
# 计算有多少类图片
num_classes = len(set(label))
print(num_classes)
# 生成等差数列随机调整图像顺序
num_example = data.shape[0]
arr = np.arange(num_example)
np.random.shuffle(arr)
data = data[arr]
label = label[arr]
fpaths = fpaths[arr]
# 拆分训练集和测试集 80%训练集 20%测试集
ratio = 0.8
s = np.int(num_example * ratio)
x_train = data[:s]
y_train = label[:s]
fpaths_train = fpaths[:s] 
x_val = data[s:]
y_val = label[s:]
fpaths_test = fpaths[s:] 
print(len(x_train),len(y_train),len(x_val),len(y_val)) #800 800 200 200
print(y_val)
#---------------------------------第二步 建立神经网络-----------------------------------
# 定义placeholder
xs = tf.placeholder(tf.float32, [none, 32, 32, 3])  #每张图片32*32*3个点
ys = tf.placeholder(tf.int32, [none])               #每个样本有1个输出
# 存放dropout参数的容器 
drop = tf.placeholder(tf.float32)                   #训练时为0.25 测试时为0
# 定义卷积层 conv0
conv0 = tf.layers.conv2d(xs, 20, 5, activation=tf.nn.relu)    #20个卷积核 卷积核大小为5 relu激活
# 定义max-pooling层 pool0
pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2])        #pooling窗口为2x2 步长为2x2
print("layer0:\n", conv0, pool0)
 
# 定义卷积层 conv1
conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu) #40个卷积核 卷积核大小为4 relu激活
# 定义max-pooling层 pool1
pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])        #pooling窗口为2x2 步长为2x2
print("layer1:\n", conv1, pool1)
# 将3维特征转换为1维向量
flatten = tf.layers.flatten(pool1)
# 全连接层 转换为长度为400的特征向量
fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu)
print("layer2:\n", fc)
# 加上dropout防止过拟合
dropout_fc = tf.layers.dropout(fc, drop)
# 未激活的输出层
logits = tf.layers.dense(dropout_fc, num_classes)
print("output:\n", logits)
# 定义输出结果
predicted_labels = tf.arg_max(logits, 1)
#---------------------------------第三步 定义损失函数和优化器---------------------------------
# 利用交叉熵定义损失
losses = tf.nn.softmax_cross_entropy_with_logits(
        labels = tf.one_hot(ys, num_classes),       #将input转化为one-hot类型数据输出
        logits = logits)
# 平均损失
mean_loss = tf.reduce_mean(losses)
# 定义优化器 学习效率设置为0.0001
optimizer = tf.train.adamoptimizer(learning_rate=1e-4).minimize(losses)
#------------------------------------第四步 模型训练和预测-----------------------------------
# 用于保存和载入模型
saver = tf.train.saver()
# 训练或预测
train = false
# 模型文件路径
model_path = "model/image_model"
with tf.session() as sess:
    if train:
        print("训练模式")
        # 训练初始化参数
        sess.run(tf.global_variables_initializer())
        # 定义输入和label以填充容器 训练时dropout为0.25
        train_feed_dict = {
                xs: x_train,
                ys: y_train,
                drop: 0.25
        }
        # 训练学习1000次
        for step in range(1000):
            _, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
            if step % 50 == 0:  #每隔50次输出一次结果
                print("step = {}\t mean loss = {}".format(step, mean_loss_val))
        # 保存模型
        saver.save(sess, model_path)
        print("训练结束,保存模型到{}".format(model_path))
    else:
        print("测试模式")
        # 测试载入参数
        saver.restore(sess, model_path)
        print("从{}载入模型".format(model_path))
        # label和名称的对照关系
        label_name_dict = {
            0: "人类",
            1: "沙滩",
            2: "建筑",
            3: "公交",
            4: "恐龙",
            5: "大象",
            6: "花朵",
            7: "野马",
            8: "雪山",
            9: "美食"
        }
        # 定义输入和label以填充容器 测试时dropout为0
        test_feed_dict = {
            xs: x_val,
            ys: y_val,
            drop: 0
        }
        
        # 真实label与模型预测label
        predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
        for fpath, real_label, predicted_label in zip(fpaths_test, y_val, predicted_labels_val):
            # 将label id转换为label名
            real_label_name = label_name_dict[real_label]
            predicted_label_name = label_name_dict[predicted_label]
            print("{}\t{} => {}".format(fpath, real_label_name, predicted_label_name))
        # 评价结果
        print("正确预测个数:", sum(y_val==predicted_labels_val))
        print("准确度为:", 1.0*sum(y_val==predicted_labels_val) / len(y_val))

核心步骤为:

saver = tf.train.saver()
model_path = "model/image_model"
with tf.session() as sess:
    if train:
    	#保存神经网络
    	sess.run(tf.global_variables_initializer())
    	for step in range(1000):
            _, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
            if step % 50 == 0:
                print("step = {}\t mean loss = {}".format(step, mean_loss_val))
        saver.save(sess, model_path)
    else:
    	#载入神经网络
    	saver.restore(sess, model_path)
        predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
        for fpath, real_label, predicted_label in zip(fpaths_test, y_val, predicted_labels_val):
            real_label_name = label_name_dict[real_label]
            predicted_label_name = label_name_dict[predicted_label]
            print("{}\t{} => {}".format(fpath, real_label_name, predicted_label_name))	

预测输出结果如下图所示,最终预测正确181张图片,准确度为0.905。相比之前机器学习knn的0.500有非常高的提升。

测试模式
info:tensorflow:restoring parameters from model/image_model
从model/image_model载入模型
b'photo/photo/3\\335.jpg'       公交 => 公交
b'photo/photo/1\\129.jpg'       沙滩 => 沙滩
b'photo/photo/7\\740.jpg'       野马 => 野马
b'photo/photo/5\\564.jpg'       大象 => 大象
...
b'photo/photo/9\\974.jpg'       美食 => 美食
b'photo/photo/2\\220.jpg'       建筑 => 公交
b'photo/photo/9\\912.jpg'       美食 => 美食
b'photo/photo/4\\459.jpg'       恐龙 => 恐龙
b'photo/photo/5\\525.jpg'       大象 => 大象
b'photo/photo/0\\44.jpg'        人类 => 人类
正确预测个数: 181
准确度为: 0.905

写到这里,这篇文章就讲解完毕,更多tensorflow深度学习文章会继续分享,接下来我们会分享rnn、lstm、文本识别等内容。如果读者有什么想学习的,也可以私聊我,我去学习并应用到你的领域。

最后,希望这篇基础性文章对您有所帮助,如果文章中存在错误或不足之处,还请海涵~作为人工智能的菜鸟,我希望自己能不断进步并深入,后续将它应用于图像识别、网络安全、对抗样本等领域,指导大家撰写简单的学术论文,一起加油!

感恩能与大家在华为云遇见!
希望能与大家一起在华为云社区共同成长。原文地址:https://blog.csdn.net/eastmount/article/details/103757386
【百变ai秀】有奖征文火热进行中:

(by:娜璋之家 eastmount 2021-09-10 夜于武汉)


参考文献:

[1] 冈萨雷斯著. 数字图像处理(第3版)[m]. 北京:电子工业出版社,2013.
[2] 杨秀璋, 颜娜. python网络数据爬取及分析从入门到精通(分析篇)[m]. 北京:北京航天航空大学出版社, 2018.
[3] 罗子江等. python中的图像处理[m]. 科学出版社, 2020.
[4] 
[5] 
[6] 
[7] 
[8] 

【4008云顶国际集团的版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请发送邮件至:;如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容。
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。