-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
117 lines (93 loc) · 3.67 KB
/
predict.py
File metadata and controls
117 lines (93 loc) · 3.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
import argparse
import torch
from torchvision import transforms, models
import numpy as np
from train import load_model
from PIL import Image
# define command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint_path', type=str, help='Name of checkpoint file to use for predicting')
parser.add_argument('--gpu', action='store_true', help='Use GPU if available')
parser.add_argument('--image_path', type=str, help='Path for image file which will be used for prediction')
parser.add_argument('--label_file', type=str, help='JSON file containing mapping of number to labels')
parser.add_argument('--top_k', type=int, help='Return top k predictions')
args = parser.parse_args()
# Assume that we are on a CUDA machine, then this should printa CUDA device:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# load the checkpoint
def load_checkpoint(checkpoint_path):
ckp = torch.load(checkpoint_path)
# model = models.vgg16(pretrained=True)
if ckp['arch'] == "vgg16":
model = models.vgg16(pretrained=True)
elif ckp['arch'] == "vgg19":
model = models.vgg16(pretrained=True)
else:
print("Unknown architecture")
for param in model.parameters():
param.requires_grad = False
model.classifier = ckp["classifier"]
model.load_state_dict = (ckp["state_dict"])
model.class_to_idx = ckp['class_to_idx']
# model.to(device)
model.cuda()
return model
# Take image file as an input and predict the class for it
def predict(image_path, top_k=5):
# Use command line arguments if specified
if args.checkpoint_path:
checkpoint_path = args.checkpoint_path
if args.gpu:
gpu = args.gpu
if args.image_path:
image_path = args.image_path
if args.label_file:
label_file = args.label_file
if args.top_k:
top_k = args.top_k
# load the checkpoint
model = load_checkpoint(checkpoint_path)
# use GPU if available
if gpu & torch.cuda.is_available():
model.cuda()
# Process image
img = process_image(image_path)
# Numpy -> Tensor
image_tensor = torch.from_numpy(img).type(torch.FloatTensor)
# Add batch of size 1 to image
model_input = image_tensor.unsqueeze(0)
model_input = model_input.cuda()
# Probs
probs = torch.exp(model.forward(model_input))
# Top probs
top_probs, top_labs = probs.topk(args.top_k)
top_probs = top_probs.detach().cpu().numpy().tolist()[0]
top_labs = top_labs.detach().cpu().numpy().tolist()[0]
# label mapping from file
with open(args.label_file, 'r') as f:
cat_to_name = json.load(f)
# Convert indices to classes
idx_to_class = {val: key for key, val in model.class_to_idx.items()}
top_labels = [idx_to_class[lab] for lab in top_labs]
top_flowers = [cat_to_name[idx_to_class[lab]] for lab in top_labs]
return top_probs, top_labels, top_flowers
def process_image(image_path):
''' Scales, crops, and normalizes a PIL image for a PyTorch model,
returns an Numpy array
'''
# TODO: Process a PIL image for use in a PyTorch model
img_loader = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
pil_image = Image.open(image_path)
pil_image = img_loader(pil_image).float()
np_image = np.array(pil_image)
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
np_image = (np.transpose(np_image, (1, 2, 0)) - mean)/std
np_image = np.transpose(np_image, (2, 0, 1))
return np_image
print(predict(args.image_path))