tensorflow中的条件语句和循环语句

论坛 期权论坛 脚本     
匿名技术用户   2021-1-6 02:14   503   0

tensorflow中,不可以直接拿tensor比较的结果作为 if 语句的条件,因此tensorflow中实现了自己的条件语句:

a = tf.get_variable("a",initializer=1)
b = tf.get_variable("b",initializer=2)

pred = tf.equal(a,b)

## 下面这种写法是正确的
def fun1():
 return a
def fun2():
 return b
c = tf.cond(pred, fun1, fun2)

## 下面这种写法是错误的
# if tf.equal(a,b):
#  c = a
# else:
#  c = b

with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 print(sess.run(c))

同理,tensorflow中,while函数也是需要有条件判断语句的,所以tensorflow实现了自己的while循环:

a = tf.get_variable("a",initializer=1)
b = tf.get_variable("b",initializer=5)

def cond(a, b):
 # 输入为loop_vars, 输出为布尔值
 return tf.less(a, b)

def body(a, b):
 # 输入为loop_vars, 输出为lop_vars
 a = a + 2
 b = b + 1
 return a, b
a, b = tf.while_loop(cond, body, loop_vars=[a, b])
with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 print(sess.run([a,b])) # 9,9

下面我们借助条件语句和循环语句实现一个在numpy中非常容易实现的功能:

给定两个placeholder,input_a和input_b,两者的维度都是[None,5],而实际输入是维度分别为[seq_len,5]和[seq_len+1,5],得到一个output_c,形状为[seq_len*2+1, 5],其奇数位来自于input_b,偶数位来自于input_a。即,对于output_c而言,[2*i,5]的位置来自于input_a[i,5],[2*i+1,5]的位置来自于input_b[i,5]。

题目如下:

input_a = tf.placeholder(dtype=tf.int64, shape=[None, 5]) #实际输入时为[seq_len, 5]
input_b = tf.placeholder(dtype=tf.int64, shape=[None, 5]) #实际输入时为[seq_len+1, 5]

"""
请在此处键入你的代码
"""

with tf.Session() as sess:
 
    feed_a = [[1,2,3,4,5],[2,3,4,5,6]] #shape=[3,5]
    feed_b = [[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0]] #shape=[4,5]

    sess.run(tf.global_variables_initializer())
 output_c_value = sess.run(output_c, feed_dict={input_a:feed_a, input_b:feed_b})

 print(output_c_value)
 #期望得到的值是[[0,0,0,0,0],[1,2,3,4,5],[0,0,0,0,0],[2,3,4,5,6],[0,0,0,0,0]]

在实现的过程中,有几个关键点:

1. 因为placeholder形状的第一维是None,所以取出来的seq_len是一个Tensor

2. 因为seq_len是一个Tensor,所以无法用seq_len来实现python原生的条件和循环语句

3. Tensor不支持对特定位进行赋值,所以必须建一个新的Tensor,然后把原有的Tensor中的特定位赋给新的Tensor

最终实现的代码如下:

import tensorflow as tf

input_a = tf.placeholder(dtype=tf.int64, shape=[None, 5]) #实际输入时为[seq_len, 5]
input_b = tf.placeholder(dtype=tf.int64, shape=[None, 5]) #实际输入时为[seq_len+1, 5]

seq_len = tf.shape(input_a)[0]

i = tf.get_variable("i", initializer=0)
output_c = tf.expand_dims(input_b[0,:], axis=0)

def loop_cond(input_a, input_b, output_c, i, seq_len):
 return tf.less(i, 2*seq_len)

def loop_body(input_a, input_b, output_c, i, seq_len):

 def concat_a():
  return tf.concat([output_c, tf.expand_dims(input_a[i//2,:], axis=0)], axis=0)
 def concat_b():
  return tf.concat([output_c, tf.expand_dims(input_b[i//2,:], axis=0)], axis=0)
 pred = tf.equal(i%2, 0)
 output_c = tf.cond(pred, concat_a, concat_b)
 i += 1
 return input_a, input_b, output_c, i, seq_len

_, _, output_c, _, _ = tf.while_loop(cond=loop_cond, body=loop_body, 
                                     loop_vars=[input_a, input_b, output_c, i, seq_len],                                                                          
                                     shape_invariants=[input_a.get_shape(), 
                                                       input_b.get_shape(),      
                                                       tf.TensorShape([None,5]), 
                                                       i.get_shape(), 
                                                    seq_len.get_shape()])

with tf.Session() as sess:
 
 feed_a = np.array([[1,2,3,4,5],[2,3,4,5,6]]) #shape=[2,5]
 feed_b = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0]]) #shape=[3,5]

 sess.run(tf.global_variables_initializer())
 output_c_value = sess.run(output_c, feed_dict={input_a:feed_a, input_b:feed_b})

 print(output_c_value)
 #得到的值是[[0,0,0,0,0],[1,2,3,4,5],[0,0,0,0,0],[2,3,4,5,6],[0,0,0,0,0]]

有趣之处在于,tensorflow的while_loop是先循环再判断,也就是说,循环终止条件需要提前一步。

分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:7942463
帖子:1588486
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP