Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 60 additions & 8 deletions system_tests/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ def _assert_coordinate(self, coordinate):
self.assertIsInstance(coordinate, (int, float))
self.assertNotEqual(coordinate, 0.0)

def _assert_likelihood(self, likelihood):
from google.cloud.vision.likelihood import Likelihood

levels = [Likelihood.UNKNOWN, Likelihood.VERY_LIKELY,
Likelihood.UNLIKELY, Likelihood.POSSIBLE, Likelihood.LIKELY,
Likelihood.VERY_UNLIKELY]
self.assertIn(likelihood, levels)


class TestVisionClientLogo(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -130,14 +138,6 @@ def tearDown(self):
for value in self.to_delete_by_case:
value.delete()

def _assert_likelihood(self, likelihood):
from google.cloud.vision.likelihood import Likelihood

levels = [Likelihood.UNKNOWN, Likelihood.VERY_LIKELY,
Likelihood.UNLIKELY, Likelihood.POSSIBLE, Likelihood.LIKELY,
Likelihood.VERY_UNLIKELY]
self.assertIn(likelihood, levels)

def _assert_landmarks(self, landmarks):
from google.cloud.vision.face import Landmark
from google.cloud.vision.face import LandmarkTypes
Expand Down Expand Up @@ -340,3 +340,55 @@ def test_detect_landmark_filename(self):
self.assertEqual(len(landmarks), 1)
landmark = landmarks[0]
self._assert_landmark(landmark)


class TestVisionClientSafeSearch(BaseVisionTestCase):
def setUp(self):
self.to_delete_by_case = []

def tearDown(self):
for value in self.to_delete_by_case:
value.delete()

def _assert_safe_search(self, safe_search):
from google.cloud.vision.safe import SafeSearchAnnotation

self.assertIsInstance(safe_search, SafeSearchAnnotation)
self._assert_likelihood(safe_search.adult)
self._assert_likelihood(safe_search.spoof)
self._assert_likelihood(safe_search.medical)
self._assert_likelihood(safe_search.violence)

def test_detect_safe_search_content(self):
client = Config.CLIENT
with open(FACE_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
safe_searches = image.detect_safe_search()
self.assertEqual(len(safe_searches), 1)
safe_search = safe_searches[0]
self._assert_safe_search(safe_search)

def test_detect_safe_search_gcs(self):
bucket_name = Config.TEST_BUCKET.name
blob_name = 'faces.jpg'
blob = Config.TEST_BUCKET.blob(blob_name)
self.to_delete_by_case.append(blob) # Clean-up.
with open(FACE_FILE, 'rb') as file_obj:
blob.upload_from_file(file_obj)

source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

client = Config.CLIENT
image = client.image(source_uri=source_uri)
safe_searches = image.detect_safe_search()
self.assertEqual(len(safe_searches), 1)
safe_search = safe_searches[0]
self._assert_safe_search(safe_search)

def test_detect_safe_search_filename(self):
client = Config.CLIENT
image = client.image(filename=FACE_FILE)
safe_searches = image.detect_safe_search()
self.assertEqual(len(safe_searches), 1)
safe_search = safe_searches[0]
self._assert_safe_search(safe_search)