Tensorflow1.0 实现Sptial Transformer Networks遇到的问题

大米饭盖不住四喜丸子 2017-02-28 11:01:06
最近找到githurb上stn的项目,是基于Tensorflow0.7的
本人使用环境 tensorflow1.0 python3.5
以下为报错
File "C:\Users\hasee\STN_tf_test01.py", line 38, in <module>
h_trans=transformer(x,h_fc1,out_size)
File "C:\Users\hasee\STN_tf_01.py", line 145, in transformer
output = _transform(theta, U, out_size)
File "C:\Users\hasee\STN_tf_01.py", line 122, in _transform
grid = _meshgrid(out_height, out_width)
File "C:\Users\hasee\STN_tf_01.py", line 102, in _meshgrid
grid = tf.concat(0, [x_t_flat, y_t_flat, ones])
File "C:\Users\hasee\AppData\Local\Programs\Python\Python35\Lib\site-packages\tensorflow\python\ops\array_ops.py", line 1047, in concat
dtype=dtypes.int32).get_shape(
File "C:\Users\hasee\AppData\Local\Programs\Python\Python35\Lib\site-packages\tensorflow\python\framework\ops.py", line 651, in convert_to_tensor
as_ref=False)
File "C:\Users\hasee\AppData\Local\Programs\Python\Python35\Lib\site-packages\tensorflow\python\framework\ops.py", line 716, in internal_convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
File "C:\Users\hasee\AppData\Local\Programs\Python\Python35\Lib\site-packages\tensorflow\python\framework\constant_op.py", line 176, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)
File "C:\Users\hasee\AppData\Local\Programs\Python\Python35\Lib\site-packages\tensorflow\python\framework\constant_op.py", line 165, in constant
tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape))
File "C:\Users\hasee\AppData\Local\Programs\Python\Python35\Lib\site-packages\tensorflow\python\framework\tensor_util.py", line 367, in make_tensor_proto
_AssertCompatible(values, dtype)
File "C:\Users\hasee\AppData\Local\Programs\Python\Python35\Lib\site-packages\tensorflow\python\framework\tensor_util.py", line 302, in _AssertCompatible
(dtype.name, repr(mismatch), type(mismatch).__name__))

builtins.TypeError: Expected int32, got list containing Tensors of type '_Message' instead.



报错部分代码如下:

def _meshgrid(height, width):
print('begin--meshgrid')
with tf.variable_scope('_meshgrid'):
# This should be equivalent to:
# x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
# np.linspace(-1, 1, height))
# ones = np.ones(np.prod(x_t.shape))
# grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])

x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])),
tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0]))
print('meshgrid_x_t_ok')
y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1),
tf.ones(shape=tf.stack([1, width])))
print('meshgrid_y_t_ok')
x_t_flat = tf.reshape(x_t, (1, -1))
y_t_flat = tf.reshape(y_t, (1, -1))
print('meshgrid_flat_t_ok')
ones = tf.ones_like(x_t_flat)
print('meshgrid_ones_ok')
print(x_t_flat)
print(y_t_flat)
print(ones)
grid = tf.concat(0, [x_t_flat, y_t_flat, ones])#报错语句
print ('over_meshgrid')
return grid


以及上级代码如下:
im=ndimage.imread('C:\\Users\\hasee\\Desktop\\cat.jpg')
im=im/255.
#im=tf.reshape(im, [1,1200,1600,3])

im=im.reshape(1,1200,1600,3)

im=im.astype('float32')
print('img-over')
out_size=(600,800)
batch=np.append(im,im,axis=0)
batch=np.append(batch,im,axis=0)
num_batch=3

x=tf.placeholder(tf.float32,[None,1200,1600,3])
x=tf.cast(batch,'float32')
print('begin---')
with tf.variable_scope('spatial_transformer_0'):
n_fc=6
w_fc1=tf.Variable(tf.Variable(tf.zeros([1200*1600*3,n_fc]),name='W_fc1'))
initial=np.array([[0.5,0,0],[0,0.5,0]])
initial=initial.astype('float32')
initial=initial.flatten()


b_fc1=tf.Variable(initial_value=initial,name='b_fc1')


h_fc1=tf.matmul(tf.zeros([num_batch,1200*1600*3]),w_fc1)+b_fc1

print(x,h_fc1,out_size)

h_trans=transformer(x,h_fc1,out_size)


sess=tf.Session()
sess.run(tf.global_variables_initializer())
y=sess.run(h_trans,feed_dict={x:batch})
plt.imshow(y[0])
plt.show()


...全文
3303 7 打赏 收藏 转发到动态 举报
写回复
用AI写文章
7 条回复
切换为时间正序
请发表友善的回复…
发表回复
无比滴 2017-03-19
  • 打赏
  • 举报
回复
通过楼主的博客也找到解决方法啦,tf.contact()里面的参数换下位置就OK啦。
grid = tf.concat(0, [x_t_flat, y_t_flat, ones])#报错语句
grid = tf.concat( [x_t_flat, y_t_flat, ones],0) #楼主改后的代码
无比滴 2017-03-19
  • 打赏
  • 举报
回复
遇到同样问题了,期待大神看到帖子回来分享解决方法。
on2way 2017-03-14
  • 打赏
  • 举报
回复
怎么解决的 原因呢 这个网友。。
heavenpeien 2017-03-07
  • 打赏
  • 举报
回复
so why ?
qq_26240809 2017-03-01
  • 打赏
  • 举报
回复
我也碰到了这个问题,能说下是怎么解决的吗?谢谢!
  • 打赏
  • 举报
回复
已解决所有问题
  • 打赏
  • 举报
回复
在tensorflow0.12下运行正常,可疑点貌似是tf.expand_dims api修改后,带来的问题

4,450

社区成员

发帖
与我相关
我的任务
社区描述
图形图像/机器视觉
社区管理员
  • 机器视觉
  • 迪菲赫尔曼
加入社区
  • 近7日
  • 近30日
  • 至今
社区公告
暂无公告

试试用AI创作助手写篇文章吧