pytorch中y.data.norm()的含义

本文最后更新于:2023年4月7日 下午

import torch
x = torch.randn(3, requires_grad=True)
y = x*2
print(y.data.norm())
print(torch.sqrt(torch.sum(torch.pow(y,2))))  #其实就是对y张量L2范数,先对y中每一项取平方,之后累加,最后取根号
i=0
while y.data.norm()<1000:
  y = y*2
  i+=1
print(y)
print(i)

结果:

tensor(3.7025)
tensor(3.7025, grad_fn=<SqrtBackward>)
tensor([ 1066.4563, -1511.3652,  -414.6933], grad_fn=<MulBackward0>)
9

 


打赏支持
“如果你觉得我的文章不错,不妨鼓励我继续写作。”

pytorch中y.data.norm()的含义
https://dreamoneyou.github.io/2020/pytorch中y.data.norm()的含义/
作者
九叶草
发布于
2020年6月25日
更新于
2023年4月7日
许可协议