diff --git a/tableaudocumentapi/workbook.py b/tableaudocumentapi/workbook.py
index 889f746..0da1827 100644
--- a/tableaudocumentapi/workbook.py
+++ b/tableaudocumentapi/workbook.py
@@ -3,10 +3,57 @@
# Workbook - A class for writing Tableau workbook files
#
###############################################################################
+import contextlib
import os
+import shutil
+import tempfile
+import zipfile
+
import xml.etree.ElementTree as ET
+
from tableaudocumentapi import Datasource
+###########################################################################
+#
+# Utility Functions
+#
+###########################################################################
+
+
+@contextlib.contextmanager
+def temporary_directory(*args, **kwargs):
+ d = tempfile.mkdtemp(*args, **kwargs)
+ try:
+ yield d
+ finally:
+ shutil.rmtree(d)
+
+
+def find_twb_in_zip(zip):
+ for filename in zip.namelist():
+ if os.path.splitext(filename)[-1].lower() == '.twb':
+ return filename
+
+
+def get_twb_xml_from_twbx(filename):
+ with temporary_directory() as temp:
+ with zipfile.ZipFile(filename) as zf:
+ zf.extractall(temp)
+ twb_file = find_twb_in_zip(zf)
+ twb_xml = ET.parse(os.path.join(temp, twb_file))
+
+ return twb_xml
+
+
+def build_twbx_file(twbx_contents, zip):
+ for root_dir, _, files in os.walk(twbx_contents):
+ relative_dir = os.path.relpath(root_dir, twbx_contents)
+ for f in files:
+ temp_file_full_path = os.path.join(
+ twbx_contents, relative_dir, f)
+ zipname = os.path.join(relative_dir, f)
+ zip.write(temp_file_full_path, arcname=zipname)
+
class Workbook(object):
"""
@@ -24,30 +71,18 @@ def __init__(self, filename):
Constructor.
"""
- # We have a valid type of input file
- if self._is_valid_file(filename):
- # set our filename, open .twb, initialize things
- self._filename = filename
- self._workbookTree = ET.parse(filename)
- self._workbookRoot = self._workbookTree.getroot()
-
- # prepare our datasource objects
- self._datasources = self._prepare_datasources(
- self._workbookRoot) # self.workbookRoot.find('datasources')
- else:
- print('Invalid file type. Must be .twb or .tds.')
- raise Exception()
-
- @classmethod
- def from_file(cls, filename):
- "Initialize datasource from file (.tds)"
- if self._is_valid_file(filename):
- self._filename = filename
- dsxml = ET.parse(filename).getroot()
- return cls(dsxml)
+ self._filename = filename
+
+ # Determine if this is a twb or twbx and get the xml root
+ if zipfile.is_zipfile(self._filename):
+ self._workbookTree = get_twb_xml_from_twbx(self._filename)
else:
- print('Invalid file type. Must be .twb or .tds.')
- raise Exception()
+ self._workbookTree = ET.parse(self._filename)
+
+ self._workbookRoot = self._workbookTree.getroot()
+ # prepare our datasource objects
+ self._datasources = self._prepare_datasources(
+ self._workbookRoot) # self.workbookRoot.find('datasources')
###########
# datasources
@@ -76,7 +111,12 @@ def save(self):
"""
# save the file
- self._workbookTree.write(self._filename, encoding="utf-8", xml_declaration=True)
+
+ if zipfile.is_zipfile(self._filename):
+ self._save_into_twbx(self._filename)
+ else:
+ self._workbookTree.write(
+ self._filename, encoding="utf-8", xml_declaration=True)
def save_as(self, new_filename):
"""
@@ -90,7 +130,11 @@ def save_as(self, new_filename):
"""
- self._workbookTree.write(new_filename, encoding="utf-8", xml_declaration=True)
+ if zipfile.is_zipfile(self._filename):
+ self._save_into_twbx(new_filename)
+ else:
+ self._workbookTree.write(
+ new_filename, encoding="utf-8", xml_declaration=True)
###########################################################################
#
@@ -107,6 +151,29 @@ def _prepare_datasources(self, xmlRoot):
return datasources
+ def _save_into_twbx(self, filename=None):
+ # Save reuses existing filename, 'save as' takes a new one
+ if filename is None:
+ filename = self._filename
+
+ # Saving a twbx means extracting the contents into a temp folder,
+ # saving the changes over the twb in that folder, and then
+ # packaging it back up into a specifically formatted zip with the correct
+ # relative file paths
+
+ # Extract to temp directory
+ with temporary_directory() as temp_path:
+ with zipfile.ZipFile(self._filename) as zf:
+ twb_file = find_twb_in_zip(zf)
+ zf.extractall(temp_path)
+ # Write the new version of the twb to the temp directory
+ self._workbookTree.write(os.path.join(
+ temp_path, twb_file), encoding="utf-8", xml_declaration=True)
+
+ # Write the new twbx with the contents of the temp folder
+ with zipfile.ZipFile(filename, "w", compression=zipfile.ZIP_DEFLATED) as new_twbx:
+ build_twbx_file(temp_path, new_twbx)
+
@staticmethod
def _is_valid_file(filename):
fileExtension = os.path.splitext(filename)[-1].lower()
diff --git a/test/assets/CONNECTION.xml b/test/assets/CONNECTION.xml
new file mode 100644
index 0000000..392d112
--- /dev/null
+++ b/test/assets/CONNECTION.xml
@@ -0,0 +1 @@
+
diff --git a/test/assets/TABLEAU_10_TDS.tds b/test/assets/TABLEAU_10_TDS.tds
new file mode 100644
index 0000000..7a81784
--- /dev/null
+++ b/test/assets/TABLEAU_10_TDS.tds
@@ -0,0 +1 @@
+
diff --git a/test/assets/TABLEAU_10_TWB.twb b/test/assets/TABLEAU_10_TWB.twb
new file mode 100644
index 0000000..c116bdf
--- /dev/null
+++ b/test/assets/TABLEAU_10_TWB.twb
@@ -0,0 +1 @@
+
diff --git a/test/assets/TABLEAU_10_TWBX.twbx b/test/assets/TABLEAU_10_TWBX.twbx
new file mode 100644
index 0000000..ef8f910
Binary files /dev/null and b/test/assets/TABLEAU_10_TWBX.twbx differ
diff --git a/test/assets/TABLEAU_93_TDS.tds b/test/assets/TABLEAU_93_TDS.tds
new file mode 100644
index 0000000..2afa3ea
--- /dev/null
+++ b/test/assets/TABLEAU_93_TDS.tds
@@ -0,0 +1 @@
+
diff --git a/test/assets/TABLEAU_93_TWB.twb b/test/assets/TABLEAU_93_TWB.twb
new file mode 100644
index 0000000..cdb6484
--- /dev/null
+++ b/test/assets/TABLEAU_93_TWB.twb
@@ -0,0 +1 @@
+
diff --git a/test/bvt.py b/test/bvt.py
index f521465..aa4a247 100644
--- a/test/bvt.py
+++ b/test/bvt.py
@@ -1,23 +1,23 @@
-import unittest
-import io
import os
+import unittest
+
import xml.etree.ElementTree as ET
from tableaudocumentapi import Workbook, Datasource, Connection, ConnectionParser
-# Disable the 120 line limit because of the embedded XML on these lines
-# TODO: Move the XML into external files and load them when needed
+TEST_DIR = os.path.dirname(__file__)
+
+TABLEAU_93_TWB = os.path.join(TEST_DIR, 'assets', 'TABLEAU_93_TWB.twb')
-TABLEAU_93_WORKBOOK = '''''' # noqa
+TABLEAU_93_TDS = os.path.join(TEST_DIR, 'assets', 'TABLEAU_93_TDS.tds')
-TABLEAU_93_TDS = '''''' # noqa
+TABLEAU_10_TDS = os.path.join(TEST_DIR, 'assets', 'TABLEAU_10_TDS.tds')
-TABLEAU_10_TDS = '''''' # noqa
+TABLEAU_10_TWB = os.path.join(TEST_DIR, 'assets', 'TABLEAU_10_TWB.twb')
-TABLEAU_10_WORKBOOK = '''''' # noqa
+TABLEAU_CONNECTION_XML = ET.parse(os.path.join(TEST_DIR, 'assets', 'CONNECTION.xml')).getroot()
-TABLEAU_CONNECTION_XML = ET.fromstring(
- '''''') # noqa
+TABLEAU_10_TWBX = os.path.join(TEST_DIR, 'assets', 'TABLEAU_10_TWBX.twbx')
class HelperMethodTests(unittest.TestCase):
@@ -36,14 +36,14 @@ def test_is_valid_file_with_invalid_inputs(self):
class ConnectionParserTests(unittest.TestCase):
def test_can_extract_legacy_connection(self):
- parser = ConnectionParser(ET.fromstring(TABLEAU_93_TDS), '9.2')
+ parser = ConnectionParser(ET.parse(TABLEAU_93_TDS), '9.2')
connections = parser.get_connections()
self.assertIsInstance(connections, list)
self.assertIsInstance(connections[0], Connection)
self.assertEqual(connections[0].dbname, 'TestV1')
def test_can_extract_federated_connections(self):
- parser = ConnectionParser(ET.fromstring(TABLEAU_10_TDS), '10.0')
+ parser = ConnectionParser(ET.parse(TABLEAU_10_TDS), '10.0')
connections = parser.get_connections()
self.assertIsInstance(connections, list)
self.assertIsInstance(connections[0], Connection)
@@ -76,9 +76,9 @@ def test_can_write_attributes_to_connection(self):
class DatasourceModelTests(unittest.TestCase):
def setUp(self):
- self.tds_file = io.FileIO('test.tds', 'w')
- self.tds_file.write(TABLEAU_93_TDS.encode('utf8'))
- self.tds_file.seek(0)
+ with open(TABLEAU_93_TDS, 'rb') as in_file, open('test.tds', 'wb') as out_file:
+ out_file.write(in_file.read())
+ self.tds_file = out_file
def tearDown(self):
self.tds_file.close()
@@ -117,9 +117,9 @@ def test_save_has_xml_declaration(self):
class DatasourceModelV10Tests(unittest.TestCase):
def setUp(self):
- self.tds_file = io.FileIO('test10.tds', 'w')
- self.tds_file.write(TABLEAU_10_TDS.encode('utf8'))
- self.tds_file.seek(0)
+ with open(TABLEAU_10_TDS, 'rb') as in_file, open('test.twb', 'wb') as out_file:
+ out_file.write(in_file.read())
+ self.tds_file = out_file
def tearDown(self):
self.tds_file.close()
@@ -147,9 +147,9 @@ def test_can_save_tds(self):
class WorkbookModelTests(unittest.TestCase):
def setUp(self):
- self.workbook_file = io.FileIO('test.twb', 'w')
- self.workbook_file.write(TABLEAU_93_WORKBOOK.encode('utf8'))
- self.workbook_file.seek(0)
+ with open(TABLEAU_93_TWB, 'rb') as in_file, open('test.twb', 'wb') as out_file:
+ out_file.write(in_file.read())
+ self.workbook_file = out_file
def tearDown(self):
self.workbook_file.close()
@@ -175,9 +175,9 @@ def test_can_update_datasource_connection_and_save(self):
class WorkbookModelV10Tests(unittest.TestCase):
def setUp(self):
- self.workbook_file = io.FileIO('testv10.twb', 'w')
- self.workbook_file.write(TABLEAU_10_WORKBOOK.encode('utf8'))
- self.workbook_file.seek(0)
+ with open(TABLEAU_10_TWB, 'rb') as in_file, open('test.twb', 'wb') as out_file:
+ out_file.write(in_file.read())
+ self.workbook_file = out_file
def tearDown(self):
self.workbook_file.close()
@@ -213,5 +213,43 @@ def test_save_has_xml_declaration(self):
self.assertEqual(
first_line, "")
+
+class WorkbookModelV10TWBXTests(unittest.TestCase):
+
+ def setUp(self):
+ with open(TABLEAU_10_TWBX, 'rb') as in_file, open('test.twbx', 'wb') as out_file:
+ out_file.write(in_file.read())
+ self.workbook_file = out_file
+
+ def tearDown(self):
+ self.workbook_file.close()
+ os.unlink(self.workbook_file.name)
+
+ def test_can_open_twbx(self):
+ wb = Workbook(self.workbook_file.name)
+ self.assertTrue(wb.datasources)
+ self.assertTrue(wb.datasources[0].connections)
+
+ def test_can_open_twbx_and_save_changes(self):
+ original_wb = Workbook(self.workbook_file.name)
+ original_wb.datasources[0].connections[0].server = 'newdb.test.tsi.lan'
+ original_wb.save()
+
+ new_wb = Workbook(self.workbook_file.name)
+ self.assertEqual(new_wb.datasources[0].connections[
+ 0].server, 'newdb.test.tsi.lan')
+
+ def test_can_open_twbx_and_save_as_changes(self):
+ new_twbx_filename = self.workbook_file.name + "_TEST_SAVE_AS"
+ original_wb = Workbook(self.workbook_file.name)
+ original_wb.datasources[0].connections[0].server = 'newdb.test.tsi.lan'
+ original_wb.save_as(new_twbx_filename)
+
+ new_wb = Workbook(new_twbx_filename)
+ self.assertEqual(new_wb.datasources[0].connections[
+ 0].server, 'newdb.test.tsi.lan')
+
+ os.unlink(new_twbx_filename)
+
if __name__ == '__main__':
unittest.main()