CLASS torch.nn.Softmax(dim=None)

 

Returns: [0,1] 범위이고 합은 1

m = nn.Softmax(dim=1)
input = torch.randn(2,3)
output = m(input)

 

.argmax(1) 사용

 

ref. https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html?highlight=nn+softmax#torch.nn.Softmax

 

+ Recent posts