F.cross_entropy 팁
in Coding on Pytorch
loss = (F.cross_entropy(pred, lbl, reduction='none') * mask.squeeze(1)).sum() / mask.sum()
각 변수
- pred: 예측 값 (B, N, L)
- lbl: 타겟 라벨 (B, L)
- mask: 마스크 (B, 1, L)
- B: Batch 크기, N: 클래스 수, L: 길이
torch.Size([16, 5, 144]) torch.Size([16, 144]) torch.Size([16, 1, 144])