tf工具相关

论坛 期权论坛 脚本     
匿名技术用户   2020-12-23 17:50   11   0

1. 从event file解析loss:

import tensorflow as tf
from tensorflow.python.summary import summary_iterator
event_file = 'events.filename'
for event in summary_iterator.summary_iterator(event_file):
  if event.HasField('summary'):
    event_eval_result = {}
    for value in event.summary.value:
      if value.HasField('simple_value') and value.tag == 'loss':
        print(value.simple_value)

2. 多进程找出有问题的tfrecords:

import tensorflow as tf
import glob
import multiprocessing



def task2(id, q):
  while not q.empty():
      file = q.get()
      try:
        a=[1 for _ in tf.python_io.tf_record_iterator(file)]
      except Exception as e:
        print("=====",file)
  return None

pool = multiprocessing.Pool()
m = multiprocessing.Manager()
cpus = multiprocessing.cpu_count()
q = m.Queue()
results = []
train_files = sorted(glob.glob('/path/to/tfrecords_dir/*'))
for each in train_files:
  q.put(each)
for i in range(cpus):
    results.append(pool.apply_async(task2, args=(i,q)))
pool.close()
pool.join()

for result in results:
  result.get()

3.往ckpt里面添加变量

import os,sys
import tensorflow as tf
import horovod.tensorflow as hvd
hvd.init()
ckpt_path=sys.argv[1]
tf.get_variable('var_name',dtype=tf.int32,shape=[],initializer=tf.constant_initializer(0))
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver = tf.train.import_meta_graph(ckpt_path+'.meta')
  saver.restore(sess, ckpt_path)
  path=sys.argv[2]#
  saver2 = tf.train.Saver()
  ckpt_name=ckpt_path[ckpt_path.rindex('/')+1:]
  tf.gfile.MakeDirs(path)
  os.chdir(path)
  print(saver2.save(sess, ckpt_name))

分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:7942463
帖子:1588486
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP