diff --git a/cloudquery/sdk/internal/memdb/memdb.py b/cloudquery/sdk/internal/memdb/memdb.py index aa1c9e8..dd7bd5d 100644 --- a/cloudquery/sdk/internal/memdb/memdb.py +++ b/cloudquery/sdk/internal/memdb/memdb.py @@ -5,6 +5,7 @@ from cloudquery.sdk import schema from typing import List, Generator, Dict import pyarrow as pa +from cloudquery.sdk.schema.table import Table from cloudquery.sdk.types import JSONType from dataclasses import dataclass, field @@ -109,5 +110,9 @@ def write(self, writer: Generator[message.WriteMessage, None, None]) -> None: else: raise NotImplementedError(f"Unknown message type {type(msg)}") + def read(self, table: Table) -> Generator[message.ReadMessage, None, None]: + for table, record in self._db.items(): + yield message.ReadMessage(record) + def close(self) -> None: self._db = {} diff --git a/cloudquery/sdk/internal/servers/plugin_v3/plugin.py b/cloudquery/sdk/internal/servers/plugin_v3/plugin.py index ffbc581..dee7ff0 100644 --- a/cloudquery/sdk/internal/servers/plugin_v3/plugin.py +++ b/cloudquery/sdk/internal/servers/plugin_v3/plugin.py @@ -81,8 +81,14 @@ def Sync(self, request, context): # unknown sync message type raise NotImplementedError() - def Read(self, request, context): - raise NotImplementedError() + def Read( + self, request: plugin_pb2.Read.Request, context + ) -> Generator[plugin_pb2.Read.Response, None, None]: + schema = arrow.new_schema_from_bytes(request.table) + table = Table.from_arrow_schema(schema) + for msg in self._plugin.read(table): + buf = arrow.record_to_bytes(msg.record) + yield plugin_pb2.Read.Response(record=buf) def Write( self, request_iterator: Generator[plugin_pb2.Write.Request, None, None], context @@ -93,7 +99,9 @@ def msg_iterator() -> Generator[WriteMessage, None, None]: if field == "migrate_table": sc = arrow.new_schema_from_bytes(msg.migrate_table.table) table = Table.from_arrow_schema(sc) - yield WriteMigrateTableMessage(table=table) + yield WriteMigrateTableMessage( + table=table, migrate_force=msg.migrate_table.migrate_force + ) elif field == "insert": yield WriteInsertMessage( record=arrow.new_record_from_bytes(msg.insert.record) diff --git a/cloudquery/sdk/message/__init__.py b/cloudquery/sdk/message/__init__.py index 5ddb77a..bbcb84f 100644 --- a/cloudquery/sdk/message/__init__.py +++ b/cloudquery/sdk/message/__init__.py @@ -5,3 +5,4 @@ WriteMigrateTableMessage, WriteDeleteStale, ) +from .read import ReadMessage diff --git a/cloudquery/sdk/message/read.py b/cloudquery/sdk/message/read.py new file mode 100644 index 0000000..227aa1e --- /dev/null +++ b/cloudquery/sdk/message/read.py @@ -0,0 +1,6 @@ +import pyarrow as pa + + +class ReadMessage: + def __init__(self, record: pa.RecordBatch): + self.record = record diff --git a/cloudquery/sdk/message/write.py b/cloudquery/sdk/message/write.py index 69ef9cc..f1beeb7 100644 --- a/cloudquery/sdk/message/write.py +++ b/cloudquery/sdk/message/write.py @@ -12,8 +12,9 @@ def __init__(self, record: pa.RecordBatch): class WriteMigrateTableMessage(WriteMessage): - def __init__(self, table: Table): + def __init__(self, table: Table, migrate_force: bool): self.table = table + self.migrate_force = migrate_force class WriteDeleteStale(WriteMessage): diff --git a/cloudquery/sdk/plugin/plugin.py b/cloudquery/sdk/plugin/plugin.py index 5b3d5f5..d03dbef 100644 --- a/cloudquery/sdk/plugin/plugin.py +++ b/cloudquery/sdk/plugin/plugin.py @@ -93,5 +93,8 @@ def sync(self, options: SyncOptions) -> Generator[message.SyncMessage, None, Non def write(self, writer: Generator[message.WriteMessage, None, None]) -> None: raise NotImplementedError() + def read(self, table: Table) -> Generator[message.ReadMessage, None, None]: + raise NotImplementedError() + def close(self) -> None: raise NotImplementedError() diff --git a/cloudquery/sdk/schema/__init__.py b/cloudquery/sdk/schema/__init__.py index 649a1fd..63b4065 100644 --- a/cloudquery/sdk/schema/__init__.py +++ b/cloudquery/sdk/schema/__init__.py @@ -1,5 +1,16 @@ from .column import Column -from .table import Table, tables_to_arrow_schemas, filter_dfs +from .table import ( + Table, + tables_to_arrow_schemas, + filter_dfs, + TableColumnChangeType, + TableColumnChange, + TableColumnChangeType, + get_table_changes, + get_table_column, + flatten_tables_recursive, + flatten_tables, +) from .resource import Resource # from .table_resolver import TableReso diff --git a/cloudquery/sdk/schema/column.py b/cloudquery/sdk/schema/column.py index 2283d25..3d165a7 100644 --- a/cloudquery/sdk/schema/column.py +++ b/cloudquery/sdk/schema/column.py @@ -55,7 +55,9 @@ def to_arrow_field(self): arrow.METADATA_TRUE if self.incremental_key else arrow.METADATA_FALSE ), } - return pa.field(self.name, self.type, metadata=metadata) + return pa.field( + self.name, self.type, metadata=metadata, nullable=not self.not_null + ) @staticmethod def from_arrow_field(field: pa.Field) -> Column: diff --git a/cloudquery/sdk/schema/table.py b/cloudquery/sdk/schema/table.py index 71b3d29..a55ce76 100644 --- a/cloudquery/sdk/schema/table.py +++ b/cloudquery/sdk/schema/table.py @@ -1,8 +1,9 @@ from __future__ import annotations import copy +from enum import IntEnum import fnmatch -from typing import List +from typing import List, Optional import pyarrow as pa @@ -10,6 +11,10 @@ from .column import Column +CQ_SYNC_TIME_COLUMN = "cq_sync_time" +CQ_SOURCE_NAME_COLUMN = "cq_source_name" + + class Client: pass @@ -192,9 +197,137 @@ def filter_dfs_child(r, matched, include, exclude, skip_dependent_tables): return None -def flatten_tables(tables: List[Table]) -> List[Table]: - flattened: List[Table] = [] +class TableColumnChangeType: + ADD = 1 + REMOVE = 2 + REMOVE_UNIQUE_CONSTRAINT = 3 + + +class TableColumnChange: + def __init__( + self, + type: TableColumnChangeType, + column_name: str, + current: Optional[Column], + previous: Optional[Column], + ): + self.type = type + self.column_name = column_name + self.current = current + self.previous = previous + + +class TableColumnChangeType(IntEnum): + UNKNOWN = 0 + ADD = 1 + UPDATE = 2 + REMOVE = 3 + REMOVE_UNIQUE_CONSTRAINT = 4 + MOVE_TO_CQ_ONLY = 5 + + +def get_table_changes(new: Table, old: Table) -> List[TableColumnChange]: + changes = [] + + # Special case: Moving from individual PKs to singular PK on _cq_id + new_pks = new.primary_keys + if ( + len(new_pks) == 1 + and new_pks[0] == "CqIDColumn" + and get_table_column(old, "CqIDColumn") is None + and len(old.primary_keys) > 0 + ): + changes.append( + TableColumnChange( + type=TableColumnChangeType.MOVE_TO_CQ_ONLY, + ) + ) + + for c in new.columns: + other_column = get_table_column(old, c.name) + # A column was added to the table definition + if other_column is None: + changes.append( + TableColumnChange( + type=TableColumnChangeType.ADD, + column_name=c.name, + current=c, + previous=None, + ) + ) + continue + + # Column type or options (e.g. PK, Not Null) changed in the new table definition + if ( + c.type != other_column.type + or c.not_null != other_column.not_null + or c.primary_key != other_column.primary_key + ): + changes.append( + TableColumnChange( + type=TableColumnChangeType.UPDATE, + column_name=c.name, + current=c, + previous=other_column, + ) + ) + + # Unique constraint was removed + if not c.unique and other_column.unique: + changes.append( + TableColumnChange( + type=TableColumnChangeType.REMOVE_UNIQUE_CONSTRAINT, + column_name=c.name, + current=c, + previous=other_column, + ) + ) + + # A column was removed from the table definition + for c in old.columns: + if get_table_column(new, c.name) is None: + changes.append( + TableColumnChange( + type=TableColumnChangeType.REMOVE, + column_name=c.name, + current=None, + previous=c, + ) + ) + + return changes + + +def get_table_column(table: Table, column_name: str) -> Optional[Column]: + for c in table.columns: + if c.name == column_name: + return c + return None + + +def flatten_tables_recursive(original_tables: List[Table]) -> List[Table]: + tables = [] + for table in original_tables: + table_copy = Table( + name=table.name, + columns=table.columns, + relations=table.relations, + title=table.title, + description=table.description, + is_incremental=table.is_incremental, + parent=table.parent, + ) + tables.append(table_copy) + tables.extend(flatten_tables_recursive(table.relations)) + return tables + + +def flatten_tables(original_tables: List[Table]) -> List[Table]: + tables = flatten_tables_recursive(original_tables) + seen = set() + deduped = [] for table in tables: - flattened.append(table) - flattened.extend(flatten_tables(table.relations)) - return flattened + if table.name not in seen: + deduped.append(table) + seen.add(table.name) + return deduped diff --git a/cloudquery/sdk/types/json.py b/cloudquery/sdk/types/json.py index ed9ed1f..dfcd852 100644 --- a/cloudquery/sdk/types/json.py +++ b/cloudquery/sdk/types/json.py @@ -21,3 +21,6 @@ def __arrow_ext_deserialize__(self, storage_type, serialized): # return an instance of this subclass given the serialized # metadata. return JSONType() + + +pa.register_extension_type(JSONType()) diff --git a/cloudquery/sdk/types/uuid.py b/cloudquery/sdk/types/uuid.py index d3564a9..549774c 100644 --- a/cloudquery/sdk/types/uuid.py +++ b/cloudquery/sdk/types/uuid.py @@ -23,3 +23,6 @@ def __arrow_ext_deserialize__(self, storage_type, serialized): # return an instance of this subclass given the serialized # metadata. return UUIDType() + + +pa.register_extension_type(UUIDType()) diff --git a/tests/serve/plugin.py b/tests/serve/plugin.py index f7336bc..f07d819 100644 --- a/tests/serve/plugin.py +++ b/tests/serve/plugin.py @@ -1,6 +1,7 @@ import json import os import random +from uuid import UUID import grpc import time import pyarrow as pa @@ -9,12 +10,16 @@ from cloudquery.sdk import serve from cloudquery.plugin_v3 import plugin_pb2_grpc, plugin_pb2, arrow from cloudquery.sdk.internal.memdb import MemDB +from cloudquery.sdk.types.json import JSONType +from cloudquery.sdk.types.uuid import UUIDType test_table = Table( "test", [ Column("id", pa.int64()), Column("name", pa.string()), + Column("json", JSONType()), + Column("uuid", UUIDType()), ], ) @@ -47,6 +52,15 @@ def writer_iterator(): [ pa.array([1, 2, 3]), pa.array(["a", "b", "c"]), + pa.array([None, b"{}", b'{"a":null}']), + pa.array( + [ + None, + UUID("550e8400-e29b-41d4-a716-446655440000").bytes, + UUID("123e4567-e89b-12d3-a456-426614174000").bytes, + ], + type=pa.binary(16), + ), ], schema=test_table.to_arrow_schema(), ) @@ -74,6 +88,77 @@ def writer_iterator(): pool.shutdown() +def test_plugin_read(): + p = MemDB() + sample_record_1 = pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3]), + pa.array(["a", "b", "c"]), + pa.array([None, b"{}", b'{"a":null}']), + pa.array( + [ + None, + UUID("550e8400-e29b-41d4-a716-446655440000").bytes, + UUID("123e4567-e89b-12d3-a456-426614174000").bytes, + ], + type=pa.binary(16), + ), + ], + schema=test_table.to_arrow_schema(), + ) + sample_record_2 = pa.RecordBatch.from_arrays( + [ + pa.array([2, 3, 4]), + pa.array(["b", "c", "d"]), + pa.array([b'""', b'{"a":true}', b'{"b":1}']), + pa.array( + [ + UUID("9bba4c2a-1a37-4fbe-b489-6b40303a8a25").bytes, + None, + UUID("3fa85f64-5717-4562-b3fc-2c963f66afa6").bytes, + ], + type=pa.binary(16), + ), + ], + schema=test_table.to_arrow_schema(), + ) + p._db["test_1"] = sample_record_1 + p._db["test_2"] = sample_record_2 + + cmd = serve.PluginCommand(p) + port = random.randint(5000, 50000) + pool = futures.ThreadPoolExecutor(max_workers=1) + pool.submit(cmd.run, ["serve", "--address", f"[::]:{port}"]) + time.sleep(1) + try: + with grpc.insecure_channel(f"localhost:{port}") as channel: + stub = plugin_pb2_grpc.PluginStub(channel) + response = stub.GetName(plugin_pb2.GetName.Request()) + assert response.name == "memdb" + + response = stub.GetVersion(plugin_pb2.GetVersion.Request()) + assert response.version == "development" + + response = stub.Init(plugin_pb2.Init.Request(spec=b"")) + assert response is not None + + request = plugin_pb2.Read.Request( + table=arrow.schema_to_bytes(test_table.to_arrow_schema()) + ) + reader_iterator = stub.Read(request) + + output_records = [] + for msg in reader_iterator: + output_records.append(arrow.new_record_from_bytes(msg.record)) + + assert len(output_records) == 2 + assert output_records[0].equals(sample_record_1) + assert output_records[1].equals(sample_record_2) + finally: + cmd.stop() + pool.shutdown() + + def test_plugin_package(): p = MemDB() cmd = serve.PluginCommand(p)