• 全部
  • 互动交流
  • 文章分享

tensorflow 代码问题求解

m0_57756523 2021-05-05 04:52:55
import tensorflow as tf import numpy as np import pandas as pd from sklearn.datasets import load_iris#导入数据集合 x_data=load_iris().data#导入特征数据集合 y_data=load_iris().target#导入标签数据集合 x_train=x_data[:-30]#特征数据数据训练集 x_test=x_data[-30:]#特征数据数据测试集 y_train=y_data[:-30]#标签数据训练集 y_test=y_data[-30:]#标签数据测试集 np.random.seed(1)#随机种子,保持看乱序一致 np.random.shuffle(x_train)#乱序 np.random.seed(1) np.random.shuffle(y_train)#乱序 tf.random.set_seed(111) train_db=tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(32)#数据切割把数组切位单独张量后,对数据打包 # for i in train_db: #     print(i) w=tf.Variable(tf.random.truncated_normal([4,3],seed=1))#初始化参数,并设定为可训练变量 b=tf.Variable(tf.random.truncated_normal([3],seed=1))#初始化参数,并设定为可训练变量 epoch=100 loss_all=0 train_loss_result=[] test_acc=[] lr=0.2 # for i in epoch for step,(x,y) in enumerate(train_db):     with tf.GradientTape() as tape:        x=tf.cast(x,dtype="float64") #        print(x)        w=tf.cast(w,dtype="float64") #        print(w)        b=tf.cast(b,dtype="float64")        y_train=tf.one_hot(tf.cast(y,dtype="int32"),depth=3)        y_test=tf.one_hot(tf.cast(y_test,dtype="int32"),depth=3)        y_pred=tf.matmul(x,w)+b        loss=tf.nn.softmax_cross_entropy_with_logits(y_train,y_pred)        loss_mean=tf.reduce_mean(loss)  #        print(loss_mean)     grad=tape.gradient(loss_mean,[w,b])      #     w=w-lr*grad[0] #     b=b-lr*grad[1]     print(type(grad[0]))     if step==1:         break # # loss_all+=loss_mean 运行结果: <class 'tensorflow.python.framework.ops.EagerTensor'> <class 'NoneType'> 上面代码,在进行循环是,第一次step循环,在计算w,b 梯度是,有值,进入第二次循环是,报出空值,感谢大神求解问题在哪?
...全文
20 1 收藏 2
写回复
2 条回复
切换为时间正序
请发表友善的回复…
发表回复
今天检查了一下,知道原因了,因为在更改数据为float64时,把b,w 的viriable 类型改为了tensor类型,因此导致数据无法在求梯度的时候,无法更新参数
回复
哪位大神给予解答
回复
相关推荐
发帖
脚本语言
创建于2007-08-27

3.7w+

社区成员

JavaScript,VBScript,AngleScript,ActionScript,Shell,Perl,Ruby,Lua,Tcl,Scala,MaxScript 等脚本语言交流。
申请成为版主
帖子事件
创建了帖子
2021-05-05 04:52
社区公告

CSDN 脚本语言社区接受专栏投稿(专栏会在顶部创建专属你的栏目),投稿需满足以下要求:

  • 脚本语言技术相关;
  • 文章持续更新,保持活跃;
  • 内容清晰明了,干货为主;
  • 文章排版有序,有条有理。

本社区开通招聘专栏,发布招聘信息请联系版主,发布者需要保证招聘信息真实有效,CSDN 平台和版主不对招聘内容负责!

联系方式:私聊版主、发送邮件、QQ联系等均可: