如何解决tf.train.batch只生成一个batch的数据的问题

CZlin. 2019-12-08 08:16:14
使用tensorflow训练网上下载好的flowers数据集的时候出现了问题,按照下面代码操作后用sess.run()获取batch数据并打印,只显示一个batch(10个)的数据,求问大神们如何解决
keys_to_features = {
'image/encoded': tf.FixedLenFeature([], default_value='', dtype=tf.string, ),
'image/format': tf.FixedLenFeature([], default_value='jpeg', dtype=tf.string),
'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=0),
'image/height': tf.FixedLenFeature([], tf.int64, default_value=0),
'image/width': tf.FixedLenFeature([], tf.int64, default_value=0)
}

items_to_handlers = {
'image': slim.tfexample_decoder.Image(image_key='image/encoded', format_key='image/format', channels=3),
'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]),
'height': slim.tfexample_decoder.Tensor('image/height', shape=[]),
'width': slim.tfexample_decoder.Tensor('image/width', shape=[])
}
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)

labels_to_names = None
items_to_descriptions = {
'image': 'An image with shape image_shape.',
'label': 'A single integer between 0 and 9.',
'height': 'float number',
'width': 'float number'}

dataset = slim.dataset.Dataset(
data_sources=tfrecord_path,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=1000,
items_to_descriptions=None,
num_classes=num_classes,
)

provider = slim.dataset_data_provider.DatasetDataProvider(dataset=dataset,
num_readers=4,
shuffle=False, # 这个改成False以后每次生成的batch都一样了
common_queue_capacity=256,
common_queue_min=128,
seed=None)

[image, label, height, width] = provider.get(['image', 'label', 'height', 'width'])

resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[resize_height, resize_width]))

images, labels = tf.train.batch([resized_image, label], batch_size=bsize, allow_smaller_final_batch=True, num_threads=1, capacity=5*bsize)
...全文
114 回复 打赏 收藏 转发到动态 举报
AI 作业
写回复
用AI写文章
回复
切换为时间正序
请发表友善的回复…
发表回复

4,499

社区成员

发帖
与我相关
我的任务
社区描述
图形图像/机器视觉
社区管理员
  • 机器视觉
  • 迪菲赫尔曼
加入社区
  • 近7日
  • 近30日
  • 至今
社区公告
暂无公告

试试用AI创作助手写篇文章吧