-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain.py
More file actions
36 lines (30 loc) · 1.26 KB
/
train.py
File metadata and controls
36 lines (30 loc) · 1.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
def train(model, device, train_loader, optimizer, scheduler, epoch, l1_factor):
model.train()
epoch_loss = 0
correct = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
if l1_factor > 0:
L1_loss = nn.L1Loss(size_average=None, reduce=None, reduction='mean')
reg_loss = 0
for param in model.parameters():
zero_vector = torch.rand_like(param) * 0
reg_loss += L1_loss(param,zero_vector)
loss += l1_factor * reg_loss
epoch_loss += loss.item()
loss.backward()
optimizer.step()
scheduler.step()
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
print(f'Train set: Average loss: {loss.item():.4f}, Accuracy: {100. * correct/len(train_loader.dataset):.2f}')
train_loss = epoch_loss / len(train_loader)
train_acc=100.*correct/len(train_loader.dataset)
return train_loss, train_acc