diff --git a/tableaudocumentapi/datasource.py b/tableaudocumentapi/datasource.py index 617004a..b4fb8ed 100644 --- a/tableaudocumentapi/datasource.py +++ b/tableaudocumentapi/datasource.py @@ -3,8 +3,11 @@ # Datasource - A class for writing datasources to Tableau files # ############################################################################### +import os +import zipfile + import xml.etree.ElementTree as ET -from tableaudocumentapi import Connection +from tableaudocumentapi import Connection, xfile class ConnectionParser(object): @@ -56,7 +59,11 @@ def __init__(self, dsxml, filename=None): @classmethod def from_file(cls, filename): "Initialize datasource from file (.tds)" - dsxml = ET.parse(filename).getroot() + + if zipfile.is_zipfile(filename): + dsxml = xfile.get_xml_from_archive(filename).getroot() + else: + dsxml = ET.parse(filename).getroot() return cls(dsxml, filename) def save(self): @@ -72,7 +79,8 @@ def save(self): """ # save the file - self._datasourceTree.write(self._filename, encoding="utf-8", xml_declaration=True) + + xfile._save_file(self._filename, self._datasourceTree) def save_as(self, new_filename): """ @@ -85,7 +93,7 @@ def save_as(self, new_filename): Nothing. """ - self._datasourceTree.write(new_filename, encoding="utf-8", xml_declaration=True) + xfile._save_file(self._filename, self._datasourceTree, new_filename) ########### # name diff --git a/tableaudocumentapi/workbook.py b/tableaudocumentapi/workbook.py index 0da1827..9e29973 100644 --- a/tableaudocumentapi/workbook.py +++ b/tableaudocumentapi/workbook.py @@ -3,15 +3,12 @@ # Workbook - A class for writing Tableau workbook files # ############################################################################### -import contextlib import os -import shutil -import tempfile import zipfile import xml.etree.ElementTree as ET -from tableaudocumentapi import Datasource +from tableaudocumentapi import Datasource, xfile ########################################################################### # @@ -20,41 +17,6 @@ ########################################################################### -@contextlib.contextmanager -def temporary_directory(*args, **kwargs): - d = tempfile.mkdtemp(*args, **kwargs) - try: - yield d - finally: - shutil.rmtree(d) - - -def find_twb_in_zip(zip): - for filename in zip.namelist(): - if os.path.splitext(filename)[-1].lower() == '.twb': - return filename - - -def get_twb_xml_from_twbx(filename): - with temporary_directory() as temp: - with zipfile.ZipFile(filename) as zf: - zf.extractall(temp) - twb_file = find_twb_in_zip(zf) - twb_xml = ET.parse(os.path.join(temp, twb_file)) - - return twb_xml - - -def build_twbx_file(twbx_contents, zip): - for root_dir, _, files in os.walk(twbx_contents): - relative_dir = os.path.relpath(root_dir, twbx_contents) - for f in files: - temp_file_full_path = os.path.join( - twbx_contents, relative_dir, f) - zipname = os.path.join(relative_dir, f) - zip.write(temp_file_full_path, arcname=zipname) - - class Workbook(object): """ A class for writing Tableau workbook files. @@ -75,7 +37,8 @@ def __init__(self, filename): # Determine if this is a twb or twbx and get the xml root if zipfile.is_zipfile(self._filename): - self._workbookTree = get_twb_xml_from_twbx(self._filename) + self._workbookTree = xfile.get_xml_from_archive( + self._filename) else: self._workbookTree = ET.parse(self._filename) @@ -111,12 +74,7 @@ def save(self): """ # save the file - - if zipfile.is_zipfile(self._filename): - self._save_into_twbx(self._filename) - else: - self._workbookTree.write( - self._filename, encoding="utf-8", xml_declaration=True) + xfile._save_file(self._filename, self._workbookTree) def save_as(self, new_filename): """ @@ -129,12 +87,8 @@ def save_as(self, new_filename): Nothing. """ - - if zipfile.is_zipfile(self._filename): - self._save_into_twbx(new_filename) - else: - self._workbookTree.write( - new_filename, encoding="utf-8", xml_declaration=True) + xfile._save_file( + self._filename, self._workbookTree, new_filename) ########################################################################### # @@ -150,31 +104,3 @@ def _prepare_datasources(self, xmlRoot): datasources.append(ds) return datasources - - def _save_into_twbx(self, filename=None): - # Save reuses existing filename, 'save as' takes a new one - if filename is None: - filename = self._filename - - # Saving a twbx means extracting the contents into a temp folder, - # saving the changes over the twb in that folder, and then - # packaging it back up into a specifically formatted zip with the correct - # relative file paths - - # Extract to temp directory - with temporary_directory() as temp_path: - with zipfile.ZipFile(self._filename) as zf: - twb_file = find_twb_in_zip(zf) - zf.extractall(temp_path) - # Write the new version of the twb to the temp directory - self._workbookTree.write(os.path.join( - temp_path, twb_file), encoding="utf-8", xml_declaration=True) - - # Write the new twbx with the contents of the temp folder - with zipfile.ZipFile(filename, "w", compression=zipfile.ZIP_DEFLATED) as new_twbx: - build_twbx_file(temp_path, new_twbx) - - @staticmethod - def _is_valid_file(filename): - fileExtension = os.path.splitext(filename)[-1].lower() - return fileExtension in ('.twb', '.tds') diff --git a/tableaudocumentapi/xfile.py b/tableaudocumentapi/xfile.py new file mode 100644 index 0000000..13e08c7 --- /dev/null +++ b/tableaudocumentapi/xfile.py @@ -0,0 +1,76 @@ +import contextlib +import os +import shutil +import tempfile +import zipfile + +import xml.etree.ElementTree as ET + + +@contextlib.contextmanager +def temporary_directory(*args, **kwargs): + d = tempfile.mkdtemp(*args, **kwargs) + try: + yield d + finally: + shutil.rmtree(d) + + +def find_file_in_zip(zip): + for filename in zip.namelist(): + try: + with zip.open(filename) as xml_candidate: + ET.parse(xml_candidate).getroot().tag in ( + 'workbook', 'datasource') + return filename + except ET.ParseError: + # That's not an XML file by gosh + pass + + +def get_xml_from_archive(filename): + with zipfile.ZipFile(filename) as zf: + with zf.open(find_file_in_zip(zf)) as xml_file: + xml_tree = ET.parse(xml_file) + + return xml_tree + + +def build_archive_file(archive_contents, zip): + for root_dir, _, files in os.walk(archive_contents): + relative_dir = os.path.relpath(root_dir, archive_contents) + for f in files: + temp_file_full_path = os.path.join( + archive_contents, relative_dir, f) + zipname = os.path.join(relative_dir, f) + zip.write(temp_file_full_path, arcname=zipname) + + +def save_into_archive(xml_tree, filename, new_filename=None): + # Saving a archive means extracting the contents into a temp folder, + # saving the changes over the twb/tds in that folder, and then + # packaging it back up into a specifically formatted zip with the correct + # relative file paths + + if new_filename is None: + new_filename = filename + + # Extract to temp directory + with temporary_directory() as temp_path: + with zipfile.ZipFile(filename) as zf: + xml_file = find_file_in_zip(zf) + zf.extractall(temp_path) + # Write the new version of the file to the temp directory + xml_tree.write(os.path.join( + temp_path, xml_file), encoding="utf-8", xml_declaration=True) + + # Write the new archive with the contents of the temp folder + with zipfile.ZipFile(new_filename, "w", compression=zipfile.ZIP_DEFLATED) as new_archive: + build_archive_file(temp_path, new_archive) + + +def _save_file(container_file, xml_tree, new_filename=None): + if zipfile.is_zipfile(container_file): + save_into_archive(xml_tree, container_file, new_filename) + else: + xml_tree.write(container_file, encoding="utf-8", xml_declaration=True) diff --git a/test/assets/TABLEAU_10_TDSX.tdsx b/test/assets/TABLEAU_10_TDSX.tdsx new file mode 100644 index 0000000..f94b678 Binary files /dev/null and b/test/assets/TABLEAU_10_TDSX.tdsx differ diff --git a/test/bvt.py b/test/bvt.py index aa4a247..1dedd57 100644 --- a/test/bvt.py +++ b/test/bvt.py @@ -15,22 +15,12 @@ TABLEAU_10_TWB = os.path.join(TEST_DIR, 'assets', 'TABLEAU_10_TWB.twb') -TABLEAU_CONNECTION_XML = ET.parse(os.path.join(TEST_DIR, 'assets', 'CONNECTION.xml')).getroot() +TABLEAU_CONNECTION_XML = ET.parse(os.path.join( + TEST_DIR, 'assets', 'CONNECTION.xml')).getroot() TABLEAU_10_TWBX = os.path.join(TEST_DIR, 'assets', 'TABLEAU_10_TWBX.twbx') - -class HelperMethodTests(unittest.TestCase): - - def test_is_valid_file_with_valid_inputs(self): - self.assertTrue(Workbook._is_valid_file('file1.tds')) - self.assertTrue(Workbook._is_valid_file('file2.twb')) - self.assertTrue(Workbook._is_valid_file('tds.twb')) - - def test_is_valid_file_with_invalid_inputs(self): - self.assertFalse(Workbook._is_valid_file('')) - self.assertFalse(Workbook._is_valid_file('file1.tds2')) - self.assertFalse(Workbook._is_valid_file('file2.twb3')) +TABLEAU_10_TDSX = os.path.join(TEST_DIR, 'assets', 'TABLEAU_10_TDSX.tdsx') class ConnectionParserTests(unittest.TestCase): @@ -144,6 +134,43 @@ def test_can_save_tds(self): self.assertEqual(new_tds.connections[0].dbname, 'newdb.test.tsi.lan') +class DatasourceModelV10TDSXTests(unittest.TestCase): + + def setUp(self): + with open(TABLEAU_10_TDSX, 'rb') as in_file, open('test.tdsx', 'wb') as out_file: + out_file.write(in_file.read()) + self.tdsx_file = out_file + + def tearDown(self): + self.tdsx_file.close() + os.unlink(self.tdsx_file.name) + + def test_can_open_tdsx(self): + ds = Datasource.from_file(self.tdsx_file.name) + self.assertTrue(ds.connections) + self.assertTrue(ds.name) + + def test_can_open_tdsx_and_save_changes(self): + original_tdsx = Datasource.from_file(self.tdsx_file.name) + original_tdsx.connections[0].server = 'newdb.test.tsi.lan' + original_tdsx.save() + + new_tdsx = Datasource.from_file(self.tdsx_file.name) + self.assertEqual(new_tdsx.connections[ + 0].server, 'newdb.test.tsi.lan') + + def test_can_open_tdsx_and_save_as_changes(self): + new_tdsx_filename = 'newtdsx.tdsx' + original_wb = Datasource.from_file(self.tdsx_file.name) + original_wb.connections[0].server = 'newdb.test.tsi.lan' + original_wb.save_as(new_tdsx_filename) + + new_wb = Datasource.from_file(new_tdsx_filename) + self.assertEqual(new_wb.connections[ + 0].server, 'newdb.test.tsi.lan') + os.unlink(new_tdsx_filename) + + class WorkbookModelTests(unittest.TestCase): def setUp(self): @@ -240,7 +267,7 @@ def test_can_open_twbx_and_save_changes(self): 0].server, 'newdb.test.tsi.lan') def test_can_open_twbx_and_save_as_changes(self): - new_twbx_filename = self.workbook_file.name + "_TEST_SAVE_AS" + new_twbx_filename = 'newtwbx.twbx' original_wb = Workbook(self.workbook_file.name) original_wb.datasources[0].connections[0].server = 'newdb.test.tsi.lan' original_wb.save_as(new_twbx_filename)