Waihinchan

关于 pytorch TVloss 代码实现的一些疑惑

  •  
  •   Waihinchan · Aug 8, 2020 · 3435 views
    This topic created in 2121 days ago, the information mentioned may be changed or developed.

    网上看到普遍的答案是这个

    class TVLoss(nn.Module):
        def __init__(self,TVLoss_weight=1):
            super(TVLoss,self).__init__()
            self.TVLoss_weight = TVLoss_weight
    
        def forward(self,x):
            batch_size = x.size()[0]
            h_x = x.size()[2]
            w_x = x.size()[3]
            count_h = self._tensor_size(x[:,:,1:,:])
            count_w = self._tensor_size(x[:,:,:,1:])
            h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
            w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
            return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
    
        def _tensor_size(self,t):
            return t.size()[1]*t.size()[2]*t.size()[3]
    

    这里给出的说的是β=2,且不支持变更. 所以按照这里给出的公式 https://blog.csdn.net/yexiaogu1104/article/details/88395475 β/2, 当β=2 那就是 1 也就是不进行任何操作. 所以最后 return 这里为什么会返回一个 self.TVLoss_weight2, 为啥要2 呢..

    No Comments Yet
    About   ·   Help   ·   Advertise   ·   Blog   ·   API   ·   FAQ   ·   Solana   ·   2752 Online   Highest 6679   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 27ms · UTC 06:26 · PVG 14:26 · LAX 23:26 · JFK 02:26
    ♥ Do have faith in what you're doing.