Pytorch — sensitivity 计算

Pytorch — sensitivity 计算Pytorchsensi 敏感度计算 1 sensitivity 是一种局部性的指标 表达正确识别正类个数 正类总个数 Sensitivity TPR TP TP FN 2 specificity 同理 不同之处为 正确识别负类个数 负类总个数 Specificity TNR TN TN FP 1 代码如下 defsensitivi output target

Pytorch — sensitivity 敏感度计算

1. sensitivity是一种局部性的指标,表达 正确识别正类个数 / 正类总个数 - Sensitivity/TPR = TP / (TP + FN) 2. specificity同理,不同之处为,正确识别负类个数 / 负类总个数 - Specificity/TNR = TN / (TN + FP) 
  • 1、代码如下:
def sensitivity(output, target, sensi): ''' 这里类别数为3 传入参数: sensi = np.array([-1] * 3) (首次,后面变为sensitivity的值) output --> tensor(80,3) 从outputs, _ = net(inputs)中获取 target --> tensor(80) 返回值: sensitivity --> np.array ''' # 取得到分类分数最大的值,返回第一维度是value,第二维度是index _, pred = output.max(1) # 将 pred 展开成 one-hot编码形式 pre_mask = torch.zeros(output.size()).scatter_(1, pred.cpu().view(-1, 1), 1.) # 将 target 也展开成 one-hot编码形式 tar_mask = torch.zeros(output.size()).scatter_(1, target.data.cpu().view(-1, 1), 1.) # 计算 acc 的one-hot编码形式 acc_mask = pre_mask * tar_mask # 计算 sensitivity sensitivity = acc_mask.sum(0) / tar_mask.sum(0) # 转换成numpy() sensitivity = sensitivity.numpy() if sensi[0] != -1 : #不是第一次计算sensivity, 计算求平均值 sensitivity = (sensitivity + sensi) / 2 return sensitivity 
  • 2、详细的具体解析
Batch_size = 80 print(output) tensor([[-0.0082, -0.1216, 0.0823], [ 0.0433, -0.1183, -0.0050], ..................................... , [ 0.0682, -0.1924, 0.0039]],device='cuda:0') :softmax计算得到值,3分类,故有3个值 print(target) tensor([1, 2, ... ,1], device='cuda:0') :目标标签值 print(output.max(1)) torch.return_types.max( values=tensor([ 0.0823, 0.0433, ... 0.0682], device='cuda:0'), indices=tensor([2, 0, ... 0], device='cuda:0'))max(1) --> values 对应ouput每一行中最大值,indices 下标 _, pred = output.max(1) print(pred) tensor([2, 0, ... 0], device='cuda:0') :取得预测的下标值 print(ouput.size()) torch.Size([80, 3]) print(target.size()) torch.Size([80]) :类似numpy的shape print(pred.eq(target)) tensor([0, 0, ... 0], device='cuda:0') :值相同为1,不同为0 print(pred.eq(target).sum()) tensor(21, device='cuda:0') :将所有的值相加 print(pred.eq(target).sum().item()) 21 :取出tensor里面的值 print(pred.cpu()) tensor([2, 0, ... 0]) :少了" device='cuda:0' " 应该是转移到了cpu中 print(pred.cpu().view(-1,1)) tensor([[2], [0], ... [0]]) :由180列,变成801列,view(-1,1)表示张量维度,-1表缺省,但可推断值 print(pred_mask) tensor([[0., 0., 1.], [1., 0., 0.], ....... [1., 0., 0.]]) :转换成one-hot编码形式 print(pred_mask.sum(0)) tensor([32., 11., 37.])sum(0)0表示以行为基本单位,列项相加 
  • 3、解析代码
 print(output) print(target) _, pred = output.max(1) print(output.max(1)) print(pred) print(output.size()) print(target.size()) print(pred.eq(target)) print(pred.eq(target).sum()) print(pred.eq(target).sum().item()) print(pred.cpu()) print(pred.cpu().view(-1, 1)) print(torch.zeros(output.size()).scatter_(1, pred.cpu().view(-1, 1), 1.)) pred_mask = torch.zeros(output.size()).scatter_(1, pred.cpu().view(-1, 1), 1.) print(pred_mask.sum(0)) tar_mask = torch.zeros(output.size()).scatter_(1, target.data.cpu().view(-1, 1), 1.) print(tar_mask) acc_mask = pred_mask * tar_mask print(acc_mask) 
  • 4、scatter_()函数具体解析
    https://www.cnblogs.com/daremosiranaihana/p/12538512.html
    注:scatter() 与 scatter_() 的区别在于 后者直接修改源数据




版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/214530.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月18日 下午4:02
下一篇 2026年3月18日 下午4:02


相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号