Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,18 @@ on:
jobs:
build:

runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
python-version: [3.6, 3.7, 3.8]
include:
- python-version: 3.8
push-package: true
- os: windows-latest
python-version: 3.8
- os: macos-latest
python-version: 3.8

steps:
- uses: actions/checkout@v2
Expand All @@ -24,25 +32,25 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
make dev-venv
make dev-venv SYSTEM_PYTHON=python
- name: Lint
run: |
make dev-lint
- name: Test with pytest
run: |
make dev-pytest
- name: Build dist
if: matrix.python-version == '3.8'
if: matrix.push-package == true
run: |
make dev-remove-dist dev-build-dist dev-list-dist-contents dev-test-install-dist
- name: Publish distribution to Test PyPI
if: matrix.python-version == '3.8'
if: matrix.push-package == true
uses: pypa/gh-action-pypi-publish@master
with:
password: ${{ secrets.test_pypi_password }}
repository_url: https://test.pypi.org/legacy/
- name: Publish distribution to PyPI
if: matrix.python-version == '3.8' && startsWith(github.ref, 'refs/tags')
if: matrix.push-package == true && startsWith(github.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@master
with:
password: ${{ secrets.pypi_password }}
15 changes: 12 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
VENV = venv
PIP = $(VENV)/bin/pip
PYTHON = $(VENV)/bin/python

ifeq ($(OS),Windows_NT)
VENV_BIN = $(VENV)/Scripts
else
VENV_BIN = $(VENV)/bin
endif

PYTHON = $(VENV_BIN)/python
PIP = $(VENV_BIN)/python -m pip

SYSTEM_PYTHON = python3

VENV_TEMP = venv_temp

Expand Down Expand Up @@ -30,7 +39,7 @@ venv-clean:


venv-create:
python3 -m venv $(VENV)
$(SYSTEM_PYTHON) -m venv $(VENV)


dev-install:
Expand Down
18 changes: 18 additions & 0 deletions tests/cli_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import logging
from pathlib import Path

from tf_bodypix.download import BodyPixModelPaths
from tf_bodypix.cli import main


LOGGER = logging.getLogger(__name__)


EXAMPLE_IMAGE_URL = (
r'https://upload.wikimedia.org/wikipedia/commons/thumb/5/5e/'
r'Person_Of_Interest_-_Panel_%289353656298%29.jpg/'
Expand Down Expand Up @@ -69,3 +74,16 @@ def test_should_not_fail_to_replace_background(self, temp_dir: Path):
'--background=%s' % EXAMPLE_IMAGE_URL,
'--output=%s' % output_image_path
])

def test_should_list_all_default_model_urls(self, capsys):
expected_urls = [
value
for key, value in BodyPixModelPaths.__dict__.items()
if not key.startswith('_')
]
main(['list-models'])
captured = capsys.readouterr()
output_urls = captured.out.splitlines()
LOGGER.debug('output_urls: %s', output_urls)
missing_urls = set(expected_urls) - set(output_urls)
assert not missing_urls
2 changes: 1 addition & 1 deletion tf_bodypix/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def add_arguments(self, parser: argparse.ArgumentParser):
add_common_arguments(parser)
parser.add_argument(
"--storage-url",
default="https://storage.googleapis.com/tfjs-models/",
default="https://storage.googleapis.com/tfjs-models",
help="The base URL for the storage containing the models"
)

Expand Down
2 changes: 1 addition & 1 deletion tf_bodypix/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def download_model(model_path: str) -> str:
for weights_manifest_path in weights_manifest_paths:
local_model_json_path = tf.keras.utils.get_file(
os.path.basename(weights_manifest_path),
os.path.join(model_base_path, weights_manifest_path),
model_base_path + '/' + weights_manifest_path,
cache_subdir=cache_subdir,
)
return local_model_path
5 changes: 3 additions & 2 deletions tf_bodypix/utils/s3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging

import os
import urllib.request
from xml.etree import ElementTree
from typing import Iterable
Expand All @@ -17,6 +16,8 @@


def iter_s3_file_urls(base_url: str) -> Iterable[str]:
if not base_url.endswith('/'):
base_url += '/'
marker = None
while True:
current_url = base_url
Expand All @@ -29,7 +30,7 @@ def iter_s3_file_urls(base_url: str) -> Iterable[str]:
for item in root.findall(S3_CONTENTS):
key = item.findtext(S3_KEY)
LOGGER.debug('key: %s', key)
yield os.path.join(base_url, key)
yield base_url + key
next_marker = root.findtext(S3_NEXT_MARKER)
if not next_marker or next_marker == marker:
break
Expand Down