diff --git a/Makefile b/Makefile index c0a847830..c6ee47611 100644 --- a/Makefile +++ b/Makefile @@ -112,7 +112,8 @@ REGRESS = scan \ name_validation \ jsonb_operators \ list_comprehension \ - map_projection + map_projection \ + concurrent ifneq ($(EXTRA_TESTS),) REGRESS += $(EXTRA_TESTS) diff --git a/regress/expected/concurrent.out b/regress/expected/concurrent.out new file mode 100644 index 000000000..39948dc85 --- /dev/null +++ b/regress/expected/concurrent.out @@ -0,0 +1,12 @@ +\! python3 regress/python/test_concurrent.py +Result: OK + +Result: OK + +Result: OK + +Result: OK + +Result: OK + +All threads have finished execution. diff --git a/regress/python/test_concurrent.py b/regress/python/test_concurrent.py new file mode 100644 index 000000000..8384e757c --- /dev/null +++ b/regress/python/test_concurrent.py @@ -0,0 +1,170 @@ +from contextlib import contextmanager +import psycopg2 +from psycopg2 import sql +import threading + +from concurrent.futures import ThreadPoolExecutor +from threading import Semaphore + +sqls = [ + """ SELECT * FROM cypher('test_graph', $$ + MERGE (n:DDDDD {doc_id: 'f5ce4dc2'}) + SET n.embedset_id='ae1b9b73', n.doc_id='f5ce4dc2', n.doc_hash='977b56ef' + $$) as (result agtype) """, + """ SELECT * FROM cypher('test_graph', $$ + MERGE (n:EEEEE {doc_id: 'f5ce4dc2'}) + SET n.embedset_id='ae1b9b73', n.doc_id='f5ce4dc2', n.doc_hash='1d7e79a0' + $$) as (result agtype) """, + """ SELECT * FROM cypher('test_graph', $$ + MATCH (source:EEEEE {doc_id:'f5ce4dc2'}) + MATCH (target:DDDDD {doc_id:'f5ce4dc2'}) + WITH source, target + MERGE (source)-[r:DIRECTED]->(target) + SET r.embedset_id='ae1b9b73', r.doc_id='f5ce4dc2' + RETURN r + $$) as (result agtype) """, + """ SELECT * FROM cypher('test_graph', $$ + MATCH (source:EEEEE {doc_id:'f5ce4dc2'}) + MATCH (target:DDDDD {doc_id:'f5ce4dc2'}) + WITH source, target + MERGE (source)-[r:DIRECTED]->(target) + SET r.embedset_id='ae1b9b73', r.doc_id='f5ce4dc2' + RETURN r + $$) as (result agtype) """, + """ SELECT * FROM cypher('test_graph', $$ + MATCH (source:EEEEE {doc_id:'f5ce4dc2'}) + MATCH (target:DDDDD {doc_id:'f5ce4dc2'}) + WITH source, target + MERGE (source)-[r:DIRECTED]->(target) + SET r.embedset_id='ae1b9b73', r.doc_id='f5ce4dc2' + RETURN r + $$) as (result agtype) """, +] + + +class PieGraphConnector: + host: str + port: str + user: str + password: str + database: str + warehouse: str + + def __init__(self, global_config: dict): + self.host = global_config.get("host", "") + self.port = global_config.get("port", "") + self.user = global_config.get("user", "") + self.password = global_config.get("password", "") + self.database = global_config.get("database", "") + self.warehouse = global_config.get("warehouse", "") + + @contextmanager + def conn(self): + conn = None + if self.warehouse and self.warehouse != "": + options = "'-c warehouse=" + self.warehouse + "'" + conn = psycopg2.connect( + dbname=self.database, + user=self.user, + password=self.password, + host=self.host, + port=self.port, + options=options, + ) + else: + conn = psycopg2.connect( + dbname=self.database, + user=self.user, + password=self.password, + host=self.host, + port=self.port, + ) + conn.autocommit = True + with conn.cursor() as cursor: + cursor.execute("CREATE EXTENSION IF NOT EXISTS age;") + cursor.execute("LOAD 'age';") + cursor.execute("SET search_path = ag_catalog, '$user', public;") + try: + yield conn + finally: + conn.close() + + +class BoundedThreadPoolExecutor(ThreadPoolExecutor): + def __init__(self, max_workers=5, max_task_size=32, *args, **kwargs): + if max_task_size < max_workers: + raise ValueError( + "max_task_size should be greater than or equal to max_workers" + ) + if max_workers is not None: + kwargs["max_workers"] = max_workers + super().__init__(*args, **kwargs) + self._semaphore = Semaphore(max_task_size) + + def submit(self, fn, /, *args, **kwargs): + timeout = kwargs.get("timeout", None) + if self._semaphore.acquire(timeout=timeout): + future = super().submit(fn, *args, **kwargs) + future.add_done_callback(lambda _: self._semaphore.release()) + return future + else: + raise TimeoutError("waiting for semaphore timeout") + + +db_config = { + "host": "127.0.0.1", + "database": "postgres", + "user": "postgres", + "password": "", + "port": "5432", +} + +connector = PieGraphConnector(db_config) + +def execute_sql(query): + try: + connection = psycopg2.connect(**db_config) + cursor = connection.cursor() + cursor.execute("CREATE EXTENSION IF NOT EXISTS age;") + cursor.execute("LOAD 'age';") + cursor.execute("SET search_path = ag_catalog, '$user', public;") + + cursor.execute(query) + + connection.commit() + + result = cursor.fetchall() + + except Exception as e: + print(f"Error executing query '{query}': {e}") + finally: + if connection: + cursor.close() + connection.close() + +semaphore_graph = threading.Semaphore(20) + +drop_graph = "SELECT * FROM drop_graph('test_graph', true)" +create_graph = "SELECT * FROM create_graph('test_graph')" + +# execute_sql(drop_graph) +execute_sql(create_graph) + +def _merge_exec_sql(query: str): + with semaphore_graph: + with connector.conn() as conn: + with conn.cursor() as cursor: + try: + cursor.execute(query) + result = cursor.fetchall() + print(f"Result: OK\n") + except Exception as e: + print(f"Error executing query '{query}': {e}") + conn.commit() + +with BoundedThreadPoolExecutor() as executor: + executor.map(lambda q: _merge_exec_sql(q), sqls) + +print("All threads have finished execution.") + +execute_sql(drop_graph) diff --git a/regress/sql/concurrent.sql b/regress/sql/concurrent.sql new file mode 100644 index 000000000..62c5ee6c0 --- /dev/null +++ b/regress/sql/concurrent.sql @@ -0,0 +1 @@ +\! python3 regress/python/test_concurrent.py diff --git a/src/backend/parser/cypher_clause.c b/src/backend/parser/cypher_clause.c index e301daa0f..49216d058 100644 --- a/src/backend/parser/cypher_clause.c +++ b/src/backend/parser/cypher_clause.c @@ -42,12 +42,14 @@ #include "catalog/ag_graph.h" #include "catalog/ag_label.h" #include "commands/label_commands.h" +#include "common/hashfn.h" #include "parser/cypher_analyze.h" #include "parser/cypher_clause.h" #include "parser/cypher_expr.h" #include "parser/cypher_item.h" #include "parser/cypher_parse_agg.h" #include "parser/cypher_transform_entity.h" +#include "storage/lock.h" #include "utils/ag_cache.h" #include "utils/ag_func.h" #include "utils/ag_guc.h" @@ -5872,15 +5874,23 @@ transform_create_cypher_edge(cypher_parsestate *cpstate, List **target_list, /* create the label entry if it does not exist */ if (!label_exists(edge->label, cpstate->graph_oid)) { + LOCKTAG tag; + uint32 key; List *parent; - rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid, - AG_DEFAULT_LABEL_EDGE); + key = hash_bytes((const unsigned char *)edge->label, strlen(edge->label)); + SET_LOCKTAG_ADVISORY(tag, MyDatabaseId, key, cpstate->graph_oid, 3); + (void) LockAcquire(&tag, ExclusiveLock, false, false); + if (!label_exists(edge->label, cpstate->graph_oid)) + { + rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid, + AG_DEFAULT_LABEL_EDGE); - parent = list_make1(rv); + parent = list_make1(rv); - create_label(cpstate->graph_name, edge->label, LABEL_TYPE_EDGE, - parent); + create_label(cpstate->graph_name, edge->label, LABEL_TYPE_EDGE, + parent); + } } /* lock the relation of the label */ @@ -6149,15 +6159,23 @@ transform_create_cypher_new_node(cypher_parsestate *cpstate, /* create the label entry if it does not exist */ if (!label_exists(node->label, cpstate->graph_oid)) { + LOCKTAG tag; + uint32 key; List *parent; - rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid, - AG_DEFAULT_LABEL_VERTEX); + key = hash_bytes((const unsigned char *)node->label, strlen(node->label)); + SET_LOCKTAG_ADVISORY(tag, MyDatabaseId, key, cpstate->graph_oid, 3); + (void) LockAcquire(&tag, ExclusiveLock, false, false); + if (!label_exists(node->label, cpstate->graph_oid)) + { + rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid, + AG_DEFAULT_LABEL_VERTEX); - parent = list_make1(rv); + parent = list_make1(rv); - create_label(cpstate->graph_name, node->label, LABEL_TYPE_VERTEX, - parent); + create_label(cpstate->graph_name, node->label, LABEL_TYPE_VERTEX, + parent); + } } rel->flags = CYPHER_TARGET_NODE_FLAG_INSERT; @@ -7222,19 +7240,36 @@ transform_merge_cypher_edge(cypher_parsestate *cpstate, List **target_list, /* check to see if the label exists, create the label entry if it does not. */ if (edge->label && !label_exists(edge->label, cpstate->graph_oid)) { + LOCKTAG tag; + uint32 key; List *parent; + /* - * setup the default edge table as the parent table, that we - * will inherit from. + * When merging nodes or edges concurrently, there is label with the same + * name created by different transactions. Advisory lock is acquired before + * creating label, and then check if label already exists. Note, the lock is + * not released until current transaction is over. This can ensure that the + * new tuple inserted in ag_label catalog table will be sent out, so other + * transactions can receive it when checking label exists after acquiring lock. */ - rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid, - AG_DEFAULT_LABEL_EDGE); + key = hash_bytes((const unsigned char *)edge->label, strlen(edge->label)); + SET_LOCKTAG_ADVISORY(tag, MyDatabaseId, key, cpstate->graph_oid, 3); + (void) LockAcquire(&tag, ExclusiveLock, false, false); + if (!label_exists(edge->label, cpstate->graph_oid)) + { + /* + * setup the default edge table as the parent table, that we + * will inherit from. + */ + rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid, + AG_DEFAULT_LABEL_EDGE); - parent = list_make1(rv); + parent = list_make1(rv); - /* create the label */ - create_label(cpstate->graph_name, edge->label, LABEL_TYPE_EDGE, - parent); + /* create the label */ + create_label(cpstate->graph_name, edge->label, LABEL_TYPE_EDGE, + parent); + } } /* lock the relation of the label */ @@ -7357,20 +7392,28 @@ transform_merge_cypher_node(cypher_parsestate *cpstate, List **target_list, /* check to see if the label exists, create the label entry if it does not. */ if (node->label && !label_exists(node->label, cpstate->graph_oid)) { + LOCKTAG tag; + uint32 key; List *parent; - /* - * setup the default vertex table as the parent table, that we - * will inherit from. - */ - rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid, - AG_DEFAULT_LABEL_VERTEX); + key = hash_bytes((const unsigned char *)node->label, strlen(node->label)); + SET_LOCKTAG_ADVISORY(tag, MyDatabaseId, key, cpstate->graph_oid, 3); + (void) LockAcquire(&tag, ExclusiveLock, false, false); + if (!label_exists(node->label, cpstate->graph_oid)) + { + /* + * setup the default vertex table as the parent table, that we + * will inherit from. + */ + rv = get_label_range_var(cpstate->graph_name, cpstate->graph_oid, + AG_DEFAULT_LABEL_VERTEX); - parent = list_make1(rv); + parent = list_make1(rv); - /* create the label */ - create_label(cpstate->graph_name, node->label, LABEL_TYPE_VERTEX, - parent); + /* create the label */ + create_label(cpstate->graph_name, node->label, LABEL_TYPE_VERTEX, + parent); + } } rel->flags |= CYPHER_TARGET_NODE_FLAG_INSERT;