python里tensorflow2.0和1的差别_Tensorflow1.x 与 Tensorflow2.0 的区别

论坛 期权论坛 编程之家     
选择匿名的用户   2021-5-29 23:40   211   0

def flatten(x):

"""

Input:

- TensorFlow Tensor of shape (N, D1, ..., DM)

Output:

- TensorFlow Tensor of shape (N, D1 * ... * DM)

"""

N = tf.shape(x)[0]

return tf.reshape(x, (N, -1))

def test_flatten():

# Clear the current TensorFlow graph.

tf.reset_default_graph()

# Stage I: Define the TensorFlow graph describing our computation.

# In this case the computation is trivial: we just want to flatten

# a Tensor using the flatten function defined above.

# Our computation will have a single input, x. We don't know its

# value yet, so we define a placeholder which will hold the value

# when the graph is run. We then pass this placeholder Tensor to

# the flatten function; this gives us a new Tensor which will hold

# a flattened view of x when the graph is run. The tf.device

# context manager tells TensorFlow whether to place these Tensors

# on CPU or GPU.

with tf.device(device):

x = tf.placeholder(tf.float32)

x_flat = flatten(x)

# At this point we have just built the graph describing our computation,

# but we haven't actually computed anything yet. If we print x and x_flat

# we see that they don't hold any data; they are just TensorFlow Tensors

# representing values that will be computed when the graph is run.

print('x: ', type(x), x)

print('x_flat: ', type(x_flat), x_flat)

print()

# We need to use a TensorFlow Session object to actually run the graph.

with tf.Session() as sess:

# Construct concrete values of the input data x using numpy

x_np = np.arange(24).reshape((2, 3, 4))

print('x_np:\n', x_np, '\n')

# Run our computational graph to compute a concrete output value.

# The first argument to sess.run tells TensorFlow which Tensor

# we want it to compute the value of; the feed_dict specifies

# values to plug into all placeholder nodes in the graph. The

# resulting value of x_flat is returned from sess.run as a

# numpy array.

x_flat_np = sess.run(x_flat, feed_dict={x: x_np})

print('x_flat_np:\n', x_flat_np, '\n')

# We can reuse the same graph to perform the same computation

# with different input data

x_np = np.arange(12).reshape((2, 3, 2))

print('x_np:\n', x_np, '\n')

x_flat_np = sess.run(x_flat, feed_dict={x: x_np})

print('x_flat_np:\n', x_flat_np)

test_flatten()

分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:3875789
帖子:775174
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP