Skip to content

Commit 136aed2

Browse files
authored
Add MAX Framework package (#17)
1 parent e838bb0 commit 136aed2

File tree

11 files changed

+132
-168
lines changed

11 files changed

+132
-168
lines changed

.travis.yml

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
language: python
22
python:
3-
- 3.6
4-
3+
- 3.6
54
services:
6-
- docker
7-
5+
- docker
86
install:
9-
- docker build -t max-resnet-50 .
10-
- docker run -it -d -p 5000:5000 max-resnet-50
11-
- sleep 30 # container needs a few seconds before it will accept requests
12-
7+
- docker build -t max-resnet-50 .
8+
- docker run -it -d -p 5000:5000 max-resnet-50
9+
- pip install pytest requests flake8
1310
before_script:
14-
- pip install pytest requests
15-
11+
- flake8 . --max-line-length=127
12+
- sleep 30 # container needs a few seconds before it will accept requests
1613
script:
17-
- pytest tests/test.py
14+
- pytest tests/test.py

Dockerfile

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM codait/max-base:v1.0.0
1+
FROM codait/max-base:v1.1.0
22

33
ARG model_bucket=http://max-assets.s3-api.us-geo.objectstorage.softlayer.net/keras
44
ARG model_file=resnet50.h5
@@ -10,7 +10,9 @@ COPY requirements.txt /workspace
1010
RUN pip install -r requirements.txt
1111

1212
COPY . /workspace
13-
RUN md5sum -c md5sums.txt # check file integrity
13+
14+
# check file integrity
15+
RUN md5sum -c md5sums.txt
1416

1517
EXPOSE 5000
1618

api/__init__.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,2 @@
1-
from flask_restplus import Api
2-
3-
from config import API_TITLE, API_VERSION, API_DESC
4-
from .model import api as model_ns
5-
6-
api = Api(
7-
title=API_TITLE,
8-
version=API_VERSION,
9-
description=API_DESC)
10-
11-
api.namespaces.clear()
12-
api.add_namespace(model_ns)
1+
from .metadata import ModelMetadataAPI # noqa
2+
from .predict import ModelPredictAPI # noqa

api/metadata.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from core.model import ModelWrapper
2+
from maxfw.core import MAX_API, MetadataAPI, METADATA_SCHEMA
3+
4+
5+
class ModelMetadataAPI(MetadataAPI):
6+
7+
@MAX_API.marshal_with(METADATA_SCHEMA)
8+
def get(self):
9+
"""Return the metadata associated with the model"""
10+
return ModelWrapper.MODEL_META_DATA

api/model.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

api/predict.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from flask_restplus import fields
2+
from werkzeug.datastructures import FileStorage
3+
from maxfw.core import MAX_API, PredictAPI
4+
from core.model import ModelWrapper
5+
6+
input_parser = MAX_API.parser()
7+
input_parser.add_argument('image', type=FileStorage, location='files', required=True, help="An image file")
8+
9+
label_prediction = MAX_API.model('LabelPrediction', {
10+
'label_id': fields.String(required=False, description='Class label identifier'),
11+
'label': fields.String(required=True, description='Class label'),
12+
'probability': fields.Float(required=True)
13+
})
14+
15+
predict_response = MAX_API.model('ModelPredictResponse', {
16+
'status': fields.String(required=True, description='Response status message'),
17+
'predictions': fields.List(fields.Nested(label_prediction), description='Predicted labels and probabilities')
18+
})
19+
20+
21+
class ModelPredictAPI(PredictAPI):
22+
23+
model_wrapper = ModelWrapper()
24+
25+
@MAX_API.doc('predict')
26+
@MAX_API.expect(input_parser)
27+
@MAX_API.marshal_with(predict_response)
28+
def post(self):
29+
"""Make a prediction given input data"""
30+
result = {'status': 'error'}
31+
32+
args = input_parser.parse_args()
33+
image_data = args['image'].read()
34+
image = self.model_wrapper.read_image(image_data)
35+
preds = self.model_wrapper.predict(image)
36+
37+
label_preds = [{'label_id': p[0], 'label': p[1], 'probability': p[2]} for p in [x for x in preds]]
38+
result['predictions'] = label_preds
39+
result['status'] = 'ok'
40+
41+
return result

app.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
1-
import os
2-
from flask import Flask
3-
from api import api
1+
from maxfw.core import MAXApp
2+
from api import ModelMetadataAPI, ModelPredictAPI
3+
from config import API_TITLE, API_DESC, API_VERSION
44

5-
app = Flask(__name__)
6-
# load default config
7-
app.config.from_object('config')
8-
# load override config if exists
9-
if 'APP_CONFIG' in os.environ:
10-
app.config.from_envvar('APP_CONFIG')
11-
api.init_app(app)
12-
13-
if __name__ == '__main__':
14-
app.run(host='0.0.0.0')
5+
max_app = MAXApp(API_TITLE, API_DESC, API_VERSION)
6+
max_app.add_api(ModelMetadataAPI, '/metadata')
7+
max_app.add_api(ModelPredictAPI, '/predict')
8+
max_app.run()

config.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,15 @@
11
# Application settings
22

3-
# Flask settings
3+
# Flask settings
44
DEBUG = False
55

66
# Flask-restplus settings
77
RESTPLUS_MASK_SWAGGER = False
88
SWAGGER_UI_DOC_EXPANSION = 'none'
99

1010
# API metadata
11-
API_TITLE = 'Model Asset Exchange Server'
12-
API_DESC = 'An API for serving models'
11+
API_TITLE = 'MAX ResNet 50'
12+
API_DESC = 'Identify objects in images using a first-generation deep residual network.'
1313
API_VERSION = '0.1'
1414

15-
# Model settings
16-
models = {
17-
'resnet50': {'size': (224, 224), 'license': 'MIT'}
18-
}
19-
20-
# default model
21-
MODEL_NAME = 'resnet50'
22-
DEFAULT_MODEL_PATH = 'assets/{}.h5'.format(MODEL_NAME)
23-
# for image models, may not be required
24-
MODEL_INPUT_IMG_SIZE = models[MODEL_NAME]['size']
25-
MODEL_LICENSE = models[MODEL_NAME]['license']
26-
27-
MODEL_META_DATA = {
28-
'id': '{}-keras-imagenet'.format(MODEL_NAME.lower()),
29-
'name': '{} Keras Model'.format(MODEL_NAME),
30-
'description': '{} Keras model trained on ImageNet'.format(MODEL_NAME),
31-
'type': 'image_classification',
32-
'license': '{}'.format(MODEL_LICENSE)
33-
}
15+
DEFAULT_MODEL_PATH = 'assets/resnet50.h5'

core/backend.py

Lines changed: 0 additions & 44 deletions
This file was deleted.

core/model.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import io
2+
import logging
3+
from PIL import Image
4+
from keras.backend import clear_session
5+
from keras import models
6+
from keras.preprocessing.image import img_to_array
7+
from keras.applications import imagenet_utils
8+
import numpy as np
9+
from maxfw.model import MAXModelWrapper
10+
from config import DEFAULT_MODEL_PATH
11+
12+
13+
logger = logging.getLogger()
14+
15+
16+
class ModelWrapper(MAXModelWrapper):
17+
"""Model wrapper for Keras models"""
18+
19+
MODEL_NAME = 'resnet50'
20+
MODEL_INPUT_IMG_SIZE = (224, 224)
21+
MODEL_LICENSE = 'MIT'
22+
MODEL_MODE = 'caffe'
23+
MODEL_META_DATA = {
24+
'id': '{}-keras-imagenet'.format(MODEL_NAME.lower()),
25+
'name': '{} Keras Model'.format(MODEL_NAME),
26+
'description': '{} Keras model trained on ImageNet'.format(MODEL_NAME),
27+
'type': 'image_classification',
28+
'license': '{}'.format(MODEL_LICENSE),
29+
'source': 'https://developer.ibm.com/exchanges/models/all/max-resnet-50/'
30+
}
31+
32+
def __init__(self, path=DEFAULT_MODEL_PATH):
33+
logger.info('Loading model from: {}...'.format(path))
34+
clear_session()
35+
self.model = models.load_model(path)
36+
# this seems to be required to make Keras models play nicely with threads
37+
self.model._make_predict_function()
38+
logger.info('Loaded model: {}'.format(self.model.name))
39+
40+
def read_image(self, image_data):
41+
return Image.open(io.BytesIO(image_data))
42+
43+
def _pre_process(self, image):
44+
image = image.resize(self.MODEL_INPUT_IMG_SIZE)
45+
image = img_to_array(image)
46+
image = np.expand_dims(image, axis=0)
47+
return imagenet_utils.preprocess_input(image, mode=self.MODEL_MODE)
48+
49+
def _post_process(self, preds):
50+
return imagenet_utils.decode_predictions(preds)[0]
51+
52+
def _predict(self, x):
53+
return self.model.predict(x)

0 commit comments

Comments
 (0)