使用TensorFlow Python API时,我创建了一个变量(没有name
在构造函数中指定它),并且它的name
属性具有值"Variable_23:0"
.当我尝试使用这个变量时tf.get_variable("Variable23")
,"Variable_23_1:0"
会创建一个名为的新变量.如何正确选择"Variable_23"
而不是创建新的?
我想要做的是按名称选择变量,并重新初始化它,以便我可以微调权重.
该get_variable()
函数创建一个新变量或返回之前创建的变量get_variable()
.它不会返回使用创建的变量tf.Variable()
.这是一个简单的例子:
>>> with tf.variable_scope("foo"): ... bar1 = tf.get_variable("bar", (2,3)) # create ... >>> with tf.variable_scope("foo", reuse=True): ... bar2 = tf.get_variable("bar") # reuse ... >>> with tf.variable_scope("", reuse=True): # root variable scope ... bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above) ... >>> (bar1 is bar2) and (bar2 is bar3) True
如果您没有使用创建变量tf.get_variable()
,则有几个选项.首先,您可以使用tf.global_variables()
(如@mrry建议的那样):
>>> bar1 = tf.Variable(0.0, name="bar") >>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0] >>> bar1 is bar2 True
或者您可以这样使用tf.get_collection()
:
>>> bar1 = tf.Variable(0.0, name="bar") >>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0] >>> bar1 is bar2 True
编辑
您还可以使用get_tensor_by_name()
:
>>> bar1 = tf.Variable(0.0, name="bar") >>> graph = tf.get_default_graph() >>> bar2 = graph.get_tensor_by_name("bar:0") >>> bar1 is bar2 False, bar2 is a Tensor througn convert_to_tensor on bar1. but bar1 equal bar2 in value.
回想一下张量是一个操作的输出.它与操作同名,另外:0
.如果操作有多个输出,它们具有相同的名称作为操作加:0
,:1
,:2
,等等.
通过名称获取变量的最简单方法是在tf.global_variables()
集合中搜索它:
var_23 = [v for v in tf.global_variables() if v.name == "Variable_23:0"][0]
这适用于现有变量的临时重用.当您想要在模型的多个部分之间共享变量时,更加结构化的方法将在" 共享变量"教程中介绍.