python – 使用tf.Dataset训练的模型进行推理
我使用tf.data.Dataset API训练了一个模型,所以我的训练代码看起来像这样
with graph.as_default():
dataset = tf.data.TFRecordDataset(tfrecord_path)
dataset = dataset.map(scale_features, num_parallel_calls=n_workers)
dataset = dataset.shuffle(10000)
dataset = dataset.padded_batch(batch_size, padded_shapes={...})
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle,
train_dataset.output_types,
train_dataset.output_shapes)
batch = iterator.get_next()
...
# Model code
...
iterator = dataset.make_initializable_iterator()
with tf.Session(graph=graph) as sess:
train_handle = sess.run(iterator.string_handle())
sess.run(tf.global_variables_initializer())
for epoch in range(n_epochs):
sess.run(train_iterator.initializer)
while True:
try:
sess.run(optimizer, feed_dict={handle: train_handle})
except tf.errors.OutOfRangeError:
break
现在,在训练模型之后,我想推断出不在数据集中的示例,我不确定如何去做.
为了清楚起见,我知道如何使用另一个数据集,例如我只是在测试时将句柄传递给我的测试集.
问题是关于给定扩展方案和网络期望句柄的事实,如果我想对未写入TFRecord的新示例进行预测,我将如何去做?
如果我修改批处理,我会事先负责缩放,如果可能的话我想避免这种情况.
那么我应该如何推断模型traiend tf.data.Dataset方式的单个例子呢?(这不是出于生产目的,它用于评估如果我更改特定功能会发生什么)