diff --git a/tests/test_mmar_download.py b/tests/test_mmar_download.py index 7149cc5a20..31c73b8b8f 100644 --- a/tests/test_mmar_download.py +++ b/tests/test_mmar_download.py @@ -12,6 +12,7 @@ import os import tempfile import unittest +from urllib.error import ContentTooShortError, HTTPError import numpy as np import torch @@ -85,18 +86,30 @@ class TestMMMARDownload(unittest.TestCase): @skip_if_quick @SkipIfBeforePyTorchVersion((1, 6)) def test_download(self, idx): - download_mmar(idx) - download_mmar(idx, progress=False) # repeated to check caching - with tempfile.TemporaryDirectory() as tmp_dir: - download_mmar(idx, mmar_dir=tmp_dir, progress=False) - download_mmar(idx, mmar_dir=tmp_dir, progress=False) # repeated to check caching - self.assertTrue(os.path.exists(os.path.join(tmp_dir, idx))) + try: + download_mmar(idx) + download_mmar(idx, progress=False) # repeated to check caching + with tempfile.TemporaryDirectory() as tmp_dir: + download_mmar(idx, mmar_dir=tmp_dir, progress=False) + download_mmar(idx, mmar_dir=tmp_dir, progress=False) # repeated to check caching + self.assertTrue(os.path.exists(os.path.join(tmp_dir, idx))) + except (ContentTooShortError, HTTPError, RuntimeError) as e: + print(str(e)) + if isinstance(e, HTTPError): + self.assertTrue("500" in str(e)) # http error has the code 500 + return # skipping this test due the network connection errors @parameterized.expand(TEST_EXTRACT_CASES) @skip_if_quick @SkipIfBeforePyTorchVersion((1, 6)) def test_load_ckpt(self, input_args, expected_name, expected_val): - output = load_from_mmar(**input_args) + try: + output = load_from_mmar(**input_args) + except (ContentTooShortError, HTTPError, RuntimeError) as e: + print(str(e)) + if isinstance(e, HTTPError): + self.assertTrue("500" in str(e)) # http error has the code 500 + return self.assertEqual(output.__class__.__name__, expected_name) x = next(output.parameters()) # verify the first element np.testing.assert_allclose(x[0][0].detach().cpu().numpy(), expected_val, rtol=1e-3, atol=1e-3)