Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 220 additions & 26 deletions backend/pdm.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions backend/workflow_manager/endpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ class TableColumns:

class DBConnectionClass:
SNOWFLAKE = "SnowflakeConnection"
BIGQUERY = "Client"


class Snowflake:
Expand Down
153 changes: 87 additions & 66 deletions backend/workflow_manager/endpoint/database_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,58 +60,65 @@ def make_sql_values_for_query(
return sql_values

@staticmethod
def get_column_types(
engine: Any, table_name: str, cls: Any = None
) -> dict[str, str]:
"""Retrieve column types for a specified table from a database engine.

Args:
engine (Any): The database engine used to execute queries.
table_name (str): The name of the table for which column types
are retrieved.
cls (Any, optional): The database connection class (e.g.,
DBConnectionClass.SNOWFLAKE) for handling database-specific
queries.
Defaults to None.

Returns:
dict: A dictionary mapping column names to their respective data
types.

Raises:
Exception: If there is an error while retrieving column types,
an exception is raised. Exit.
def get_column_types_util(columns_with_types: Any) -> dict[str, str]:
column_types: dict[str, str] = {}
for column_name, data_type in columns_with_types:
column_types[column_name] = data_type
return column_types

Note:
- If `cls` is not provided or is None, the function assumes a
Default SQL database and queries column types accordingly.
- If `cls` is provided and matches DBConnectionClass.SNOWFLAKE,
the function queries column types using Snowflake-specific
syntax.
"""
@staticmethod
def get_column_types(
cls: Any,
table_name: str,
connector_id: str,
connector_settings: dict[str, Any],
) -> Any:
column_types: dict[str, str] = {}
try:
column_types: dict[str, str] = {}
with engine.cursor() as cursor:
if cls == DBConnectionClass.SNOWFLAKE:
results = cursor.execute(f"describe table {table_name}")
for column in results:
column_types[column[0].lower()] = column[1].split("(")[
0
]
else:
# Default to Other SQL DBs
# Postgresql treats the table names as in lowercase
# tested only with Postgresql
table_name = str.lower(table_name)
columns_with_types_query = (
"SELECT column_name, data_type FROM "
"information_schema.columns WHERE "
f"table_name = '{table_name}'"
if cls == DBConnectionClass.SNOWFLAKE:
query = f"describe table {table_name}"
results = DatabaseUtils.execute_and_fetch_data(
connector_id=connector_id,
connector_settings=connector_settings,
query=query,
)
for column in results:
column_types[column[0].lower()] = column[1].split("(")[0]
elif cls == DBConnectionClass.BIGQUERY:
table_name = str.lower(table_name)
table_list = table_name.split(".")
table_size = 3
if len(table_list) != table_size:
raise ValueError(
"Please enter project_name, dataset and table_name"
)
cursor.execute(columns_with_types_query)
columns_with_types = cursor.fetchall()
for column_name, data_type in columns_with_types:
column_types[column_name] = data_type
project_id = table_list[0]
dataset = table_list[1]
table_val = table_list[2]
query = (
"SELECT column_name, data_type FROM "
f"{project_id}.{dataset}.INFORMATION_SCHEMA.COLUMNS WHERE "
f"table_name = '{table_val}'"
)
results = DatabaseUtils.execute_and_fetch_data(
connector_id=connector_id,
connector_settings=connector_settings,
query=query,
)
column_types = DatabaseUtils.get_column_types_util(results)
else:
table_name = str.lower(table_name)
query = (
"SELECT column_name, data_type FROM "
"information_schema.columns WHERE "
f"table_name = '{table_name}'"
)
results = DatabaseUtils.execute_and_fetch_data(
connector_id=connector_id,
connector_settings=connector_settings,
query=query,
)
column_types = DatabaseUtils.get_column_types_util(results)
except Exception as e:
logger.error(
f"Error getting column types for {table_name}: {str(e)}"
Expand Down Expand Up @@ -180,13 +187,19 @@ def get_columns_and_values(

@staticmethod
def get_sql_values_for_query(
engine: Any, table_name: str, values: dict[str, Any]
engine: Any,
connector_id: str,
connector_settings: dict[str, Any],
table_name: str,
values: dict[str, Any],
) -> list[str]:
"""Generate SQL values for an insert query based on the provided values
and table schema.

Args:
engine (Any): The database engine.
connector_id: The connector id of the connector provided
connector_settings: Connector settings provided by user
table_name (str): The name of the target table for the insert query.
values (dict[str, Any]): A dictionary containing column-value pairs
for the insert query.
Expand All @@ -202,24 +215,20 @@ def get_sql_values_for_query(
- For other SQL databases, it uses default SQL generation
based on column types.
"""

if engine.__class__.__name__ == DBConnectionClass.SNOWFLAKE:
# Handle Snowflake
column_types: dict[str, str] = DatabaseUtils.get_column_types(
engine=engine,
table_name=table_name,
cls=DBConnectionClass.SNOWFLAKE,
)
class_name = engine.__class__.__name__
column_types: dict[str, str] = DatabaseUtils.get_column_types(
cls=class_name,
table_name=table_name,
connector_id=connector_id,
connector_settings=connector_settings,
)
if class_name == DBConnectionClass.SNOWFLAKE:
sql_values = DatabaseUtils.make_sql_values_for_query(
values=values,
column_types=column_types,
cls=DBConnectionClass.SNOWFLAKE,
)
else:
# Default to Other SQL DBs
column_types = DatabaseUtils.get_column_types(
engine=engine, table_name=table_name
)
sql_values = DatabaseUtils.make_sql_values_for_query(
values=values, column_types=column_types
)
Expand Down Expand Up @@ -248,11 +257,13 @@ def execute_write_query(
f"INSERT INTO {table_name} ({','.join(sql_keys)}) "
f"SELECT {','.join(sql_values)}"
)

try:
with engine.cursor() as cursor:
cursor.execute(sql)
engine.commit()
if hasattr(engine, "cursor"):
with engine.cursor() as cursor:
cursor.execute(sql)
engine.commit()
else:
engine.query(sql)
except Exception as e:
logger.error(f"Error while writing data: {str(e)}")
raise e
Expand All @@ -266,3 +277,13 @@ def get_db_engine(
]
connector_class: UnstractDB = connector(connector_settings)
return connector_class.get_engine()

@staticmethod
def execute_and_fetch_data(
connector_id: str, connector_settings: dict[str, Any], query: str
) -> Any:
connector = db_connectors[connector_id][Common.METADATA][
Common.CONNECTOR
]
connector_class: UnstractDB = connector(connector_settings)
return connector_class.execute(query=query)
25 changes: 8 additions & 17 deletions backend/workflow_manager/endpoint/destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

import fsspec
import magic
from account.models import EncryptionSecret
from connector.models import ConnectorInstance
from cryptography.fernet import Fernet
from django.db import connection
from fsspec.implementations.local import LocalFileSystem
from unstract.sdk.constants import ToolExecKey
Expand Down Expand Up @@ -77,17 +75,6 @@ def _get_endpoint_for_workflow(
workflow=workflow,
endpoint_type=WorkflowEndpoint.EndpointType.DESTINATION,
)
if endpoint.connector_instance:
encryption_secret: EncryptionSecret = EncryptionSecret.objects.get()
f: Fernet = Fernet(encryption_secret.key.encode("utf-8"))
endpoint.connector_instance.connector_metadata = json.loads(
f.decrypt(
bytes(endpoint.connector_instance.connector_metadata_b
).decode(
"utf-8"
)
)
)
return endpoint

def validate(self) -> None:
Expand Down Expand Up @@ -191,9 +178,9 @@ def copy_output_to_output_directory(self) -> None:
def insert_into_db(self, file_history: Optional[FileHistory]) -> None:
"""Insert data into the database."""
connector_instance: ConnectorInstance = self.endpoint.connector_instance
connector_settings: dict[
str, Any
] = connector_instance.connector_metadata
connector_settings: dict[str, Any] = (
connector_instance.connector_metadata
)
destination_configurations: dict[str, Any] = self.endpoint.configuration
table_name: str = str(
destination_configurations.get(DestinationKey.TABLE)
Expand Down Expand Up @@ -234,7 +221,11 @@ def insert_into_db(self, file_history: Optional[FileHistory]) -> None:
connector_settings=connector_settings,
)
sql_values = DatabaseUtils.get_sql_values_for_query(
engine=engine, table_name=table_name, values=values
engine=engine,
connector_id=connector_instance.connector_id,
connector_settings=connector_settings,
table_name=table_name,
values=values,
)

DatabaseUtils.execute_write_query(
Expand Down
20 changes: 4 additions & 16 deletions backend/workflow_manager/endpoint/source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import fnmatch
import json
import logging
import os
import shutil
Expand All @@ -8,10 +7,8 @@
from typing import Any, Optional

import fsspec
from account.models import EncryptionSecret
from connector.models import ConnectorInstance
from connector_processor.constants import ConnectorKeys
from cryptography.fernet import Fernet
from django.core.files.uploadedfile import UploadedFile
from django.db import connection
from unstract.workflow_execution.enums import LogState
Expand Down Expand Up @@ -92,18 +89,6 @@ def _get_endpoint_for_workflow(
workflow=workflow,
endpoint_type=WorkflowEndpoint.EndpointType.SOURCE,
)
if endpoint.connector_instance:
encryption_secret: EncryptionSecret = EncryptionSecret.objects.get()
f: Fernet = Fernet(encryption_secret.key.encode("utf-8"))
endpoint.connector_instance.connector_metadata = json.loads(
f.decrypt(
bytes(endpoint.connector_instance.connector_metadata_b
).decode(
"utf-8"
)
)
)

return endpoint

def validate(self) -> None:
Expand Down Expand Up @@ -173,7 +158,10 @@ def list_files_from_file_connector(self) -> list[str]:
input_directory = str(
source_configurations.get(SourceKey.ROOT_FOLDER, "")
)
input_directory = str(Path(root_dir_path, input_directory.lstrip("/")))
if root_dir_path: # user needs to manually type the optional file path
input_directory = str(
Path(root_dir_path, input_directory.lstrip("/"))
)
if not isinstance(required_patterns, list):
required_patterns = [required_patterns]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ function ConfigureConnectorModal({
formDataConfig,
setFormDataConfig,
isSpecConfigLoading,
connType,
}) {
const [activeKey, setActiveKey] = useState("1");

const tabItems = [
{
key: "1",
Expand All @@ -32,7 +32,7 @@ function ConfigureConnectorModal({
{
key: "2",
label: "File System",
disabled: !connectorId,
disabled: !connectorId || connType === "DATABASE",
},
];

Expand Down Expand Up @@ -112,6 +112,7 @@ ConfigureConnectorModal.propTypes = {
formDataConfig: PropTypes.object,
setFormDataConfig: PropTypes.func.isRequired,
isSpecConfigLoading: PropTypes.bool.isRequired,
connType: PropTypes.string.isRequired,
};

export { ConfigureConnectorModal };
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ function DsSettingsCard({ type, endpointDetails, message }) {
</Tooltip>
</Space>
<div className="display-flex-align-center">
{connDetails?.icon ? (
{connDetails?.connector_name ? (
<Space>
<Image
src={connDetails?.icon}
Expand Down Expand Up @@ -322,6 +322,7 @@ function DsSettingsCard({ type, endpointDetails, message }) {
formDataConfig={formDataConfig}
setFormDataConfig={setFormDataConfig}
isSpecConfigLoading={isSpecConfigLoading}
connType={connType}
/>
</>
);
Expand Down
Loading