Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 88 additions & 13 deletions cloudquery/sdk/schema/table.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

from typing import List, Generator, Any
import copy
import fnmatch
from typing import List

import pyarrow as pa

from cloudquery.sdk.schema import arrow
Expand Down Expand Up @@ -87,17 +89,90 @@ def tables_to_arrow_schemas(tables: List[Table]):


def filter_dfs(
tables: List[Table], include_tables: List[str], skip_tables: List[str]
tables: List[Table],
include_tables: List[str],
skip_tables: List[str],
skip_dependent_tables: bool = False,
) -> List[Table]:
filtered: List[Table] = []
flattened_tables = flatten_tables(tables)
for include_pattern in include_tables:
matched = any(
fnmatch.fnmatch(table.name, include_pattern) for table in flattened_tables
)
if not matched:
raise ValueError(
f"tables include a pattern {include_pattern} with no matches"
)

for exclude_pattern in skip_tables:
matched = any(
fnmatch.fnmatch(table.name, exclude_pattern) for table in flattened_tables
)
if not matched:
raise ValueError(
f"skip_tables include a pattern {exclude_pattern} with no matches"
)

def include_func(t):
return any(
fnmatch.fnmatch(t.name, include_pattern)
for include_pattern in include_tables
)

def exclude_func(t):
return any(
fnmatch.fnmatch(t.name, exclude_pattern) for exclude_pattern in skip_tables
)

return filter_dfs_func(tables, include_func, exclude_func, skip_dependent_tables)


def filter_dfs_func(tt: List[Table], include, exclude, skip_dependent_tables: bool):
filtered_tables = []
for t in tt:
filtered_table = copy.deepcopy(t)
filtered_table = _filter_dfs_impl(
filtered_table, False, include, exclude, skip_dependent_tables
)
if filtered_table is not None:
filtered_tables.append(filtered_table)
return filtered_tables


def _filter_dfs_impl(t, parent_matched, include, exclude, skip_dependent_tables):
def filter_dfs_child(r, matched, include, exclude, skip_dependent_tables):
filtered_child = _filter_dfs_impl(
r, matched, include, exclude, skip_dependent_tables
)
if filtered_child is not None:
return True, r
return matched, None

if exclude(t):
return None

matched = parent_matched and not skip_dependent_tables
if include(t):
matched = True

filtered_relations = []
for r in t.relations:
matched, filtered_child = filter_dfs_child(
r, matched, include, exclude, skip_dependent_tables
)
if filtered_child is not None:
filtered_relations.append(filtered_child)

t.relations = filtered_relations

if matched:
return t
return None


def flatten_tables(tables: List[Table]) -> List[Table]:
flattened: List[Table] = []
for table in tables:
matched = False
for include_table in include_tables:
if fnmatch.fnmatch(table.name, include_table):
matched = True
for skip_table in skip_tables:
if fnmatch.fnmatch(table.name, skip_table):
matched = False
if matched:
filtered.append(table)
return filtered
flattened.append(table)
flattened.extend(flatten_tables(table.relations))
return flattened
84 changes: 83 additions & 1 deletion tests/schema/test_table.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,90 @@
import pyarrow as pa
import pytest

from cloudquery.sdk.schema import Table, Column
from cloudquery.sdk.schema import Table, Column, filter_dfs
from cloudquery.sdk.schema.table import flatten_tables


def test_table():
table = Table("test_table", [Column("test_column", pa.int32())])
table.to_arrow_schema()


def test_filter_dfs_warns_no_matches():
with pytest.raises(ValueError):
tables = [Table("test1", []), Table("test2", [])]
filter_dfs(tables, include_tables=["test3"], skip_tables=[])

with pytest.raises(ValueError):
tables = [Table("test1", []), Table("test2", [])]
filter_dfs(tables, include_tables=["*"], skip_tables=["test3"])


def test_filter_dfs():
table_grandchild = Table("test_grandchild", [Column("test_column", pa.int32())])
table_child = Table(
"test_child",
[Column("test_column", pa.int32())],
relations=[
table_grandchild,
],
)
table_top1 = Table(
"test_top1",
[Column("test_column", pa.int32())],
relations=[
table_child,
],
)
table_top2 = Table("test_top2", [Column("test_column", pa.int32())])

tables = [table_top1, table_top2]

cases = [
{
"include_tables": ["*"],
"skip_tables": [],
"skip_dependent_tables": False,
"expect_top": ["test_top1", "test_top2"],
"expect_flattened": [
"test_top1",
"test_top2",
"test_child",
"test_grandchild",
],
},
{
"include_tables": ["*"],
"skip_tables": ["test_top1"],
"skip_dependent_tables": False,
"expect_top": ["test_top2"],
"expect_flattened": ["test_top2"],
},
{
"include_tables": ["test_top1"],
"skip_tables": ["test_top2"],
"skip_dependent_tables": True,
"expect_top": ["test_top1"],
"expect_flattened": ["test_top1"],
},
{
"include_tables": ["test_child"],
"skip_tables": [],
"skip_dependent_tables": True,
"expect_top": ["test_top1"],
"expect_flattened": ["test_top1", "test_child"],
},
]
for case in cases:
got = filter_dfs(
tables=tables,
include_tables=case["include_tables"],
skip_tables=case["skip_tables"],
skip_dependent_tables=case["skip_dependent_tables"],
)
assert sorted([t.name for t in got]) == sorted(case["expect_top"]), case

got_flattened = flatten_tables(got)
want_flattened = sorted(case["expect_flattened"])
got_flattened = sorted([t.name for t in got_flattened])
assert got_flattened == want_flattened, case