-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_images.py
More file actions
51 lines (38 loc) · 2.36 KB
/
test_images.py
File metadata and controls
51 lines (38 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import cv2
import torch
import configargparse
from model import DeblurGAN
from utils import selectDevice, tensorToImage, imageToTensor
if __name__ == "__main__":
# Select parameters for testing
arg = configargparse.ArgumentParser()
arg.add_argument('--dataset_path', type=str, default='test_images', help='Dataset path.')
arg.add_argument('--log_dir', type=str, default='deblurGAN_bs1_lr0.0001_numresblocks9_lambdaG100_lambdaD10', help='Name of the folder where the files of checkpoints and precision and loss values are stored.')
arg.add_argument('--checkpoint', type=str, default='checkpoint_67_best_g.pth',help='Checkpoint to use')
arg.add_argument('--num_resblocks', type=int, default=9, help='Number of residual blocks for the generator.')
arg.add_argument('--GPU', type=bool, default=True, help='True to train the model in the GPU.')
args = arg.parse_args()
device = selectDevice(args)
generator = DeblurGAN(n_resblocks=args.num_resblocks)
state_dict = torch.load(os.path.join(args.log_dir, "checkpoints", args.checkpoint), map_location=device)
generator.load_state_dict(state_dict)
generator.to(device)
generator.eval()
image_paths = []
for root, _, files in os.walk(os.path.join(args.dataset_path, "original")):
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
image_paths.append(os.path.join(root, file))
with torch.no_grad():
for image_test in image_paths:
image_original = cv2.imread(image_test)
kernel_size = 5
sigma = 5.0
image_blurred = cv2.GaussianBlur(image_original, (kernel_size, kernel_size), sigmaX=sigma)
image_blurred_tensor = imageToTensor(image_blurred)
image_blurred_tensor = torch.unsqueeze(image_blurred_tensor, dim=0).to(device)
blurred_image_deblurred = generator(image_blurred_tensor)
blurred_image_deblurred = tensorToImage(blurred_image_deblurred)
cv2.imwrite(os.path.join('test_images', 'blurred', os.path.basename(image_test)), image_blurred)
cv2.imwrite(os.path.join('test_images', 'deblurred', os.path.basename(image_test)), blurred_image_deblurred)