未加星标

tensorflow创建变量以及根据名称查找变量

字体大小 | |
[开发(python) 所属分类 开发(python) | 发布者 店小二04 | 时间 | 作者 红领巾 ] 0人收藏点击收藏

环境:Ubuntu14.04,tensorflow=1.4(bazel源码安装),Anaconda python=3.6

声明变量主要有两种方法:tf.Variable和 tf.get_variable,二者的最大区别是:

(1) tf.Variable是一个类,自带很多属性函数;而 tf.get_variable是一个函数;
(2) tf.Variable只能生成独一无二的变量,即如果给出的name已经存在,则会自动修改生成新的变量name;
(3) tf.get_variable可以用于生成共享变量。默认情况下,该函数会进行变量名检查,如果有重复则会报错。当在指定变量域中声明可
以变量共享时,可以重复使用该变量(例如RNN中的参数共享)。
下面给出简单的的示例程序:
import tensorflow as tf
with tf.variable_scope('scope1',reuse=tf.AUTO_REUSE) as scope1:
x1 = tf.Variable(tf.ones([1]),name='x1')
x2 = tf.Variable(tf.zeros([1]),name='x1')
y1 = tf.get_variable('y1',initializer=1.0)
y2 = tf.get_variable('y1',initializer=0.0)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(x1.name,x1.eval())
print(x2.name,x2.eval())
print(y1.name,y1.eval())
print(y2.name,y2.eval())

输出结果为:

scope1/x1:0 [ 1.]
scope1/x1_1:0 [ 0.]
scope1/y1:0 1.0
scope1/y1:0 1.0

1. tf.Variable(…)

tf.Variable(…)使用给定初始值来创建一个新变量,该变量会默认添加到 graph collections listed in collections, which defaults to [GraphKeys.GLOBAL_VARIABLES]。

如果trainable属性被设置为True,该变量同时也会被添加到graph collection GraphKeys.TRAINABLE_VARIABLES.

# tf.Variable
__init__(
initial_value=None,
trainable=True,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None,
constraint=None
)

2. tf.get_variable(…)

tf.get_variable(…)的返回值有两种情形:

使用指定的initializer来创建一个新变量;
当变量重用时,根据变量名搜索返回一个由tf.get_variable创建的已经存在的变量;

get_variable(
name,
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=True,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None,
constraint=None
)

3. 根据名称查找变量

在创建变量时,即使我们不指定变量名称,程序也会自动进行命名。于是,我们可以很方便的根据名称来查找变量,这在抓取参数、finetune模型等很多时候都很有用。

示例1:

通过在tf.global_variables()变量列表中,根据变量名进行匹配搜索查找。 该种搜索方式,可以同时找到由tf.Variable或者tf.get_variable创建的变量。

import tensorflow as tf
x = tf.Variable(1,name='x')
y = tf.get_variable(name='y',shape=[1,2])
for var in tf.global_variables():
if var.name == 'x:0':
print(var)

示例2:

利用get_tensor_by_name()同样可以获得由tf.Variable或者tf.get_variable创建的变量。
需要注意的是,此时获得的是Tensor, 而不是Variable,因此 x不等于x1.
import tensorflow as tf
x = tf.Variable(1,name='x')
y = tf.get_variable(name='y',shape=[1,2])
graph = tf.get_default_graph()
x1 = graph.get_tensor_by_name("x:0")
y1 = graph.get_tensor_by_name("y:0")

示例3:

针对tf.get_variable创建的变量,可以利用变量重用来直接获取已经存在的变量。

with tf.variable_scope("foo"):
bar1 = tf.get_variable("bar", (2,3)) # create
with tf.variable_scope("foo", reuse=True):
bar2 = tf.get_variable("bar") # reuse
with tf.variable_scope("", reuse=True): # root variable scope
bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
print((bar1 is bar2) and (bar2 is bar3))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。


您可能感兴趣的文章:TensorFLow用Saver保存和恢复变量tensorflow获取变量维度信息TensorFlow变量管理详解TensorFlow saver指定变量的存取

本文开发(python)相关术语:python基础教程 python多线程 web开发工程师 软件开发工程师 软件开发流程

主题: RIA变量UbuntuAUTUTAU
tags: tf,variable,get,None,name,变量,Variable,x1,y1,True,graph,shape,tensorflow,scope,print
分页:12
转载请注明
本文标题:tensorflow创建变量以及根据名称查找变量
本站链接:http://www.codesec.net/view/573782.html
分享请点击:


1.凡CodeSecTeam转载的文章,均出自其它媒体或其他官网介绍,目的在于传递更多的信息,并不代表本站赞同其观点和其真实性负责;
2.转载的文章仅代表原创作者观点,与本站无关。其原创性以及文中陈述文字和内容未经本站证实,本站对该文以及其中全部或者部分内容、文字的真实性、完整性、及时性,不作出任何保证或承若;
3.如本站转载稿涉及版权等问题,请作者及时联系本站,我们会及时处理。
登录后可拥有收藏文章、关注作者等权限...
技术大类 技术大类 | 开发(python) | 评论(0) | 阅读(40)