-
Notifications
You must be signed in to change notification settings - Fork 113
Expand file tree
/
Copy pathGreedyHash.py
More file actions
135 lines (103 loc) · 4.27 KB
/
GreedyHash.py
File metadata and controls
135 lines (103 loc) · 4.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from utils.tools import *
from network import *
import os
import torch
import torch.optim as optim
import time
import numpy as np
torch.multiprocessing.set_sharing_strategy('file_system')
# GreedyHash(NIPS2018)
# paper [Greedy Hash: Towards Fast Optimization for Accurate Hash Coding in CNN](https://papers.nips.cc/paper/7360-greedy-hash-towards-fast-optimization-for-accurate-hash-coding-in-cnn.pdf)
# code [GreedyHash](https://github.com/ssppp/GreedyHash)
def get_config():
config = {
"alpha": 0.1,
"optimizer": {"type": optim.SGD, "epoch_lr_decrease": 30,
"optim_params": {"lr": 0.001, "weight_decay": 5e-4, "momentum": 0.9}},
# "optimizer": {"type": optim.RMSprop, "epoch_lr_decrease": 30,
# "optim_params": {"lr": 5e-5, "weight_decay": 5e-4}},
"info": "[GreedyHash]",
"resize_size": 256,
"crop_size": 224,
"batch_size": 64,
"net": AlexNet,
# "net":ResNet,
# "dataset": "cifar10",
"dataset": "cifar10-1",
# "dataset": "cifar10-2",
# "dataset": "coco",
# "dataset": "mirflickr",
# "dataset": "voc2012",
# "dataset": "imagenet",
# "dataset": "nuswide_21",
# "dataset": "nuswide_21_m",
# "dataset": "nuswide_81_m",
"epoch": 200,
"test_map": 3,
# "device":torch.device("cpu"),
"device": torch.device("cuda:1"),
"bit_list": [48],
}
config = config_dataset(config)
if config["dataset"] == "imagenet":
config["alpha"] = 1
config["optimizer"]["epoch_lr_decrease"] = 80
return config
class GreedyHashLoss(torch.nn.Module):
def __init__(self, config, bit):
super(GreedyHashLoss, self).__init__()
self.fc = torch.nn.Linear(bit, config["n_class"], bias=False).to(config["device"])
self.criterion = torch.nn.CrossEntropyLoss().to(config["device"])
def forward(self, u, onehot_y, ind, config):
b = GreedyHashLoss.Hash.apply(u)
# one-hot to label
y = onehot_y.argmax(axis=1)
y_pre = self.fc(b)
loss1 = self.criterion(y_pre, y)
loss2 = config["alpha"] * (u.abs() - 1).pow(3).abs().mean()
return loss1 + loss2
class Hash(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
# ctx.save_for_backward(input)
return input.sign()
@staticmethod
def backward(ctx, grad_output):
# input, = ctx.saved_tensors
# grad_output = grad_output.data
return grad_output
def train_val(config, bit):
device = config["device"]
train_loader, test_loader, dataset_loader, num_train, num_test, num_dataset = get_data(config)
config["num_train"] = num_train
net = config["net"](bit).to(device)
criterion = GreedyHashLoss(config, bit)
optimizer = config["optimizer"]["type"](list(net.parameters())+list(criterion.parameters()), **(config["optimizer"]["optim_params"]))
Best_mAP = 0
for epoch in range(config["epoch"]):
lr = config["optimizer"]["optim_params"]["lr"] * (0.1 ** (epoch // config["optimizer"]["epoch_lr_decrease"]))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
current_time = time.strftime('%H:%M:%S', time.localtime(time.time()))
print("%s[%2d/%2d][%s] bit:%d, lr:%.9f, dataset:%s, training...." % (
config["info"], epoch + 1, config["epoch"], current_time, bit, lr, config["dataset"]), end="")
net.train()
train_loss = 0
for image, label, ind in train_loader:
image = image.to(device)
label = label.to(device)
optimizer.zero_grad()
u = net(image)
loss = criterion(u, label.float(), ind, config)
train_loss += loss.item()
loss.backward()
optimizer.step()
train_loss = train_loss / len(train_loader)
print("\b\b\b\b\b\b\b loss:%.3f" % (train_loss))
if (epoch + 1) % config["test_map"] == 0:
Best_mAP = validate(config, Best_mAP, test_loader, dataset_loader, net, bit, epoch, num_dataset)
if __name__ == "__main__":
config = get_config()
print(config)
for bit in config["bit_list"]:
train_val(config, bit)