|
| 1 | +#ifndef SQL_ENGINE_DISTRIBUTED_TXN_H |
| 2 | +#define SQL_ENGINE_DISTRIBUTED_TXN_H |
| 3 | + |
| 4 | +#include "sql_engine/transaction_manager.h" |
| 5 | +#include "sql_engine/remote_executor.h" |
| 6 | +#include "sql_engine/shard_map.h" |
| 7 | +#include "sql_parser/common.h" |
| 8 | + |
| 9 | +#include <string> |
| 10 | +#include <vector> |
| 11 | +#include <unordered_map> |
| 12 | +#include <unordered_set> |
| 13 | +#include <chrono> |
| 14 | +#include <random> |
| 15 | +#include <cstdio> |
| 16 | + |
| 17 | +namespace sql_engine { |
| 18 | + |
| 19 | +// DistributedTransactionManager implements two-phase commit (2PC) for |
| 20 | +// transactions spanning multiple backends. |
| 21 | +// |
| 22 | +// MySQL XA protocol: |
| 23 | +// XA START 'txn_id' → begin on each participant |
| 24 | +// XA END 'txn_id' → mark end of work |
| 25 | +// XA PREPARE 'txn_id' → phase 1 |
| 26 | +// XA COMMIT 'txn_id' → phase 2 (success) |
| 27 | +// XA ROLLBACK 'txn_id' → phase 2 (failure) |
| 28 | +// |
| 29 | +// PostgreSQL: |
| 30 | +// BEGIN → PREPARE TRANSACTION 'txn_id' → COMMIT PREPARED 'txn_id' |
| 31 | +class DistributedTransactionManager : public TransactionManager { |
| 32 | +public: |
| 33 | + // Backend dialect for 2PC protocol selection |
| 34 | + enum class BackendDialect : uint8_t { MYSQL, POSTGRESQL }; |
| 35 | + |
| 36 | + DistributedTransactionManager(RemoteExecutor& executor, |
| 37 | + BackendDialect dialect = BackendDialect::MYSQL) |
| 38 | + : executor_(executor), dialect_(dialect) {} |
| 39 | + |
| 40 | + bool begin() override { |
| 41 | + txn_id_ = generate_txn_id(); |
| 42 | + participants_.clear(); |
| 43 | + prepared_.clear(); |
| 44 | + started_.clear(); |
| 45 | + active_ = true; |
| 46 | + return true; |
| 47 | + } |
| 48 | + |
| 49 | + // Enlist a backend as a transaction participant. Called when DML is |
| 50 | + // executed against a backend. Sends XA START / BEGIN to the backend |
| 51 | + // if not already enlisted. |
| 52 | + bool enlist_backend(const char* backend_name) { |
| 53 | + if (!active_) return false; |
| 54 | + std::string name(backend_name); |
| 55 | + if (started_.count(name)) return true; // already enlisted |
| 56 | + |
| 57 | + bool ok = false; |
| 58 | + if (dialect_ == BackendDialect::MYSQL) { |
| 59 | + std::string sql = "XA START '" + txn_id_ + "'"; |
| 60 | + ok = send_sql(backend_name, sql); |
| 61 | + } else { |
| 62 | + ok = send_sql(backend_name, "BEGIN"); |
| 63 | + } |
| 64 | + if (ok) { |
| 65 | + participants_.push_back(name); |
| 66 | + started_.insert(name); |
| 67 | + } |
| 68 | + return ok; |
| 69 | + } |
| 70 | + |
| 71 | + bool commit() override { |
| 72 | + if (!active_) return false; |
| 73 | + if (participants_.empty()) { |
| 74 | + active_ = false; |
| 75 | + return true; |
| 76 | + } |
| 77 | + |
| 78 | + // Phase 1: prepare all participants |
| 79 | + if (!phase1_prepare()) { |
| 80 | + phase2_rollback(); |
| 81 | + active_ = false; |
| 82 | + return false; |
| 83 | + } |
| 84 | + |
| 85 | + // Phase 2: commit all participants |
| 86 | + bool ok = phase2_commit(); |
| 87 | + active_ = false; |
| 88 | + return ok; |
| 89 | + } |
| 90 | + |
| 91 | + bool rollback() override { |
| 92 | + if (!active_) return false; |
| 93 | + phase2_rollback(); |
| 94 | + active_ = false; |
| 95 | + return true; |
| 96 | + } |
| 97 | + |
| 98 | + // Savepoints are not supported for distributed transactions. |
| 99 | + bool savepoint(const char*) override { return false; } |
| 100 | + bool rollback_to(const char*) override { return false; } |
| 101 | + bool release_savepoint(const char*) override { return false; } |
| 102 | + |
| 103 | + bool in_transaction() const override { return active_; } |
| 104 | + bool is_auto_commit() const override { return auto_commit_; } |
| 105 | + void set_auto_commit(bool ac) override { auto_commit_ = ac; } |
| 106 | + |
| 107 | + const std::string& txn_id() const { return txn_id_; } |
| 108 | + const std::vector<std::string>& participants() const { return participants_; } |
| 109 | + |
| 110 | +private: |
| 111 | + RemoteExecutor& executor_; |
| 112 | + BackendDialect dialect_; |
| 113 | + |
| 114 | + std::string txn_id_; |
| 115 | + std::vector<std::string> participants_; |
| 116 | + std::unordered_set<std::string> started_; |
| 117 | + std::unordered_map<std::string, bool> prepared_; |
| 118 | + bool active_ = false; |
| 119 | + bool auto_commit_ = true; |
| 120 | + |
| 121 | + // Generate a unique transaction ID. |
| 122 | + static std::string generate_txn_id() { |
| 123 | + auto now = std::chrono::steady_clock::now(); |
| 124 | + auto ns = std::chrono::duration_cast<std::chrono::nanoseconds>( |
| 125 | + now.time_since_epoch()).count(); |
| 126 | + // Use random suffix to avoid collisions |
| 127 | + static thread_local std::mt19937 rng( |
| 128 | + static_cast<unsigned>(std::chrono::system_clock::now() |
| 129 | + .time_since_epoch().count())); |
| 130 | + std::uniform_int_distribution<uint32_t> dist(0, 999999); |
| 131 | + char buf[64]; |
| 132 | + std::snprintf(buf, sizeof(buf), "parsersql_%ld_%06u", |
| 133 | + static_cast<long>(ns), dist(rng)); |
| 134 | + return buf; |
| 135 | + } |
| 136 | + |
| 137 | + bool send_sql(const char* backend, const std::string& sql) { |
| 138 | + auto r = executor_.execute_dml(backend, |
| 139 | + sql_parser::StringRef{sql.c_str(), |
| 140 | + static_cast<uint32_t>(sql.size())}); |
| 141 | + return r.success; |
| 142 | + } |
| 143 | + |
| 144 | + // Phase 1: XA END + XA PREPARE on all participants (MySQL) |
| 145 | + // PREPARE TRANSACTION on all participants (PostgreSQL) |
| 146 | + bool phase1_prepare() { |
| 147 | + bool all_ok = true; |
| 148 | + for (auto& p : participants_) { |
| 149 | + bool ok = false; |
| 150 | + if (dialect_ == BackendDialect::MYSQL) { |
| 151 | + std::string end_sql = "XA END '" + txn_id_ + "'"; |
| 152 | + ok = send_sql(p.c_str(), end_sql); |
| 153 | + if (ok) { |
| 154 | + std::string prep_sql = "XA PREPARE '" + txn_id_ + "'"; |
| 155 | + ok = send_sql(p.c_str(), prep_sql); |
| 156 | + } |
| 157 | + } else { |
| 158 | + std::string prep_sql = "PREPARE TRANSACTION '" + txn_id_ + "'"; |
| 159 | + ok = send_sql(p.c_str(), prep_sql); |
| 160 | + } |
| 161 | + prepared_[p] = ok; |
| 162 | + if (!ok) all_ok = false; |
| 163 | + } |
| 164 | + return all_ok; |
| 165 | + } |
| 166 | + |
| 167 | + // Phase 2 (success): XA COMMIT on all participants |
| 168 | + bool phase2_commit() { |
| 169 | + bool all_ok = true; |
| 170 | + for (auto& p : participants_) { |
| 171 | + bool ok = false; |
| 172 | + if (dialect_ == BackendDialect::MYSQL) { |
| 173 | + std::string sql = "XA COMMIT '" + txn_id_ + "'"; |
| 174 | + ok = send_sql(p.c_str(), sql); |
| 175 | + } else { |
| 176 | + std::string sql = "COMMIT PREPARED '" + txn_id_ + "'"; |
| 177 | + ok = send_sql(p.c_str(), sql); |
| 178 | + } |
| 179 | + if (!ok) all_ok = false; |
| 180 | + } |
| 181 | + return all_ok; |
| 182 | + } |
| 183 | + |
| 184 | + // Phase 2 (failure): XA ROLLBACK on all participants |
| 185 | + void phase2_rollback() { |
| 186 | + for (auto& p : participants_) { |
| 187 | + if (dialect_ == BackendDialect::MYSQL) { |
| 188 | + // If prepared, XA ROLLBACK; if only started, XA END + XA ROLLBACK |
| 189 | + if (prepared_.count(p) && prepared_[p]) { |
| 190 | + std::string sql = "XA ROLLBACK '" + txn_id_ + "'"; |
| 191 | + send_sql(p.c_str(), sql); |
| 192 | + } else if (started_.count(p)) { |
| 193 | + // Try XA END first (may fail if already ended) |
| 194 | + std::string end_sql = "XA END '" + txn_id_ + "'"; |
| 195 | + send_sql(p.c_str(), end_sql); |
| 196 | + std::string rb_sql = "XA ROLLBACK '" + txn_id_ + "'"; |
| 197 | + send_sql(p.c_str(), rb_sql); |
| 198 | + } |
| 199 | + } else { |
| 200 | + if (prepared_.count(p) && prepared_[p]) { |
| 201 | + std::string sql = "ROLLBACK PREPARED '" + txn_id_ + "'"; |
| 202 | + send_sql(p.c_str(), sql); |
| 203 | + } else if (started_.count(p)) { |
| 204 | + send_sql(p.c_str(), "ROLLBACK"); |
| 205 | + } |
| 206 | + } |
| 207 | + } |
| 208 | + } |
| 209 | +}; |
| 210 | + |
| 211 | +} // namespace sql_engine |
| 212 | + |
| 213 | +#endif // SQL_ENGINE_DISTRIBUTED_TXN_H |
0 commit comments