Skip to content

Commit bd79ba0

Browse files
committed
feat: add transactions (local undo, single-backend, distributed 2PC) and Session API
1 parent 0fae9d5 commit bd79ba0

File tree

10 files changed

+1648
-1
lines changed

10 files changed

+1648
-1
lines changed

Makefile.new

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ TEST_SRCS = $(TEST_DIR)/test_main.cpp \
7373
$(TEST_DIR)/test_mysql_executor.cpp \
7474
$(TEST_DIR)/test_pgsql_executor.cpp \
7575
$(TEST_DIR)/test_distributed_real.cpp \
76-
$(TEST_DIR)/test_subquery.cpp
76+
$(TEST_DIR)/test_subquery.cpp \
77+
$(TEST_DIR)/test_local_txn.cpp \
78+
$(TEST_DIR)/test_session.cpp \
79+
$(TEST_DIR)/test_single_backend_txn.cpp \
80+
$(TEST_DIR)/test_distributed_txn.cpp
7781
TEST_OBJS = $(TEST_SRCS:.cpp=.o)
7882
TEST_TARGET = $(PROJECT_ROOT)/run_tests
7983

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
#ifndef SQL_ENGINE_DISTRIBUTED_TXN_H
2+
#define SQL_ENGINE_DISTRIBUTED_TXN_H
3+
4+
#include "sql_engine/transaction_manager.h"
5+
#include "sql_engine/remote_executor.h"
6+
#include "sql_engine/shard_map.h"
7+
#include "sql_parser/common.h"
8+
9+
#include <string>
10+
#include <vector>
11+
#include <unordered_map>
12+
#include <unordered_set>
13+
#include <chrono>
14+
#include <random>
15+
#include <cstdio>
16+
17+
namespace sql_engine {
18+
19+
// DistributedTransactionManager implements two-phase commit (2PC) for
20+
// transactions spanning multiple backends.
21+
//
22+
// MySQL XA protocol:
23+
// XA START 'txn_id' → begin on each participant
24+
// XA END 'txn_id' → mark end of work
25+
// XA PREPARE 'txn_id' → phase 1
26+
// XA COMMIT 'txn_id' → phase 2 (success)
27+
// XA ROLLBACK 'txn_id' → phase 2 (failure)
28+
//
29+
// PostgreSQL:
30+
// BEGIN → PREPARE TRANSACTION 'txn_id' → COMMIT PREPARED 'txn_id'
31+
class DistributedTransactionManager : public TransactionManager {
32+
public:
33+
// Backend dialect for 2PC protocol selection
34+
enum class BackendDialect : uint8_t { MYSQL, POSTGRESQL };
35+
36+
DistributedTransactionManager(RemoteExecutor& executor,
37+
BackendDialect dialect = BackendDialect::MYSQL)
38+
: executor_(executor), dialect_(dialect) {}
39+
40+
bool begin() override {
41+
txn_id_ = generate_txn_id();
42+
participants_.clear();
43+
prepared_.clear();
44+
started_.clear();
45+
active_ = true;
46+
return true;
47+
}
48+
49+
// Enlist a backend as a transaction participant. Called when DML is
50+
// executed against a backend. Sends XA START / BEGIN to the backend
51+
// if not already enlisted.
52+
bool enlist_backend(const char* backend_name) {
53+
if (!active_) return false;
54+
std::string name(backend_name);
55+
if (started_.count(name)) return true; // already enlisted
56+
57+
bool ok = false;
58+
if (dialect_ == BackendDialect::MYSQL) {
59+
std::string sql = "XA START '" + txn_id_ + "'";
60+
ok = send_sql(backend_name, sql);
61+
} else {
62+
ok = send_sql(backend_name, "BEGIN");
63+
}
64+
if (ok) {
65+
participants_.push_back(name);
66+
started_.insert(name);
67+
}
68+
return ok;
69+
}
70+
71+
bool commit() override {
72+
if (!active_) return false;
73+
if (participants_.empty()) {
74+
active_ = false;
75+
return true;
76+
}
77+
78+
// Phase 1: prepare all participants
79+
if (!phase1_prepare()) {
80+
phase2_rollback();
81+
active_ = false;
82+
return false;
83+
}
84+
85+
// Phase 2: commit all participants
86+
bool ok = phase2_commit();
87+
active_ = false;
88+
return ok;
89+
}
90+
91+
bool rollback() override {
92+
if (!active_) return false;
93+
phase2_rollback();
94+
active_ = false;
95+
return true;
96+
}
97+
98+
// Savepoints are not supported for distributed transactions.
99+
bool savepoint(const char*) override { return false; }
100+
bool rollback_to(const char*) override { return false; }
101+
bool release_savepoint(const char*) override { return false; }
102+
103+
bool in_transaction() const override { return active_; }
104+
bool is_auto_commit() const override { return auto_commit_; }
105+
void set_auto_commit(bool ac) override { auto_commit_ = ac; }
106+
107+
const std::string& txn_id() const { return txn_id_; }
108+
const std::vector<std::string>& participants() const { return participants_; }
109+
110+
private:
111+
RemoteExecutor& executor_;
112+
BackendDialect dialect_;
113+
114+
std::string txn_id_;
115+
std::vector<std::string> participants_;
116+
std::unordered_set<std::string> started_;
117+
std::unordered_map<std::string, bool> prepared_;
118+
bool active_ = false;
119+
bool auto_commit_ = true;
120+
121+
// Generate a unique transaction ID.
122+
static std::string generate_txn_id() {
123+
auto now = std::chrono::steady_clock::now();
124+
auto ns = std::chrono::duration_cast<std::chrono::nanoseconds>(
125+
now.time_since_epoch()).count();
126+
// Use random suffix to avoid collisions
127+
static thread_local std::mt19937 rng(
128+
static_cast<unsigned>(std::chrono::system_clock::now()
129+
.time_since_epoch().count()));
130+
std::uniform_int_distribution<uint32_t> dist(0, 999999);
131+
char buf[64];
132+
std::snprintf(buf, sizeof(buf), "parsersql_%ld_%06u",
133+
static_cast<long>(ns), dist(rng));
134+
return buf;
135+
}
136+
137+
bool send_sql(const char* backend, const std::string& sql) {
138+
auto r = executor_.execute_dml(backend,
139+
sql_parser::StringRef{sql.c_str(),
140+
static_cast<uint32_t>(sql.size())});
141+
return r.success;
142+
}
143+
144+
// Phase 1: XA END + XA PREPARE on all participants (MySQL)
145+
// PREPARE TRANSACTION on all participants (PostgreSQL)
146+
bool phase1_prepare() {
147+
bool all_ok = true;
148+
for (auto& p : participants_) {
149+
bool ok = false;
150+
if (dialect_ == BackendDialect::MYSQL) {
151+
std::string end_sql = "XA END '" + txn_id_ + "'";
152+
ok = send_sql(p.c_str(), end_sql);
153+
if (ok) {
154+
std::string prep_sql = "XA PREPARE '" + txn_id_ + "'";
155+
ok = send_sql(p.c_str(), prep_sql);
156+
}
157+
} else {
158+
std::string prep_sql = "PREPARE TRANSACTION '" + txn_id_ + "'";
159+
ok = send_sql(p.c_str(), prep_sql);
160+
}
161+
prepared_[p] = ok;
162+
if (!ok) all_ok = false;
163+
}
164+
return all_ok;
165+
}
166+
167+
// Phase 2 (success): XA COMMIT on all participants
168+
bool phase2_commit() {
169+
bool all_ok = true;
170+
for (auto& p : participants_) {
171+
bool ok = false;
172+
if (dialect_ == BackendDialect::MYSQL) {
173+
std::string sql = "XA COMMIT '" + txn_id_ + "'";
174+
ok = send_sql(p.c_str(), sql);
175+
} else {
176+
std::string sql = "COMMIT PREPARED '" + txn_id_ + "'";
177+
ok = send_sql(p.c_str(), sql);
178+
}
179+
if (!ok) all_ok = false;
180+
}
181+
return all_ok;
182+
}
183+
184+
// Phase 2 (failure): XA ROLLBACK on all participants
185+
void phase2_rollback() {
186+
for (auto& p : participants_) {
187+
if (dialect_ == BackendDialect::MYSQL) {
188+
// If prepared, XA ROLLBACK; if only started, XA END + XA ROLLBACK
189+
if (prepared_.count(p) && prepared_[p]) {
190+
std::string sql = "XA ROLLBACK '" + txn_id_ + "'";
191+
send_sql(p.c_str(), sql);
192+
} else if (started_.count(p)) {
193+
// Try XA END first (may fail if already ended)
194+
std::string end_sql = "XA END '" + txn_id_ + "'";
195+
send_sql(p.c_str(), end_sql);
196+
std::string rb_sql = "XA ROLLBACK '" + txn_id_ + "'";
197+
send_sql(p.c_str(), rb_sql);
198+
}
199+
} else {
200+
if (prepared_.count(p) && prepared_[p]) {
201+
std::string sql = "ROLLBACK PREPARED '" + txn_id_ + "'";
202+
send_sql(p.c_str(), sql);
203+
} else if (started_.count(p)) {
204+
send_sql(p.c_str(), "ROLLBACK");
205+
}
206+
}
207+
}
208+
}
209+
};
210+
211+
} // namespace sql_engine
212+
213+
#endif // SQL_ENGINE_DISTRIBUTED_TXN_H

0 commit comments

Comments
 (0)