python – 使用tf.Dataset训练的模型进行推理

weixin_38081402 2019-09-12 01:01:36
我使用tf.data.Dataset API训练了一个模型,所以我的训练代码看起来像这样 with graph.as_default(): dataset = tf.data.TFRecordDataset(tfrecord_path) dataset = dataset.map(scale_features, num_parallel_calls=n_workers) dataset = dataset.shuffle(10000) dataset = dataset.padded_batch(batch_size, padded_shapes={...}) handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle(handle, train_dataset.output_types, train_dataset.output_shapes) batch = iterator.get_next() ... # Model code ... iterator = dataset.make_initializable_iterator() with tf.Session(graph=graph) as sess: train_handle = sess.run(iterator.string_handle()) sess.run(tf.global_variables_initializer()) for epoch in range(n_epochs): sess.run(train_iterator.initializer) while True: try: sess.run(optimizer, feed_dict={handle: train_handle}) except tf.errors.OutOfRangeError: break 现在,在训练模型之后,我想推断出不在数据集中的示例,我不确定如何去做. 为了清楚起见,我知道如何使用另一个数据集,例如我只是在测试时将句柄传递给我的测试集. 问题是关于给定扩展方案和网络期望句柄的事实,如果我想对未写入TFRecord的新示例进行预测,我将如何去做? 如果我修改批处理,我会事先负责缩放,如果可能的话我想避免这种情况. 那么我应该如何推断模型traiend tf.data.Dataset方式的单个例子呢?(这不是出于生产目的,它用于评估如果我更改特定功能会发生什么)
...全文
154 1 打赏 收藏 转发到动态 举报
写回复
用AI写文章
1 条回复
切换为时间正序
请发表友善的回复…
发表回复
weixin_38108918 2019-09-12
  • 打赏
  • 举报
回复
实际上图中有一个名为“IteratorGetNext:0”的张量名称 当您使用数据集api时,您可以使用以下方式直接设置 输入: #get a tensor from a graph input tensor : input = graph.get_tensor_by_name("IteratorGetNext:0") # difine the target tensor you want evaluate for your prediction prediction tensor: predictions=... # finally call session to run then sess.run(predictions, feed_dict={input: np.asanyarray(images), ...})

433

社区成员

发帖
与我相关
我的任务
社区描述
其他技术讨论专区
其他 技术论坛(原bbs)
社区管理员
  • 其他技术讨论专区社区
加入社区
  • 近7日
  • 近30日
  • 至今
社区公告
暂无公告

试试用AI创作助手写篇文章吧