Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions tests/test_mmar_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import tempfile
import unittest
from urllib.error import ContentTooShortError, HTTPError

import numpy as np
import torch
Expand Down Expand Up @@ -85,18 +86,30 @@ class TestMMMARDownload(unittest.TestCase):
@skip_if_quick
@SkipIfBeforePyTorchVersion((1, 6))
def test_download(self, idx):
download_mmar(idx)
download_mmar(idx, progress=False) # repeated to check caching
with tempfile.TemporaryDirectory() as tmp_dir:
download_mmar(idx, mmar_dir=tmp_dir, progress=False)
download_mmar(idx, mmar_dir=tmp_dir, progress=False) # repeated to check caching
self.assertTrue(os.path.exists(os.path.join(tmp_dir, idx)))
try:
download_mmar(idx)
download_mmar(idx, progress=False) # repeated to check caching
with tempfile.TemporaryDirectory() as tmp_dir:
download_mmar(idx, mmar_dir=tmp_dir, progress=False)
download_mmar(idx, mmar_dir=tmp_dir, progress=False) # repeated to check caching
self.assertTrue(os.path.exists(os.path.join(tmp_dir, idx)))
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))

@Nic-Ma Nic-Ma Jun 14, 2021

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a basic test for the error message? Just like:
https://github.com/Project-MONAI/MONAI/blob/dev/tests/test_download_and_extract.py#L36

Thanks.

if isinstance(e, HTTPError):
self.assertTrue("500" in str(e)) # http error has the code 500
return # skipping this test due the network connection errors

@parameterized.expand(TEST_EXTRACT_CASES)
@skip_if_quick
@SkipIfBeforePyTorchVersion((1, 6))
def test_load_ckpt(self, input_args, expected_name, expected_val):
output = load_from_mmar(**input_args)
try:
output = load_from_mmar(**input_args)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
if isinstance(e, HTTPError):
self.assertTrue("500" in str(e)) # http error has the code 500
return
self.assertEqual(output.__class__.__name__, expected_name)
x = next(output.parameters()) # verify the first element
np.testing.assert_allclose(x[0][0].detach().cpu().numpy(), expected_val, rtol=1e-3, atol=1e-3)
Expand Down