|
| 1 | +import io |
| 2 | +import logging |
| 3 | +from PIL import Image |
| 4 | +from keras.backend import clear_session |
| 5 | +from keras import models |
| 6 | +from keras.preprocessing.image import img_to_array |
| 7 | +from keras.applications import imagenet_utils |
| 8 | +import numpy as np |
| 9 | +from maxfw.model import MAXModelWrapper |
| 10 | +from config import DEFAULT_MODEL_PATH |
| 11 | + |
| 12 | + |
| 13 | +logger = logging.getLogger() |
| 14 | + |
| 15 | + |
| 16 | +class ModelWrapper(MAXModelWrapper): |
| 17 | + """Model wrapper for Keras models""" |
| 18 | + |
| 19 | + MODEL_NAME = 'resnet50' |
| 20 | + MODEL_INPUT_IMG_SIZE = (224, 224) |
| 21 | + MODEL_LICENSE = 'MIT' |
| 22 | + MODEL_MODE = 'caffe' |
| 23 | + MODEL_META_DATA = { |
| 24 | + 'id': '{}-keras-imagenet'.format(MODEL_NAME.lower()), |
| 25 | + 'name': '{} Keras Model'.format(MODEL_NAME), |
| 26 | + 'description': '{} Keras model trained on ImageNet'.format(MODEL_NAME), |
| 27 | + 'type': 'image_classification', |
| 28 | + 'license': '{}'.format(MODEL_LICENSE), |
| 29 | + 'source': 'https://developer.ibm.com/exchanges/models/all/max-resnet-50/' |
| 30 | + } |
| 31 | + |
| 32 | + def __init__(self, path=DEFAULT_MODEL_PATH): |
| 33 | + logger.info('Loading model from: {}...'.format(path)) |
| 34 | + clear_session() |
| 35 | + self.model = models.load_model(path) |
| 36 | + # this seems to be required to make Keras models play nicely with threads |
| 37 | + self.model._make_predict_function() |
| 38 | + logger.info('Loaded model: {}'.format(self.model.name)) |
| 39 | + |
| 40 | + def read_image(self, image_data): |
| 41 | + return Image.open(io.BytesIO(image_data)) |
| 42 | + |
| 43 | + def _pre_process(self, image): |
| 44 | + image = image.resize(self.MODEL_INPUT_IMG_SIZE) |
| 45 | + image = img_to_array(image) |
| 46 | + image = np.expand_dims(image, axis=0) |
| 47 | + return imagenet_utils.preprocess_input(image, mode=self.MODEL_MODE) |
| 48 | + |
| 49 | + def _post_process(self, preds): |
| 50 | + return imagenet_utils.decode_predictions(preds)[0] |
| 51 | + |
| 52 | + def _predict(self, x): |
| 53 | + return self.model.predict(x) |
0 commit comments