未加星标

TensorFlow模型保存和提取的方法

字体大小 | |
[开发(python) 所属分类 开发(python) | 发布者 店小二04 | 时间 | 作者 红领巾 ] 0人收藏点击收藏

一、TensorFlow模型保存和提取方法

1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取。tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,saver.save(sess,"Model/model.ckpt") ,实际在这个文件目录下会生成4个人文件:

TensorFlow模型保存和提取的方法

checkpoint文件保存了一个录下多有的模型文件列表,model.ckpt.meta保存了TensorFlow计算图的结构信息,model.ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同,但加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。

2. 加载这个已保存的TensorFlow模型的方法是saver.restore(sess,"./Model/model.ckpt") ,加载模型的代码中也要定义TensorFlow计算图上的所有运算并声明一个tf.train.Saver类,不同的是加载模型时不需要进行变量的初始化,而是将变量的取值通过保存的模型加载进来,注意加载路径的写法。若不希望重复定义计算图上的运算,可直接加载已经持久化的图,saver =tf.train.import_meta_graph("Model/model.ckpt.meta") 。

3.tf.train.Saver类也支持在保存和加载时给变量重命名,声明Saver类对象的时候使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名},saver = tf.train.Saver({"v1":u1, "v2": u2})即原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中。

4. 上一条做的目的之一就是方便使用变量的滑动平均值。如果在加载模型时直接将影子变量映射到变量自身,则在使用训练好的模型时就不需要再调用函数来获取变量的滑动平均值了。载入时,声明Saver类对象时通过一个字典将滑动平均值直接加载到新的变量中,saver = tf.train.Saver({"v/ExponentialMovingAverage": v}),另通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典。

此外,通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中。

二、TensorFlow程序实现

# 本文件程序为配合教材及学习进度渐进进行,请按照注释分段执行
# 执行时要注意IDE的当前工作过路径,最好每段重启控制器一次,输出结果更准确


# Part1: 通过tf.train.Saver类实现保存和载入神经网络模型

# 执行本段程序时注意当前的工作路径
import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, "Model/model.ckpt")


# Part2: 加载TensorFlow模型的方法

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:
saver.restore(sess, "./Model/model.ckpt") # 注意此处路径前添加"./"
print(sess.run(result)) # [ 3.]


# Part3: 若不希望重复定义计算图上的运算,可直接加载已经持久化的图

import tensorflow as tf

saver = tf.train.import_meta_graph("Model/model.ckpt.meta")

with tf.Session() as sess:
saver.restore(sess, "./Model/model.ckpt") # 注意路径写法
print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [ 3.]


# Part4: tf.train.Saver类也支持在保存和加载时给变量重命名

import tensorflow as tf

# 声明的变量名称name与已保存的模型中的变量名称name不一致
u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
result = u1 + u2

# 若直接生命Saver类对象,会报错变量找不到
# 使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名}
# 原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中
saver = tf.train.Saver({"v1": u1, "v2": u2})

with tf.Session() as sess:
saver.restore(sess, "./Model/model.ckpt")
print(sess.run(result)) # [ 3.]


# Part5: 保存滑动平均模型

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")
for variables in tf.global_variables():
print(variables.name) # v:0

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():
print(variables.name) # v:0
# v/ExponentialMovingAverage:0

saver = tf.train.Saver()

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.assign(v, 10))
sess.run(maintain_averages_op)
saver.save(sess, "Model/model_ema.ckpt")
print(sess.run([v, ema.average(v)])) # [10.0, 0.099999905]


# Part6: 通过变量重命名直接读取变量的滑动平均值

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})

with tf.Session() as sess:
saver.restore(sess, "./Model/model_ema.ckpt")
print(sess.run(v)) # 0.0999999


# Part7: 通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")
# 注意此处的变量名称name一定要与已保存的变量名称一致
ema = tf.train.ExponentialMovingAverage(0.99)
print(ema.variables_to_restore())
# {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
# 此处的v取自上面变量v的名称name="v"

saver = tf.train.Saver(ema.variables_to_restore())

with tf.Session() as sess:
saver.restore(sess, "./Model/model_ema.ckpt")
print(sess.run(v)) # 0.0999999


# Part8: 通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中

import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分
graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])

with tf.gfile.GFile("Model/combined_model.pb", 'wb') as f:
f.write(output_graph_def.SerializeToString())


# Part9: 载入包含变量及其取值的模型

import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
model_filename = "Model/combined_model.pb"
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

result = tf.import_graph_def(graph_def, return_elements=["add:0"])
print(sess.run(result)) # [array([ 3.], dtype=float32)]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。


您可能感兴趣的文章:TensorFlow模型保存/载入的两种方法利用TensorFlow训练简单的二分类神经网络模型的方法用tensorflow构建线性回归模型的示例代码python使用tensorflow保存、加载和使用模型的方法

本文开发(python)相关术语:python基础教程 python多线程 web开发工程师 软件开发工程师 软件开发流程

主题: 神经网络变量
tags: tf,sess,name,saver,train,model,变量,variables,Saver,v1,graph,import,Model,tensorflow
分页:12
转载请注明
本文标题:TensorFlow模型保存和提取的方法
本站链接:https://www.codesec.net/view/573802.html


1.凡CodeSecTeam转载的文章,均出自其它媒体或其他官网介绍,目的在于传递更多的信息,并不代表本站赞同其观点和其真实性负责;
2.转载的文章仅代表原创作者观点,与本站无关。其原创性以及文中陈述文字和内容未经本站证实,本站对该文以及其中全部或者部分内容、文字的真实性、完整性、及时性,不作出任何保证或承若;
3.如本站转载稿涉及版权等问题,请作者及时联系本站,我们会及时处理。
登录后可拥有收藏文章、关注作者等权限...
技术大类 技术大类 | 开发(python) | 评论(0) | 阅读(186)